~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2011-10-14 13:56:45 UTC
  • mfrom: (6215 +trunk)
  • mto: This revision was merged to the branch mainline in revision 6216.
  • Revision ID: jelmer@samba.org-20111014135645-phc3q3y21k2ks0s2
Merge bzr.dev.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
26
26
 
27
27
import errno
28
28
import os
29
 
import socket
30
29
import sys
 
30
import time
31
31
import urllib
32
32
 
 
33
import bzrlib
33
34
from bzrlib.lazy_import import lazy_import
34
35
lazy_import(globals(), """
35
 
import atexit
 
36
import select
 
37
import socket
 
38
import thread
36
39
import weakref
 
40
 
37
41
from bzrlib import (
38
42
    debug,
39
43
    errors,
40
 
    osutils,
41
 
    symbol_versioning,
42
44
    trace,
43
45
    ui,
44
46
    urlutils,
45
47
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
 
48
from bzrlib.i18n import gettext
 
49
from bzrlib.smart import client, protocol, request, signals, vfs
47
50
from bzrlib.transport import ssh
48
51
""")
49
 
 
50
 
 
51
 
# We must not read any more than 64k at a time so we don't risk "no buffer
52
 
# space available" errors on some platforms.  Windows in particular is likely
53
 
# to give error 10053 or 10055 if we read more than 64k from a socket.
54
 
_MAX_READ_SIZE = 64 * 1024
55
 
 
 
52
from bzrlib import osutils
 
53
 
 
54
# Throughout this module buffer size parameters are either limited to be at
 
55
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
 
56
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
 
57
# from non-sockets as well.
 
58
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
59
 
57
60
def _get_protocol_factory_for_bytes(bytes):
58
61
    """Determine the right protocol factory for 'bytes'.
176
179
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
180
 
178
181
 
 
182
_bad_file_descriptor = (errno.EBADF,)
 
183
if sys.platform == 'win32':
 
184
    # Given on Windows if you pass a closed socket to select.select. Probably
 
185
    # also given if you pass a file handle to select.
 
186
    WSAENOTSOCK = 10038
 
187
    _bad_file_descriptor += (WSAENOTSOCK,)
 
188
 
 
189
 
179
190
class SmartServerStreamMedium(SmartMedium):
180
191
    """Handles smart commands coming over a stream.
181
192
 
194
205
        the stream.  See also the _push_back method.
195
206
    """
196
207
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
208
    _timer = time.time
 
209
 
 
210
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
211
        """Construct new server.
199
212
 
200
213
        :param backing_transport: Transport for the directory served.
203
216
        self.backing_transport = backing_transport
204
217
        self.root_client_path = root_client_path
205
218
        self.finished = False
 
219
        if timeout is None:
 
220
            raise AssertionError('You must supply a timeout.')
 
221
        self._client_timeout = timeout
 
222
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
223
        SmartMedium.__init__(self)
207
224
 
208
225
    def serve(self):
214
231
            while not self.finished:
215
232
                server_protocol = self._build_protocol()
216
233
                self._serve_one_request(server_protocol)
 
234
        except errors.ConnectionTimeout, e:
 
235
            trace.note('%s' % (e,))
 
236
            trace.log_exception_quietly()
 
237
            self._disconnect_client()
 
238
            # We reported it, no reason to make a big fuss.
 
239
            return
217
240
        except Exception, e:
218
241
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
242
            raise
 
243
        self._disconnect_client()
 
244
 
 
245
    def _stop_gracefully(self):
 
246
        """When we finish this message, stop looking for more."""
 
247
        trace.mutter('Stopping %s' % (self,))
 
248
        self.finished = True
 
249
 
 
250
    def _disconnect_client(self):
 
251
        """Close the current connection. We stopped due to a timeout/etc."""
 
252
        # The default implementation is a no-op, because that is all we used to
 
253
        # do when disconnecting from a client. I suppose we never had the
 
254
        # *server* initiate a disconnect, before
 
255
 
 
256
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
257
        """Wait for more bytes to be read, but timeout if none available.
 
258
 
 
259
        This allows us to detect idle connections, and stop trying to read from
 
260
        them, without setting the socket itself to non-blocking. This also
 
261
        allows us to specify when we watch for idle timeouts.
 
262
 
 
263
        :return: Did we timeout? (True if we timed out, False if there is data
 
264
            to be read)
 
265
        """
 
