~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2011-12-19 13:23:58 UTC
  • mto: This revision was merged to the branch mainline in revision 6386.
  • Revision ID: jelmer@canonical.com-20111219132358-uvs5a6y92gomzacd
Move importing from future until after doc string, otherwise the doc string will disappear.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
from __future__ import absolute_import
 
28
 
 
29
import errno
27
30
import os
28
31
import sys
29
 
import urllib
 
32
import time
30
33
 
 
34
import bzrlib
31
35
from bzrlib.lazy_import import lazy_import
32
36
lazy_import(globals(), """
33
 
import atexit
 
37
import select
34
38
import socket
35
39
import thread
36
40
import weakref
38
42
from bzrlib import (
39
43
    debug,
40
44
    errors,
41
 
    symbol_versioning,
42
45
    trace,
43
46
    ui,
44
47
    urlutils,
45
48
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
 
49
from bzrlib.i18n import gettext
 
50
from bzrlib.smart import client, protocol, request, signals, vfs
47
51
from bzrlib.transport import ssh
48
52
""")
49
53
from bzrlib import osutils
176
180
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
181
 
178
182
 
 
183
_bad_file_descriptor = (errno.EBADF,)
 
184
if sys.platform == 'win32':
 
185
    # Given on Windows if you pass a closed socket to select.select. Probably
 
186
    # also given if you pass a file handle to select.
 
187
    WSAENOTSOCK = 10038
 
188
    _bad_file_descriptor += (WSAENOTSOCK,)
 
189
 
 
190
 
179
191
class SmartServerStreamMedium(SmartMedium):
180
192
    """Handles smart commands coming over a stream.
181
193
 
194
206
        the stream.  See also the _push_back method.
195
207
    """
196
208
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
209
    _timer = time.time
 
210
 
 
211
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
212
        """Construct new server.
199
213
 
200
214
        :param backing_transport: Transport for the directory served.
203
217
        self.backing_transport = backing_transport
204
218
        self.root_client_path = root_client_path
205
219
        self.finished = False
 
220
        if timeout is None:
 
221
            raise AssertionError('You must supply a timeout.')
 
222
        self._client_timeout = timeout
 
223
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
224
        SmartMedium.__init__(self)
207
225
 
208
226
    def serve(self):
214
232
            while not self.finished:
215
233
                server_protocol = self._build_protocol()
216
234
                self._serve_one_request(server_protocol)
 
235
        except errors.ConnectionTimeout, e:
 
236
            trace.note('%s' % (e,))
 
237
            trace.log_exception_quietly()
 
238
            self._disconnect_client()
 
239
            # We reported it, no reason to make a big fuss.
 
240
            return
217
241
        except Exception, e:
218
242
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
243
            raise
 
244
        self._disconnect_client()
 
245
 
 
246
    def _stop_gracefully(self):
 
247
        """When we finish this message, stop looking for more."""
 
248
        trace.mutter('Stopping %s' % (self,))
 
249
        self.finished = True
 
250
 
 
251
    def _disconnect_client(self):
 
252
        """Close the current connection. We stopped due to a timeout/etc."""
 
253
        # The default implementation is a no-op, because that is all we used to
 
254
        # do when disconnecting from a client. I suppose we never had the
 
255
        # *server* initiate a disconnect, before
 
256
 
 
257
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
258
        """Wait for more bytes to be read, but timeout if none available.
 
259
 
 
260
        This allows us to detect idle connections, and stop trying to read from
 
261
        them, without setting the socket itself to non-blocking. This also
 
262
        allows us to specify when we watch for idle timeouts.
 
263
 
 
264
        :return: Did we timeout? (True if we timed out, False if there is data
 
265
            to be read)
 
266
        """
 
267
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
268
 
221
269
    def _build_protocol(self):
222
270
        """Identifies the version of the incoming request, and returns an
227
275
 
228
276
        :returns: a SmartServerRequestProtocol.