266
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
267
 
221
268
    def _build_protocol(self):
222
269
        """Identifies the version of the incoming request, and returns an
227
274
 
228
275
        :returns: a SmartServerRequestProtocol.
229
276
        """
 
277
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
278
        if self.finished:
 
279
            # We're stopping, so don't try to do any more work
 
280
            return None
230
281
        bytes = self._get_line()
231
282
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
283
        protocol = protocol_factory(
234
285
        protocol.accept_bytes(unused_bytes)
235
286
        return protocol
236
287
 
 
288
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
289
        """select() on a file descriptor, waiting for nonblocking read()
 
290
 
 
291
        This will raise a ConnectionTimeout exception if we do not get a
 
292
        readable handle before timeout_seconds.
 
293
        :return: None
 
294
        """
 
295
        t_end = self._timer() + timeout_seconds
 
296
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
297
        rs = xs = None
 
298
        while not rs and not xs and self._timer() < t_end:
 
299
            if self.finished:
 
300
                return
 
301
            try:
 
302
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
303
            except (select.error, socket.error) as e:
 
304
                err = getattr(e, 'errno', None)
 
305
                if err is None and getattr(e, 'args', None) is not None:
 
306
                    # select.error doesn't have 'errno', it just has args[0]
 
307
                    err = e.args[0]
 
308
                if err in _bad_file_descriptor:
 
309
                    return # Not a socket indicates read() will fail
 
310
                elif err == errno.EINTR:
 
311
                    # Interrupted, keep looping.
 
312
                    continue
 
313
                raise
 
314
        if rs or xs:
 
315
            return
 
316
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
317
                                       % (timeout_seconds,))
 
318
 
237
319
    def _serve_one_request(self, protocol):
238
320
        """Read one request from input, process, send back a response.
239
321
 
240
322
        :param protocol: a SmartServerRequestProtocol.
241
323
        """
 
324
        if protocol is None:
 
325
            return
242
326
        try:
243
327
            self._serve_one_request_unguarded(protocol)
244
328
        except KeyboardInterrupt:
260
344
 
261
345
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
346
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
347
    def __init__(self, sock, backing_transport, root_client_path='/',
 
348
                 timeout=None):
264
349
        """Constructor.
265
350
 
266
351
        :param sock: the socket the server will read from.  It will be put
267
352
            into blocking mode.
268
353
        """
269
354
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
355
            self, backing_transport, root_client_path=root_client_path,
 
356
            timeout=timeout)
271
357
        sock.setblocking(True)
272
358
        self.socket = sock
 
359
        # Get the getpeername now, as we might be closed later when we care.
 
360
        try:
 
361
            self._client_info = sock.getpeername()
 
362
        except socket.error:
 
363
            self._client_info = '<unknown>'
 
364
 
 
365
    def __str__(self):
 
366
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
367
 
 
368
    def __repr__(self):
 
369
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
370
            self._client_info)
273
371
 
274
372
    def _serve_one_request_unguarded(self, protocol):
275
373
        while protocol.next_read_size():
276
374
            # We can safely try to read large chunks.  If there is less data
277
 
            # than _MAX_READ_SIZE ready, the socket wil just return a short
278
 
            # read immediately rather than block.
279
 
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
375
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
 
376
            # short read immediately rather than block.
 
377
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
280
378
            if bytes == '':
281
379
                self.finished = True
282
380
                return
284
382
 
285
383
        self._push_back(protocol.unused_data)
286
384
 
 
385
    def _disconnect_client(self):
 
386
        """Close the current connection. We stopped due to a timeout/etc."""
 
387
        self.socket.close()
 
388
 
 
389
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
390
        """Wait for more bytes to be read, but timeout if none available.
 
391
 
 
392
        This allows us to detect idle connections, and stop trying to read from
 
393
        them, without setting the socket itself to non-blocking. This also
 
394
        allows us to specify when we watch for idle timeouts.
 
395
 
 
396
        :return: None, this will raise ConnectionTimeout if we time out before
 
397
            data is available.
 
398
        """
 
399
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
400
 
287
401
    def _read_bytes(self, desired_count):
288
 
        return _read_bytes_from_socket(
289
 
            self.socket.recv, desired_count, self._report_activity)
 
402
        return osutils.read_bytes_from_socket(
 
403
            self.socket, self._report_activity)
290
404
 
291
405
    def terminate_due_to_error(self):
292
406
        # TODO: This should log to a server log file, but no such thing
295
409
        self.finished = True
296
410
 
297
411
    def _write_out(self, bytes):
 
412
        tstart = osutils.timer_func()
298
413
        osutils.send_all(self.socket, bytes, self._report_activity)
 
414
        if 'hpss' in debug.debug_flags:
 
415
            thread_id = thread.get_ident()
 
416
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
 
417
                         % ('wrote', thread_id, len(bytes),
 
418
                            osutils.timer_func() - tstart))
299
419
 
300
420
 
301
421
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
302
422
 
303
 
    def __init__(self, in_file, out_file, backing_transport):
 
423
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
304
424
        """Construct new server.
305
425
 
306
426
        :param in_file: Python file from which requests can be read.
307
427
        :param out_file: Python file to write responses.
308
428
        :param backing_transport: Transport for the directory served.
309
429
        """
310
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
430
        SmartServerStreamMedium.__init__(self, backing_transport,
 
431
            timeout=timeout)
311
432
        if sys.platform == 'win32':
312
433
            # force binary mode for files
313
434
            import msvcrt
318
439
        self._in = in_file
319
440
        self._out = out_file
320
441
 
 
442
    def serve(self):
 
443
        """See SmartServerStreamMedium.serve"""
 
444
        # This is the regular serve, except it adds signal trapping for soft
 
445
        # shutdown.
 
446
        stop_gracefully = self._stop_gracefully
 
447
        signals.register_on_hangup(id(self), stop_gracefully)
 
448
        try:
 
449
            return super(SmartServerPipeStreamMedium, self).serve()
 
450
        finally:
 
451
            signals.unregister_on_hangup(id(self))
 
452
 
321
453
    def _serve_one_request_unguarded(self, protocol):
322
454
        while True:
323
455
            # We need to be careful not to read past the end of the current
336
468
                return
337
469
            protocol.accept_bytes(bytes)
338
470
 
 
471
    def _disconnect_client(self):
 
472
        self._in.close()
 
473
        self._out.flush()
 
474
        self._out.close()
 
475
 
 
476
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
477
        """Wait for more bytes to be read, but timeout if none available.
 
478
 
 
479
        This allows us to detect idle connections, and stop trying to read from
 
480
        them, without setting the socket itself to non-blocking. This also
 
481
        allows us to specify when we watch for idle timeouts.
 
482
 
 
483
        :return: None, this will raise ConnectionTimeout if we time out before
 
484
            data is available.
 
485
        """
 
486
        if (getattr(self._in, 'fileno', None) is None
 
487
            or sys.platform == 'win32'):
 
488
            # You can't select() file descriptors on Windows.
 
489
            return
 
490
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
491
 
339
492
    def _read_bytes(self, desired_count):
340
493
        return self._in.read(desired_count)
341
494
 
485
638
        return self._medium._get_line()
486
639
 
487
640
 
 
641
class _VfsRefuser(object):
 
642
    """An object that refuses all VFS requests.
 
643
 
 
644
    """
 
645
 
 
646
    def __init__(self):
 
647
        client._SmartClient.hooks.install_named_hook(
 
648
            'call', self.check_vfs, 'vfs refuser')
 
649
 
 
650
    def check_vfs(self, params):
 
651
        try:
 
652
            request_method = request.request_handlers.get(params.method)
 
653
        except KeyError:
 
654
            # A method we don't know about doesn't count as a VFS method.
 
655
            return
 
656
        if issubclass(request_method, vfs.VfsRequest):
 
657
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
 
658
 
 
659
 
488
660
class _DebugCounter(object):
489
661
    """An object that counts the HPSS calls made to each client medium.
490
662
 
491
 
    When a medium is garbage-collected, or failing that when atexit functions
492
 
    are run, the total number of calls made on that medium are reported via
493
 
    trace.note.
 
663
    When a medium is garbage-collected, or failing that when
 
664
    bzrlib.global_state exits, the total number of calls made on that medium
 
665
    are reported via trace.note.
494
666
    """
495
667
 