229
277
        """
 
278
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
279
        if self.finished:
 
280
            # We're stopping, so don't try to do any more work
 
281
            return None
230
282
        bytes = self._get_line()
231
283
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
284
        protocol = protocol_factory(
234
286
        protocol.accept_bytes(unused_bytes)
235
287
        return protocol
236
288
 
 
289
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
290
        """select() on a file descriptor, waiting for nonblocking read()
 
291
 
 
292
        This will raise a ConnectionTimeout exception if we do not get a
 
293
        readable handle before timeout_seconds.
 
294
        :return: None
 
295
        """
 
296
        t_end = self._timer() + timeout_seconds
 
297
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
298
        rs = xs = None
 
299
        while not rs and not xs and self._timer() < t_end:
 
300
            if self.finished:
 
301
                return
 
302
            try:
 
303
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
304
            except (select.error, socket.error) as e:
 
305
                err = getattr(e, 'errno', None)
 
306
                if err is None and getattr(e, 'args', None) is not None:
 
307
                    # select.error doesn't have 'errno', it just has args[0]
 
308
                    err = e.args[0]
 
309
                if err in _bad_file_descriptor:
 
310
                    return # Not a socket indicates read() will fail
 
311
                elif err == errno.EINTR:
 
312
                    # Interrupted, keep looping.
 
313
                    continue
 
314
                raise
 
315
        if rs or xs:
 
316
            return
 
317
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
318
                                       % (timeout_seconds,))
 
319
 
237
320
    def _serve_one_request(self, protocol):
238
321
        """Read one request from input, process, send back a response.
239
322
 
240
323
        :param protocol: a SmartServerRequestProtocol.
241
324
        """
 
325
        if protocol is None:
 
326
            return
242
327
        try:
243
328
            self._serve_one_request_unguarded(protocol)
244
329
        except KeyboardInterrupt:
260
345
 
261
346
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
347
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
348
    def __init__(self, sock, backing_transport, root_client_path='/',
 
349
                 timeout=None):
264
350
        """Constructor.
265
351
 
266
352
        :param sock: the socket the server will read from.  It will be put
267
353
            into blocking mode.
268
354
        """
269
355
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
356
            self, backing_transport, root_client_path=root_client_path,
 
357
            timeout=timeout)
271
358
        sock.setblocking(True)
272
359
        self.socket = sock
 
360
        # Get the getpeername now, as we might be closed later when we care.
 
361
        try:
 
362
            self._client_info = sock.getpeername()
 
363
        except socket.error:
 
364
            self._client_info = '<unknown>'
 
365
 
 
366
    def __str__(self):
 
367
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
368
 
 
369
    def __repr__(self):
 
370
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
371
            self._client_info)
273
372
 
274
373
    def _serve_one_request_unguarded(self, protocol):
275
374
        while protocol.next_read_size():
284
383
 
285
384
        self._push_back(protocol.unused_data)
286
385
 
 
386
    def _disconnect_client(self):
 
387
        """Close the current connection. We stopped due to a timeout/etc."""
 
388
        self.socket.close()
 
389
 
 
390
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
391
        """Wait for more bytes to be read, but timeout if none available.
 
392
 
 
393
        This allows us to detect idle connections, and stop trying to read from
 
394
        them, without setting the socket itself to non-blocking. This also
 
395
        allows us to specify when we watch for idle timeouts.
 
396
 
 
397
        :return: None, this will raise ConnectionTimeout if we time out before
 
398
            data is available.
 
399
        """
 
400
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
401
 
287
402
    def _read_bytes(self, desired_count):
288
403
        return osutils.read_bytes_from_socket(
289
404
            self.socket, self._report_activity)
306
421
 
307
422
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
423
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
424
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
425
        """Construct new server.
311
426
 
312
427
        :param in_file: Python file from which requests can be read.
313
428
        :param out_file: Python file to write responses.
314
429
        :param backing_transport: Transport for the directory served.
315
430
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
431
        SmartServerStreamMedium.__init__(self, backing_transport,
 
432
            timeout=timeout)
317
433
        if sys.platform == 'win32':
318
434
            # force binary mode for files
319
435
            import msvcrt