496
668
    def __init__(self):
497
669
        self.counts = weakref.WeakKeyDictionary()
498
670
        client._SmartClient.hooks.install_named_hook(
499
671
            'call', self.increment_call_count, 'hpss call counter')
500
 
        atexit.register(self.flush_all)
 
672
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
501
673
 
502
674
    def track(self, medium):
503
675
        """Start tracking calls made to a medium.
537
709
        value['count'] = 0
538
710
        value['vfs_count'] = 0
539
711
        if count != 0:
540
 
            trace.note('HPSS calls: %d (%d vfs) %s',
541
 
                       count, vfs_count, medium_repr)
 
712
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
713
                       count, vfs_count, medium_repr))
542
714
 
543
715
    def flush_all(self):
544
716
        for ref in list(self.counts.keys()):
545
717
            self.done(ref)
546
718
 
547
719
_debug_counter = None
 
720
_vfs_refuser = None
548
721
 
549
722
 
550
723
class SmartClientMedium(SmartMedium):
567
740
            if _debug_counter is None:
568
741
                _debug_counter = _DebugCounter()
569
742
            _debug_counter.track(self)
 
743
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
744
            global _vfs_refuser
 
745
            if _vfs_refuser is None:
 
746
                _vfs_refuser = _VfsRefuser()
570
747
 
571
748
    def _is_remote_before(self, version_tuple):
572
749
        """Is it possible the remote side supports RPCs for a given version?
601
778
            # which is newer than a previously supplied older-than version.
602
779
            # This indicates that some smart verb call is not guarded
603
780
            # appropriately (it should simply not have been tried).
604
 
            raise AssertionError(
 
781
            trace.mutter(
605
782
                "_remember_remote_is_before(%r) called, but "
606
783
                "_remember_remote_is_before(%r) was called previously."
607
 
                % (version_tuple, self._remote_version_is_before))
 
784
                , version_tuple, self._remote_version_is_before)
 
785
            if 'hpss' in debug.debug_flags:
 
786
                ui.ui_factory.show_warning(
 
787
                    "_remember_remote_is_before(%r) called, but "
 
788
                    "_remember_remote_is_before(%r) was called previously."
 
789
                    % (version_tuple, self._remote_version_is_before))
 
790
            return
608
791
        self._remote_version_is_before = version_tuple
609
792
 
610
793
    def protocol_version(self):
721
904
 
722
905
    def _read_bytes(self, count):
723
906
        """See SmartClientStreamMedium._read_bytes."""
724
 
        bytes = self._readable_pipe.read(count)
 
907
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
908
        bytes = self._readable_pipe.read(bytes_to_read)
725
909
        self._report_activity(len(bytes), 'read')
726
910
        return bytes
727
911
 
728
912
 
 
913
class SSHParams(object):
 
914
    """A set of parameters for starting a remote bzr via SSH."""
 
915
 
 
916
    def __init__(self, host, port=None, username=None, password=None,
 
917
            bzr_remote_path='bzr'):
 
918
        self.host = host
 
919
        self.port = port
 
920
        self.username = username
 
921
        self.password = password
 
922
        self.bzr_remote_path = bzr_remote_path
 
923
 
 
924
 
729
925
class SmartSSHClientMedium(SmartClientStreamMedium):
730
 
    """A client medium using SSH."""
 
926
    """A client medium using SSH.
 
927
    
 
928
    It delegates IO to a SmartClientSocketMedium or
 
929
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
930
    """
731
931
 
732
 
    def __init__(self, host, port=None, username=None, password=None,
733
 
            base=None, vendor=None, bzr_remote_path=None):
 
932
    def __init__(self, base, ssh_params, vendor=None):
734
933
        """Creates a client that will connect on the first use.
735
934
 
 
935
        :param ssh_params: A SSHParams instance.
736
936
        :param vendor: An optional override for the ssh vendor to use. See
737
937
            bzrlib.transport.ssh for details on ssh vendors.
738
938
        """
739
 
        self._connected = False
740
 
        self._host = host
741
 
        self._password = password
742
 
        self._port = port
743
 
        self._username = username
 
939
        self._real_medium = None
 
940
        self._ssh_params = ssh_params
 
941
        # for the benefit of progress making a short description of this
 
942
        # transport
 