324
440
        self._in = in_file
325
441
        self._out = out_file
326
442
 
 
443
    def serve(self):
 
444
        """See SmartServerStreamMedium.serve"""
 
445
        # This is the regular serve, except it adds signal trapping for soft
 
446
        # shutdown.
 
447
        stop_gracefully = self._stop_gracefully
 
448
        signals.register_on_hangup(id(self), stop_gracefully)
 
449
        try:
 
450
            return super(SmartServerPipeStreamMedium, self).serve()
 
451
        finally:
 
452
            signals.unregister_on_hangup(id(self))
 
453
 
327
454
    def _serve_one_request_unguarded(self, protocol):
328
455
        while True:
329
456
            # We need to be careful not to read past the end of the current
342
469
                return
343
470
            protocol.accept_bytes(bytes)
344
471
 
 
472
    def _disconnect_client(self):
 
473
        self._in.close()
 
474
        self._out.flush()
 
475
        self._out.close()
 
476
 
 
477
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
478
        """Wait for more bytes to be read, but timeout if none available.
 
479
 
 
480
        This allows us to detect idle connections, and stop trying to read from
 
481
        them, without setting the socket itself to non-blocking. This also
 
482
        allows us to specify when we watch for idle timeouts.
 
483
 
 
484
        :return: None, this will raise ConnectionTimeout if we time out before
 
485
            data is available.
 
486
        """
 
487
        if (getattr(self._in, 'fileno', None) is None
 
488
            or sys.platform == 'win32'):
 
489
            # You can't select() file descriptors on Windows.
 
490
            return
 
491
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
492
 
345
493
    def _read_bytes(self, desired_count):
346
494
        return self._in.read(desired_count)
347
495
 
491
639
        return self._medium._get_line()
492
640
 
493
641
 
 
642
class _VfsRefuser(object):
 
643
    """An object that refuses all VFS requests.
 
644
 
 
645
    """
 
646
 
 
647
    def __init__(self):
 
648
        client._SmartClient.hooks.install_named_hook(
 
649
            'call', self.check_vfs, 'vfs refuser')
 
650
 
 
651
    def check_vfs(self, params):
 
652
        try:
 
653
            request_method = request.request_handlers.get(params.method)
 
654
        except KeyError:
 
655
            # A method we don't know about doesn't count as a VFS method.
 
656
            return
 
657
        if issubclass(request_method, vfs.VfsRequest):
 
658
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
 
659
 
 
660
 
494
661
class _DebugCounter(object):
495
662
    """An object that counts the HPSS calls made to each client medium.
496
663
 
497
 
    When a medium is garbage-collected, or failing that when atexit functions
498
 
    are run, the total number of calls made on that medium are reported via
499
 
    trace.note.
 
664
    When a medium is garbage-collected, or failing that when
 
665
    bzrlib.global_state exits, the total number of calls made on that medium
 
666
    are reported via trace.note.
500
667
    """
501
668
 
502
669
    def __init__(self):
503
670
        self.counts = weakref.WeakKeyDictionary()
504
671
        client._SmartClient.hooks.install_named_hook(
505
672
            'call', self.increment_call_count, 'hpss call counter')
506
 
        atexit.register(self.flush_all)
 
673
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
507
674
 
508
675
    def track(self, medium):
509
676
        """Start tracking calls made to a medium.
543
710
        value['count'] = 0
544
711
        value['vfs_count'] = 0
545
712
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
713
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
714
                       count, vfs_count, medium_repr))
548
715
 
549
716
    def flush_all(self):
550
717
        for ref in list(self.counts.keys()):
551
718
            self.done(ref)
552
719
 
553
720
_debug_counter = None
 
721
_vfs_refuser = None
554
722
 
555
723
 
556
724
class SmartClientMedium(SmartMedium):
573
741
            if _debug_counter is None:
574
742
                _debug_counter = _DebugCounter()
575
743
            _debug_counter.track(self)
 
744
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
745
            global _vfs_refuser
 
746
            if _vfs_refuser is None:
 
747
                _vfs_refuser = _VfsRefuser()
576
748
 
577
749
    def _is_remote_before(self, version_tuple):
578
750
        """Is it possible the remote side supports RPCs for a given version?