943
        self._scheme = 'bzr+ssh'
744
944
        # SmartClientStreamMedium stores the repr of this object in its
745
945
        # _DebugCounter so we have to store all the values used in our repr
746
946
        # method before calling the super init.
747
947
        SmartClientStreamMedium.__init__(self, base)
748
 
        self._read_from = None
 
948
        self._vendor = vendor
749
949
        self._ssh_connection = None
750
 
        self._vendor = vendor
751
 
        self._write_to = None
752
 
        self._bzr_remote_path = bzr_remote_path
753
 
        # for the benefit of progress making a short description of this
754
 
        # transport
755
 
        self._scheme = 'bzr+ssh'
756
950
 
757
951
    def __repr__(self):
758
 
        return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
 
952
        if self._ssh_params.port is None:
 
953
            maybe_port = ''
 
954
        else:
 
955
            maybe_port = ':%s' % self._ssh_params.port
 
956
        return "%s(%s://%s@%s%s/)" % (
759
957
            self.__class__.__name__,
760
 
            self._connected,
761
 
            self._username,
762
 
            self._host,
763
 
            self._port)
 
958
            self._scheme,
 
959
            self._ssh_params.username,
 
960
            self._ssh_params.host,
 
961
            maybe_port)
764
962
 
765
963
    def _accept_bytes(self, bytes):
766
964
        """See SmartClientStreamMedium.accept_bytes."""
767
965
        self._ensure_connection()
768
 
        self._write_to.write(bytes)
769
 
        self._report_activity(len(bytes), 'write')
 
966
        self._real_medium.accept_bytes(bytes)
770
967
 
771
968
    def disconnect(self):
772
969
        """See SmartClientMedium.disconnect()."""
773
 
        if not self._connected:
774
 
            return
775
 
        self._read_from.close()
776
 
        self._write_to.close()
777
 
        self._ssh_connection.close()
778
 
        self._connected = False
 
970
        if self._real_medium is not None:
 
971
            self._real_medium.disconnect()
 
972
            self._real_medium = None
 
973
        if self._ssh_connection is not None:
 
974
            self._ssh_connection.close()
 
975
            self._ssh_connection = None
779
976
 
780
977
    def _ensure_connection(self):
781
978
        """Connect this medium if not already connected."""
782
 
        if self._connected:
 
979
        if self._real_medium is not None:
783
980
            return
784
981
        if self._vendor is None:
785
982
            vendor = ssh._get_ssh_vendor()
786
983
        else:
787
984
            vendor = self._vendor
788
 
        self._ssh_connection = vendor.connect_ssh(self._username,
789
 
                self._password, self._host, self._port,
790
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
985
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
986
                self._ssh_params.password, self._ssh_params.host,
 
987
                self._ssh_params.port,
 
988
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
791
989
                         '--directory=/', '--allow-writes'])
792
 
        self._read_from, self._write_to = \
793
 
            self._ssh_connection.get_filelike_channels()
794
 
        self._connected = True
 
990
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
991
        if io_kind == 'socket':
 
992
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
993
                self.base, io_object)
 
994
        elif io_kind == 'pipes':
 
995
            read_from, write_to = io_object
 
996
            self._real_medium = SmartSimplePipesClientMedium(
 
997
                read_from, write_to, self.base)
 
998
        else:
 
999
            raise AssertionError(
 
1000
                "Unexpected io_kind %r from %r"
 
1001
                % (io_kind, self._ssh_connection))
795
1002
 
796
1003
    def _flush(self):
797
1004
        """See SmartClientStreamMedium._flush()."""
798
 
        self._write_to.flush()
 
1005
        self._real_medium._flush()
799
1006
 
800
1007
    def _read_bytes(self, count):
801
1008
        """See SmartClientStreamMedium.read_bytes."""
802
 
        if not self._connected:
 
1009
        if self._real_medium is None:
803
1010
            raise errors.MediumNotConnected(self)
804
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
805
 
        bytes = self._read_from.read(bytes_to_read)
806
 
        self._report_activity(len(bytes), 'read')
807
 
        return bytes
 
1011
        return self._real_medium.read_bytes(count)
808
1012
 
809
1013
 
810
1014
# Port 4155 is the default port for bzr://, registered with IANA.
812
1016
BZR_DEFAULT_PORT = 4155
813
1017
 