669
841
        """
670
842
        medium_base = urlutils.join(self.base, '/')
671
843
        rel_url = urlutils.relative_url(medium_base, transport.base)
672
 
        return urllib.unquote(rel_url)
 
844
        return urlutils.unquote(rel_url)
673
845
 
674
846
 
675
847
class SmartClientStreamMedium(SmartClientMedium):
710
882
        """
711
883
        return SmartClientStreamMediumRequest(self)
712
884
 
 
885
    def reset(self):
 
886
        """We have been disconnected, reset current state.
 
887
 
 
888
        This resets things like _current_request and connected state.
 
889
        """
 
890
        self.disconnect()
 
891
        self._current_request = None
 
892
 
713
893
 
714
894
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
895
    """A client medium using simple pipes.
716
896
 
717
897
    This client does not manage the pipes: it assumes they will always be open.
718
 
 
719
 
    Note that if readable_pipe.read might raise IOError or OSError with errno
720
 
    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
721
 
    (such as used for sys.stdin) are safe.
722
898
    """
723
899
 
724
900
    def __init__(self, readable_pipe, writeable_pipe, base):
728
904
 
729
905
    def _accept_bytes(self, bytes):
730
906
        """See SmartClientStreamMedium.accept_bytes."""
731
 
        self._writeable_pipe.write(bytes)
 
907
        try:
 
908
            self._writeable_pipe.write(bytes)
 
909
        except IOError, e:
 
910
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
911
                raise errors.ConnectionReset(
 
912
                    "Error trying to write to subprocess:\n%s" % (e,))
 
913
            raise
732
914
        self._report_activity(len(bytes), 'write')
733
915
 
734
916
    def _flush(self):
735
917
        """See SmartClientStreamMedium._flush()."""
 
918
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
919
        #       However, testing shows that even when the child process is
 
920
        #       gone, this doesn't error.
736
921
        self._writeable_pipe.flush()
737
922
 
738
923
    def _read_bytes(self, count):
739
924
        """See SmartClientStreamMedium._read_bytes."""
740
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
925
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
926
        bytes = self._readable_pipe.read(bytes_to_read)
741
927
        self._report_activity(len(bytes), 'read')
742
928
        return bytes
743
929
 
744
930
 
 
931
class SSHParams(object):
 
932
    """A set of parameters for starting a remote bzr via SSH."""
 
933
 
 
934
    def __init__(self, host, port=None, username=None, password=None,
 
935
            bzr_remote_path='bzr'):
 
936
        self.host = host
 
937
        self.port = port
 
938
        self.username = username
 
939
        self.password = password
 
940
        self.bzr_remote_path = bzr_remote_path
 
941
 
 
942
 
745
943
class SmartSSHClientMedium(SmartClientStreamMedium):
746
 
    """A client medium using SSH."""
747
 
 
748
 
    def __init__(self, host, port=None, username=None, password=None,
749
 
            base=None, vendor=None, bzr_remote_path=None):
 
944
    """A client medium using SSH.
 
945
 
 
946
    It delegates IO to a SmartSimplePipesClientMedium or
 
947
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
948
    """
 
949
 
 
950
    def __init__(self, base, ssh_params, vendor=None):
750
951
        """Creates a client that will connect on the first use.
751
952
 
 
953
        :param ssh_params: A SSHParams instance.
752
954
        :param vendor: An optional override for the ssh vendor to use. See
753
955
            bzrlib.transport.ssh for details on ssh vendors.
754
956
        """
755
 
        self._connected = False
756
 
        self._host = host
757
 
        self._password = password
758
 
        self._port = port
759
 
        self._username = username
 
957
        self._real_medium = None
 
958
        self._ssh_params = ssh_params
760
959
        # for the benefit of progress making a short description of this
761
960
        # transport
762
961
        self._scheme = 'bzr+ssh'
764
963
        # _DebugCounter so we have to store all the values used in our repr
765
964
        # method before calling the super init.
766
965
        SmartClientStreamMedium.__init__(self, base)
767
 
        self._read_from = None
 
966
        self._vendor = vendor
768
967
        self._ssh_connection = None
769
 
        self._vendor = vendor
770
 
        self._write_to = None
771
 
        self._bzr_remote_path = bzr_remote_path
772
968
 
773
969
    def __repr__(self):
774
 
        if self._port is None:
 
970
        if self._ssh_params.port is None:
775
971
            maybe_port = ''
776
972
        else:
777
 
            maybe_port = ':%s' % self._port
778
 
        return "%s(%s://%s@%s%s/)" % (
 
973
            maybe_port = ':%s' % self._ssh_params.port
 
974
        if self._ssh_params.username is None:
 
975
            maybe_user = ''
 
976
        else:
 
977
            maybe_user = '%s@' % self._ssh_params.username
 
978
        return "%s(%s://%s%s%s/)" % (
779
979
            self.__class__.__name__,
780
980
            self._scheme,
781
 
            self._username,
782
 
            self._host,
 
981
            maybe_user,
 
982
            self._ssh_params.host,
783
983
            maybe_port)
784
984
 
785
985
    def _accept_bytes(self, bytes):
786
986
        """See SmartClientStreamMedium.accept_bytes."""
787
987
        self._ensure_connection()
788
 
        self._write_to.write(bytes)
789
 
        self._report_activity(len(bytes), 'write')
 
988
        self._real_medium.accept_bytes(bytes)
790
989
 
791
990
    def disconnect(self):
792
991
        """See SmartClientMedium.disconnect()."""
793
 
        if not self._connected:
794
 
            return
795
 
        self._read_from.close()
796
 
        self._write_to.close()
797
 
        self._ssh_connection.close()
798
 
        self._connected = False
 
992
        if self._real_medium is not None:
 
993
            self._real_medium.disconnect()
 
994
            self._real_medium = None
 
995
        if self._ssh_connection is not None:
 
996
            self._ssh_connection.close()
 
997
            self._ssh_connection = None
799
998
 
800
999
    def _ensure_connection(self):
801
1000
        """Connect this medium if not already connected."""
802
 
        if self._connected:
 
1001
        if self._real_medium is not None:
803
1002
            return
804
1003
        if self._vendor is None:
805
1004
            vendor = ssh._get_ssh_vendor()
806
1005
        else:
807
1006
            vendor = self._vendor
808
 
        self._ssh_connection = vendor.connect_ssh(self._username,
809
 
                self._password, self._host, self._port,
810
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
1007
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1008
                self._ssh_params.password, self._ssh_params.host,
 
1009
                self._ssh_params.port,
 
1010
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
811
1011
                         '--directory=/', '--allow-writes'])
812
 
        self._read_from, self._write_to = \
813
 
            self._ssh_connection.get_filelike_channels()
814
 
        self._connected = True
 
1012
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1013
        if io_kind == 'socket':
 
1014
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1015
                self.base, io_object)
 
1016
        elif io_kind == 'pipes':
 
1017
            read_from, write_to = io_object
 
1018
            self._real_medium = SmartSimplePipesClientMedium(
 
1019
                read_from, write_to, self.base)
 
1020
        else:
 
1021
            raise AssertionError(
 
1022
                "Unexpected io_kind %r from %r"
 
1023
                % (io_kind, self._ssh_connection))
815
1024
 
816
1025
    def _flush(self):
817
1026
        """See SmartClientStreamMedium._flush()."""
818
 
        self._write_to.flush()
 
1027
        self._real_medium._flush()
819
1028
 
820
1029
    def _read_bytes(self, count):
821
1030
        """See SmartClientStreamMedium.read_bytes."""
822
 
        if not self._connected:
 
1031
        if self._real_medium is None:
823
1032
            raise errors.MediumNotConnected(self)
824
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
825
 
        bytes = self._read_from.read(bytes_to_read)
826
 
        self._report_activity(len(bytes), 'read')