814
1018
 
815
 
class SmartTCPClientMedium(SmartClientStreamMedium):
816
 
    """A client medium using TCP."""
 
1019
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1020
    """A client medium using a socket.
 
1021
    
 
1022
    This class isn't usable directly.  Use one of its subclasses instead.
 
1023
    """
817
1024
 
818
 
    def __init__(self, host, port, base):
819
 
        """Creates a client that will connect on the first use."""
 
1025
    def __init__(self, base):
820
1026
        SmartClientStreamMedium.__init__(self, base)
 
1027
        self._socket = None
821
1028
        self._connected = False
822
 
        self._host = host
823
 
        self._port = port
824
 
        self._socket = None
825
1029
 
826
1030
    def _accept_bytes(self, bytes):
827
1031
        """See SmartClientMedium.accept_bytes."""
828
1032
        self._ensure_connection()
829
1033
        osutils.send_all(self._socket, bytes, self._report_activity)
830
1034
 
 
1035
    def _ensure_connection(self):
 
1036
        """Connect this medium if not already connected."""
 
1037
        raise NotImplementedError(self._ensure_connection)
 
1038
 
 
1039
    def _flush(self):
 
1040
        """See SmartClientStreamMedium._flush().
 
1041
 
 
1042
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1043
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1044
        future.
 
1045
        """
 
1046
 
 
1047
    def _read_bytes(self, count):
 
1048
        """See SmartClientMedium.read_bytes."""
 
1049
        if not self._connected:
 
1050
            raise errors.MediumNotConnected(self)
 
1051
        return osutils.read_bytes_from_socket(
 
1052
            self._socket, self._report_activity)
 
1053
 
831
1054
    def disconnect(self):
832
1055
        """See SmartClientMedium.disconnect()."""
833
1056
        if not self._connected:
836
1059
        self._socket = None
837
1060
        self._connected = False
838
1061
 
 
1062
 
 
1063
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1064
    """A client medium that creates a TCP connection."""
 
1065
 
 
1066
    def __init__(self, host, port, base):
 
1067
        """Creates a client that will connect on the first use."""
 
1068
        SmartClientSocketMedium.__init__(self, base)
 
1069
        self._host = host
 
1070
        self._port = port
 
1071
 
839
1072
    def _ensure_connection(self):
840
1073
        """Connect this medium if not already connected."""
841
1074
        if self._connected:
875
1108
                    (self._host, port, err_msg))
876
1109
        self._connected = True
877
1110
 
878
 
    def _flush(self):
879
 
        """See SmartClientStreamMedium._flush().
880
 
 
881
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
882
 
        add a means to do a flush, but that can be done in the future.
883
 
        """
884
 
 
885
 
    def _read_bytes(self, count):
886
 
        """See SmartClientMedium.read_bytes."""
887
 
        if not self._connected:
888
 
            raise errors.MediumNotConnected(self)
889
 
        return _read_bytes_from_socket(
890
 
            self._socket.recv, count, self._report_activity)
 
1111
 
 
1112
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1113
    """A client medium for an already connected socket.
 
1114
    
 
1115
    Note that this class will assume it "owns" the socket, so it will close it
 
1116
    when its disconnect method is called.
 
1117
    """
 
1118
 
 
1119
    def __init__(self, base, sock):
 
1120
        SmartClientSocketMedium.__init__(self, base)
 
1121
        self._socket = sock
 
1122
        self._connected = True
 
1123
 
 
1124
    def _ensure_connection(self):
 
1125
        # Already connected, by definition!  So nothing to do.
 
1126
        pass
891
1127
 
892
1128
 
893
1129
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
930
1166
        self._medium._flush()
931
1167
 
932
1168
 
933
 
def _read_bytes_from_socket(sock, desired_count, report_activity):
934
 
    # We ignore the desired_count because on sockets it's more efficient to
935
 
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
936
 
    try:
937
 
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
938
 
    except socket.error, e:
939
 
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
940
 
            # The connection was closed by the other side.  Callers expect an
941
 
            # empty string to signal end-of-stream.
942
 
            bytes = ''
943
 
        else:
944
 
            raise
945
 
    else:
946
 
        report_activity(len(bytes), 'read')
947
 
    return bytes
948