827
 
        return bytes
 
1033
        return self._real_medium.read_bytes(count)
828
1034
 
829
1035
 
830
1036
# Port 4155 is the default port for bzr://, registered with IANA.
832
1038
BZR_DEFAULT_PORT = 4155
833
1039
 
834
1040
 
835
 
class SmartTCPClientMedium(SmartClientStreamMedium):
836
 
    """A client medium using TCP."""
 
1041
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1042
    """A client medium using a socket.
 
1043
    
 
1044
    This class isn't usable directly.  Use one of its subclasses instead.
 
1045
    """
837
1046
 
838
 
    def __init__(self, host, port, base):
839
 
        """Creates a client that will connect on the first use."""
 
1047
    def __init__(self, base):
840
1048
        SmartClientStreamMedium.__init__(self, base)
 
1049
        self._socket = None
841
1050
        self._connected = False
842
 
        self._host = host
843
 
        self._port = port
844
 
        self._socket = None
845
1051
 
846
1052
    def _accept_bytes(self, bytes):
847
1053
        """See SmartClientMedium.accept_bytes."""
848
1054
        self._ensure_connection()
849
1055
        osutils.send_all(self._socket, bytes, self._report_activity)
850
1056
 
 
1057
    def _ensure_connection(self):
 
1058
        """Connect this medium if not already connected."""
 
1059
        raise NotImplementedError(self._ensure_connection)
 
1060
 
 
1061
    def _flush(self):
 
1062
        """See SmartClientStreamMedium._flush().
 
1063
 
 
1064
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1065
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1066
        future.
 
1067
        """
 
1068
 
 
1069
    def _read_bytes(self, count):
 
1070
        """See SmartClientMedium.read_bytes."""
 
1071
        if not self._connected:
 
1072
            raise errors.MediumNotConnected(self)
 
1073
        return osutils.read_bytes_from_socket(
 
1074
            self._socket, self._report_activity)
 
1075
 
851
1076
    def disconnect(self):
852
1077
        """See SmartClientMedium.disconnect()."""
853
1078
        if not self._connected:
856
1081
        self._socket = None
857
1082
        self._connected = False
858
1083
 
 
1084
 
 
1085
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1086
    """A client medium that creates a TCP connection."""
 
1087
 
 
1088
    def __init__(self, host, port, base):
 
1089
        """Creates a client that will connect on the first use."""
 
1090
        SmartClientSocketMedium.__init__(self, base)
 
1091
        self._host = host
 
1092
        self._port = port
 
1093
 
859
1094
    def _ensure_connection(self):
860
1095
        """Connect this medium if not already connected."""
861
1096
        if self._connected:
895
1130
                    (self._host, port, err_msg))
896
1131
        self._connected = True
897
1132
 
898
 
    def _flush(self):
899
 
        """See SmartClientStreamMedium._flush().
900
 
 
901
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
902
 
        add a means to do a flush, but that can be done in the future.
903
 
        """
904
 
 
905
 
    def _read_bytes(self, count):
906
 
        """See SmartClientMedium.read_bytes."""
907
 
        if not self._connected:
908
 
            raise errors.MediumNotConnected(self)
909
 
        return osutils.read_bytes_from_socket(
910
 
            self._socket, self._report_activity)
 
1133
 
 
1134
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1135
    """A client medium for an already connected socket.
 
1136
    
 
1137
    Note that this class will assume it "owns" the socket, so it will close it
 
1138
    when its disconnect method is called.
 
1139
    """
 
1140
 
 
1141
    def __init__(self, base, sock):
 
1142
        SmartClientSocketMedium.__init__(self, base)
 
1143
        self._socket = sock
 
1144
        self._connected = True
 
1145
 
 
1146
    def _ensure_connection(self):
 
1147
        # Already connected, by definition!  So nothing to do.
 
1148
        pass
911
1149
 
912
1150
 
913
1151
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
948
1186
        This invokes self._medium._flush to ensure all bytes are transmitted.
949
1187
        """
950
1188
        self._medium._flush()
951
 
 
952