~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Martin Pool
  • Date: 2010-02-25 06:17:27 UTC
  • mfrom: (5055 +trunk)
  • mto: This revision was merged to the branch mainline in revision 5057.
  • Revision ID: mbp@sourcefrog.net-20100225061727-4sd9lt0qmdc6087t
merge news

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
26
26
 
27
27
import errno
28
28
import os
 
29
import socket
29
30
import sys
30
 
import time
31
31
import urllib
32
32
 
33
 
import bzrlib
34
33
from bzrlib.lazy_import import lazy_import
35
34
lazy_import(globals(), """
36
 
import select
37
 
import socket
 
35
import atexit
38
36
import thread
39
37
import weakref
40
38
 
41
39
from bzrlib import (
42
40
    debug,
43
41
    errors,
 
42
    symbol_versioning,
44
43
    trace,
45
44
    ui,
46
45
    urlutils,
47
46
    )
48
 
from bzrlib.i18n import gettext
49
 
from bzrlib.smart import client, protocol, request, signals, vfs
 
47
from bzrlib.smart import client, protocol, request, vfs
50
48
from bzrlib.transport import ssh
51
49
""")
 
50
#usually already imported, and getting IllegalScoperReplacer on it here.
52
51
from bzrlib import osutils
53
52
 
54
 
# Throughout this module buffer size parameters are either limited to be at
55
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
56
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
57
 
# from non-sockets as well.
58
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
 
53
# We must not read any more than 64k at a time so we don't risk "no buffer
 
54
# space available" errors on some platforms.  Windows in particular is likely
 
55
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
56
_MAX_READ_SIZE = 64 * 1024
 
57
 
59
58
 
60
59
def _get_protocol_factory_for_bytes(bytes):
61
60
    """Determine the right protocol factory for 'bytes'.
179
178
        ui.ui_factory.report_transport_activity(self, bytes, direction)
180
179
 
181
180
 
182
 
_bad_file_descriptor = (errno.EBADF,)
183
 
if sys.platform == 'win32':
184
 
    # Given on Windows if you pass a closed socket to select.select. Probably
185
 
    # also given if you pass a file handle to select.
186
 
    WSAENOTSOCK = 10038
187
 
    _bad_file_descriptor += (WSAENOTSOCK,)
188
 
 
189
 
 
190
181
class SmartServerStreamMedium(SmartMedium):
191
182
    """Handles smart commands coming over a stream.
192
183
 
205
196
        the stream.  See also the _push_back method.
206
197
    """
207
198
 
208
 
    _timer = time.time
209
 
 
210
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
199
    def __init__(self, backing_transport, root_client_path='/'):
211
200
        """Construct new server.
212
201
 
213
202
        :param backing_transport: Transport for the directory served.
216
205
        self.backing_transport = backing_transport
217
206
        self.root_client_path = root_client_path
218
207
        self.finished = False
219
 
        if timeout is None:
220
 
            raise AssertionError('You must supply a timeout.')
221
 
        self._client_timeout = timeout
222
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
223
208
        SmartMedium.__init__(self)
224
209
 
225
210
    def serve(self):
230
215
        try:
231
216
            while not self.finished:
232
217
                server_protocol = self._build_protocol()
233
 
                # TODO: This seems inelegant:
234
 
                if server_protocol is None:
235
 
                    # We could 'continue' only to notice that self.finished is
236
 
                    # True...
237
 
                    break
238
218
                self._serve_one_request(server_protocol)
239
 
        except errors.ConnectionTimeout, e:
240
 
            trace.note('%s' % (e,))
241
 
            trace.log_exception_quietly()
242
 
            self._disconnect_client()
243
 
            # We reported it, no reason to make a big fuss.
244
 
            return
245
219
        except Exception, e:
246
220
            stderr.write("%s terminating on exception %s\n" % (self, e))
247
221
            raise
248
 
        self._disconnect_client()
249
 
 
250
 
    def _stop_gracefully(self):
251
 
        """When we finish this message, stop looking for more."""
252
 
        trace.mutter('Stopping %s' % (self,))
253
 
        self.finished = True
254
 
 
255
 
    def _disconnect_client(self):
256
 
        """Close the current connection. We stopped due to a timeout/etc."""
257
 
        # The default implementation is a no-op, because that is all we used to
258
 
        # do when disconnecting from a client. I suppose we never had the
259
 
        # *server* initiate a disconnect, before
260
 
 
261
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
262
 
        """Wait for more bytes to be read, but timeout if none available.
263
 
 
264
 
        This allows us to detect idle connections, and stop trying to read from
265
 
        them, without setting the socket itself to non-blocking. This also
266
 
        allows us to specify when we watch for idle timeouts.
267
 
 
268
 
        :return: Did we timeout? (True if we timed out, False if there is data
269
 
            to be read)
270
 
        """
271
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
272
222
 
273
223
    def _build_protocol(self):
274
224
        """Identifies the version of the incoming request, and returns an
279
229
 
280
230
        :returns: a SmartServerRequestProtocol.
281
231
        """
282
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
283
 
        if self.finished:
284
 
            # We're stopping, so don't try to do any more work
285
 
            return None
286
232
        bytes = self._get_line()
287
233
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
288
234
        protocol = protocol_factory(
290
236
        protocol.accept_bytes(unused_bytes)
291
237
        return protocol
292
238
 
293
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
294
 
        """select() on a file descriptor, waiting for nonblocking read()
295
 
 
296
 
        This will raise a ConnectionTimeout exception if we do not get a
297
 
        readable handle before timeout_seconds.
298
 
        :return: None
299
 
        """
300
 
        t_end = self._timer() + timeout_seconds
301
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
302
 
        rs = xs = None
303
 
        while not rs and not xs and self._timer() < t_end:
304
 
            if self.finished:
305
 
                return
306
 
            try:
307
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
308
 
            except (select.error, socket.error) as e:
309
 
                err = getattr(e, 'errno', None)
310
 
                if err is None and getattr(e, 'args', None) is not None:
311
 
                    # select.error doesn't have 'errno', it just has args[0]
312
 
                    err = e.args[0]
313
 
                if err in _bad_file_descriptor:
314
 
                    return # Not a socket indicates read() will fail
315
 
                elif err == errno.EINTR:
316
 
                    # Interrupted, keep looping.
317
 
                    continue
318
 
                raise
319
 
        if rs or xs:
320
 
            return
321
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
322
 
                                       % (timeout_seconds,))
323
 
 
324
239
    def _serve_one_request(self, protocol):
325
240
        """Read one request from input, process, send back a response.
326
241
 
347
262
 
348
263
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
349
264
 
350
 
    def __init__(self, sock, backing_transport, root_client_path='/',
351
 
                 timeout=None):
 
265
    def __init__(self, sock, backing_transport, root_client_path='/'):
352
266
        """Constructor.
353
267
 
354
268
        :param sock: the socket the server will read from.  It will be put
355
269
            into blocking mode.
356
270
        """
357
271
        SmartServerStreamMedium.__init__(
358
 
            self, backing_transport, root_client_path=root_client_path,
359
 
            timeout=timeout)
 
272
            self, backing_transport, root_client_path=root_client_path)
360
273
        sock.setblocking(True)
361
274
        self.socket = sock
362
 
        # Get the getpeername now, as we might be closed later when we care.
363
 
        try:
364
 
            self._client_info = sock.getpeername()
365
 
        except socket.error:
366
 
            self._client_info = '<unknown>'
367
 
 
368
 
    def __str__(self):
369
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
370
 
 
371
 
    def __repr__(self):
372
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
373
 
            self._client_info)
374
275
 
375
276
    def _serve_one_request_unguarded(self, protocol):
376
277
        while protocol.next_read_size():
377
278
            # We can safely try to read large chunks.  If there is less data
378
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
379
 
            # short read immediately rather than block.
380
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
 
279
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
280
            # read immediately rather than block.
 
281
            bytes = self.read_bytes(_MAX_READ_SIZE)
381
282
            if bytes == '':
382
283
                self.finished = True
383
284
                return
385
286
 
386
287
        self._push_back(protocol.unused_data)
387
288
 
388
 
    def _disconnect_client(self):
389
 
        """Close the current connection. We stopped due to a timeout/etc."""
390
 
        self.socket.close()
391
 
 
392
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
393
 
        """Wait for more bytes to be read, but timeout if none available.
394
 
 
395
 
        This allows us to detect idle connections, and stop trying to read from
396
 
        them, without setting the socket itself to non-blocking. This also
397
 
        allows us to specify when we watch for idle timeouts.
398
 
 
399
 
        :return: None, this will raise ConnectionTimeout if we time out before
400
 
            data is available.
401
 
        """
402
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
403
 
 
404
289
    def _read_bytes(self, desired_count):
405
 
        return osutils.read_bytes_from_socket(
406
 
            self.socket, self._report_activity)
 
290
        return _read_bytes_from_socket(
 
291
            self.socket.recv, desired_count, self._report_activity)
407
292
 
408
293
    def terminate_due_to_error(self):
409
294
        # TODO: This should log to a server log file, but no such thing
410
295
        # exists yet.  Andrew Bennetts 2006-09-29.
411
 
        self.socket.close()
 
296
        osutils.until_no_eintr(self.socket.close)
412
297
        self.finished = True
413
298
 
414
299
    def _write_out(self, bytes):
423
308
 
424
309
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
425
310
 
426
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
311
    def __init__(self, in_file, out_file, backing_transport):
427
312
        """Construct new server.
428
313
 
429
314
        :param in_file: Python file from which requests can be read.
430
315
        :param out_file: Python file to write responses.
431
316
        :param backing_transport: Transport for the directory served.
432
317
        """
433
 
        SmartServerStreamMedium.__init__(self, backing_transport,
434
 
            timeout=timeout)
 
318
        SmartServerStreamMedium.__init__(self, backing_transport)
435
319
        if sys.platform == 'win32':
436
320
            # force binary mode for files
437
321
            import msvcrt
442
326
        self._in = in_file
443
327
        self._out = out_file
444
328
 
445
 
    def serve(self):
446
 
        """See SmartServerStreamMedium.serve"""
447
 
        # This is the regular serve, except it adds signal trapping for soft
448
 
        # shutdown.
449
 
        stop_gracefully = self._stop_gracefully
450
 
        signals.register_on_hangup(id(self), stop_gracefully)
451
 
        try:
452
 
            return super(SmartServerPipeStreamMedium, self).serve()
453
 
        finally:
454
 
            signals.unregister_on_hangup(id(self))
455
 
 
456
329
    def _serve_one_request_unguarded(self, protocol):
457
330
        while True:
458
331
            # We need to be careful not to read past the end of the current
461
334
            bytes_to_read = protocol.next_read_size()
462
335
            if bytes_to_read == 0:
463
336
                # Finished serving this request.
464
 
                self._out.flush()
 
337
                osutils.until_no_eintr(self._out.flush)
465
338
                return
466
339
            bytes = self.read_bytes(bytes_to_read)
467
340
            if bytes == '':
468
341
                # Connection has been closed.
469
342
                self.finished = True
470
 
                self._out.flush()
 
343
                osutils.until_no_eintr(self._out.flush)
471
344
                return
472
345
            protocol.accept_bytes(bytes)
473
346
 
474
 
    def _disconnect_client(self):
475
 
        self._in.close()
476
 
        self._out.flush()
477
 
        self._out.close()
478
 
 
479
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
480
 
        """Wait for more bytes to be read, but timeout if none available.
481
 
 
482
 
        This allows us to detect idle connections, and stop trying to read from
483
 
        them, without setting the socket itself to non-blocking. This also
484
 
        allows us to specify when we watch for idle timeouts.
485
 
 
486
 
        :return: None, this will raise ConnectionTimeout if we time out before
487
 
            data is available.
488
 
        """
489
 
        if (getattr(self._in, 'fileno', None) is None
490
 
            or sys.platform == 'win32'):
491
 
            # You can't select() file descriptors on Windows.
492
 
            return
493
 
        return self._wait_on_descriptor(self._in, timeout_seconds)
494
 
 
495
347
    def _read_bytes(self, desired_count):
496
 
        return self._in.read(desired_count)
 
348
        return osutils.until_no_eintr(self._in.read, desired_count)
497
349
 
498
350
    def terminate_due_to_error(self):
499
351
        # TODO: This should log to a server log file, but no such thing
500
352
        # exists yet.  Andrew Bennetts 2006-09-29.
501
 
        self._out.close()
 
353
        osutils.until_no_eintr(self._out.close)
502
354
        self.finished = True
503
355
 
504
356
    def _write_out(self, bytes):
505
 
        self._out.write(bytes)
 
357
        osutils.until_no_eintr(self._out.write, bytes)
506
358
 
507
359
 
508
360
class SmartClientMediumRequest(object):
641
493
        return self._medium._get_line()
642
494
 
643
495
 
644
 
class _VfsRefuser(object):
645
 
    """An object that refuses all VFS requests.
646
 
 
647
 
    """
648
 
 
649
 
    def __init__(self):
650
 
        client._SmartClient.hooks.install_named_hook(
651
 
            'call', self.check_vfs, 'vfs refuser')
652
 
 
653
 
    def check_vfs(self, params):
654
 
        try:
655
 
            request_method = request.request_handlers.get(params.method)
656
 
        except KeyError:
657
 
            # A method we don't know about doesn't count as a VFS method.
658
 
            return
659
 
        if issubclass(request_method, vfs.VfsRequest):
660
 
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
661
 
 
662
 
 
663
496
class _DebugCounter(object):
664
497
    """An object that counts the HPSS calls made to each client medium.
665
498
 
666
 
    When a medium is garbage-collected, or failing that when
667
 
    bzrlib.global_state exits, the total number of calls made on that medium
668
 
    are reported via trace.note.
 
499
    When a medium is garbage-collected, or failing that when atexit functions
 
500
    are run, the total number of calls made on that medium are reported via
 
501
    trace.note.
669
502
    """
670
503
 
671
504
    def __init__(self):
672
505
        self.counts = weakref.WeakKeyDictionary()
673
506
        client._SmartClient.hooks.install_named_hook(
674
507
            'call', self.increment_call_count, 'hpss call counter')
675
 
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
 
508
        atexit.register(self.flush_all)
676
509
 
677
510
    def track(self, medium):
678
511
        """Start tracking calls made to a medium.
712
545
        value['count'] = 0
713
546
        value['vfs_count'] = 0
714
547
        if count != 0:
715
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
716
 
                       count, vfs_count, medium_repr))
 
548
            trace.note('HPSS calls: %d (%d vfs) %s',
 
549
                       count, vfs_count, medium_repr)
717
550
 
718
551
    def flush_all(self):
719
552
        for ref in list(self.counts.keys()):
720
553
            self.done(ref)
721
554
 
722
555
_debug_counter = None
723
 
_vfs_refuser = None
724
556
 
725
557
 
726
558
class SmartClientMedium(SmartMedium):
743
575
            if _debug_counter is None:
744
576
                _debug_counter = _DebugCounter()
745
577
            _debug_counter.track(self)
746
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
747
 
            global _vfs_refuser
748
 
            if _vfs_refuser is None:
749
 
                _vfs_refuser = _VfsRefuser()
750
578
 
751
579
    def _is_remote_before(self, version_tuple):
752
580
        """Is it possible the remote side supports RPCs for a given version?
781
609
            # which is newer than a previously supplied older-than version.
782
610
            # This indicates that some smart verb call is not guarded
783
611
            # appropriately (it should simply not have been tried).
784
 
            trace.mutter(
 
612
            raise AssertionError(
785
613
                "_remember_remote_is_before(%r) called, but "
786
614
                "_remember_remote_is_before(%r) was called previously."
787
 
                , version_tuple, self._remote_version_is_before)
788
 
            if 'hpss' in debug.debug_flags:
789
 
                ui.ui_factory.show_warning(
790
 
                    "_remember_remote_is_before(%r) called, but "
791
 
                    "_remember_remote_is_before(%r) was called previously."
792
 
                    % (version_tuple, self._remote_version_is_before))
793
 
            return
 
615
                % (version_tuple, self._remote_version_is_before))
794
616
        self._remote_version_is_before = version_tuple
795
617
 
796
618
    def protocol_version(self):
898
720
 
899
721
    def _accept_bytes(self, bytes):
900
722
        """See SmartClientStreamMedium.accept_bytes."""
901
 
        self._writeable_pipe.write(bytes)
 
723
        osutils.until_no_eintr(self._writeable_pipe.write, bytes)
902
724
        self._report_activity(len(bytes), 'write')
903
725
 
904
726
    def _flush(self):
905
727
        """See SmartClientStreamMedium._flush()."""
906
 
        self._writeable_pipe.flush()
 
728
        osutils.until_no_eintr(self._writeable_pipe.flush)
907
729
 
908
730
    def _read_bytes(self, count):
909
731
        """See SmartClientStreamMedium._read_bytes."""
910
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
911
 
        bytes = self._readable_pipe.read(bytes_to_read)
 
732
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
912
733
        self._report_activity(len(bytes), 'read')
913
734
        return bytes
914
735
 
915
736
 
916
 
class SSHParams(object):
917
 
    """A set of parameters for starting a remote bzr via SSH."""
 
737
class SmartSSHClientMedium(SmartClientStreamMedium):
 
738
    """A client medium using SSH."""
918
739
 
919
740
    def __init__(self, host, port=None, username=None, password=None,
920
 
            bzr_remote_path='bzr'):
921
 
        self.host = host
922
 
        self.port = port
923
 
        self.username = username
924
 
        self.password = password
925
 
        self.bzr_remote_path = bzr_remote_path
926
 
 
927
 
 
928
 
class SmartSSHClientMedium(SmartClientStreamMedium):
929
 
    """A client medium using SSH.
930
 
    
931
 
    It delegates IO to a SmartClientSocketMedium or
932
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
933
 
    """
934
 
 
935
 
    def __init__(self, base, ssh_params, vendor=None):
 
741
            base=None, vendor=None, bzr_remote_path=None):
936
742
        """Creates a client that will connect on the first use.
937
743
 
938
 
        :param ssh_params: A SSHParams instance.
939
744
        :param vendor: An optional override for the ssh vendor to use. See
940
745
            bzrlib.transport.ssh for details on ssh vendors.
941
746
        """
942
 
        self._real_medium = None
943
 
        self._ssh_params = ssh_params
 
747
        self._connected = False
 
748
        self._host = host
 
749
        self._password = password
 
750
        self._port = port
 
751
        self._username = username
944
752
        # for the benefit of progress making a short description of this
945
753
        # transport
946
754
        self._scheme = 'bzr+ssh'
948
756
        # _DebugCounter so we have to store all the values used in our repr
949
757
        # method before calling the super init.
950
758
        SmartClientStreamMedium.__init__(self, base)
 
759
        self._read_from = None
 
760
        self._ssh_connection = None
951
761
        self._vendor = vendor
952
 
        self._ssh_connection = None
 
762
        self._write_to = None
 
763
        self._bzr_remote_path = bzr_remote_path
953
764
 
954
765
    def __repr__(self):
955
 
        if self._ssh_params.port is None:
 
766
        if self._port is None:
956
767
            maybe_port = ''
957
768
        else:
958
 
            maybe_port = ':%s' % self._ssh_params.port
 
769
            maybe_port = ':%s' % self._port
959
770
        return "%s(%s://%s@%s%s/)" % (
960
771
            self.__class__.__name__,
961
772
            self._scheme,
962
 
            self._ssh_params.username,
963
 
            self._ssh_params.host,
 
773
            self._username,
 
774
            self._host,
964
775
            maybe_port)
965
776
 
966
777
    def _accept_bytes(self, bytes):
967
778
        """See SmartClientStreamMedium.accept_bytes."""
968
779
        self._ensure_connection()
969
 
        self._real_medium.accept_bytes(bytes)
 
780
        osutils.until_no_eintr(self._write_to.write, bytes)
 
781
        self._report_activity(len(bytes), 'write')
970
782
 
971
783
    def disconnect(self):
972
784
        """See SmartClientMedium.disconnect()."""
973
 
        if self._real_medium is not None:
974
 
            self._real_medium.disconnect()
975
 
            self._real_medium = None
976
 
        if self._ssh_connection is not None:
977
 
            self._ssh_connection.close()
978
 
            self._ssh_connection = None
 
785
        if not self._connected:
 
786
            return
 
787
        osutils.until_no_eintr(self._read_from.close)
 
788
        osutils.until_no_eintr(self._write_to.close)
 
789
        self._ssh_connection.close()
 
790
        self._connected = False
979
791
 
980
792
    def _ensure_connection(self):
981
793
        """Connect this medium if not already connected."""
982
 
        if self._real_medium is not None:
 
794
        if self._connected:
983
795
            return
984
796
        if self._vendor is None:
985
797
            vendor = ssh._get_ssh_vendor()
986
798
        else:
987
799
            vendor = self._vendor
988
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
989
 
                self._ssh_params.password, self._ssh_params.host,
990
 
                self._ssh_params.port,
991
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
800
        self._ssh_connection = vendor.connect_ssh(self._username,
 
801
                self._password, self._host, self._port,
 
802
                command=[self._bzr_remote_path, 'serve', '--inet',
992
803
                         '--directory=/', '--allow-writes'])
993
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
994
 
        if io_kind == 'socket':
995
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
996
 
                self.base, io_object)
997
 
        elif io_kind == 'pipes':
998
 
            read_from, write_to = io_object
999
 
            self._real_medium = SmartSimplePipesClientMedium(
1000
 
                read_from, write_to, self.base)
1001
 
        else:
1002
 
            raise AssertionError(
1003
 
                "Unexpected io_kind %r from %r"
1004
 
                % (io_kind, self._ssh_connection))
 
804
        self._read_from, self._write_to = \
 
805
            self._ssh_connection.get_filelike_channels()
 
806
        self._connected = True
1005
807
 
1006
808
    def _flush(self):
1007
809
        """See SmartClientStreamMedium._flush()."""
1008
 
        self._real_medium._flush()
 
810
        self._write_to.flush()
1009
811
 
1010
812
    def _read_bytes(self, count):
1011
813
        """See SmartClientStreamMedium.read_bytes."""
1012
 
        if self._real_medium is None:
 
814
        if not self._connected:
1013
815
            raise errors.MediumNotConnected(self)
1014
 
        return self._real_medium.read_bytes(count)
 
816
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
817
        bytes = osutils.until_no_eintr(self._read_from.read, bytes_to_read)
 
818
        self._report_activity(len(bytes), 'read')
 
819
        return bytes
1015
820
 
1016
821
 
1017
822
# Port 4155 is the default port for bzr://, registered with IANA.
1019
824
BZR_DEFAULT_PORT = 4155
1020
825
 
1021
826
 
1022
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1023
 
    """A client medium using a socket.
1024
 
    
1025
 
    This class isn't usable directly.  Use one of its subclasses instead.
1026
 
    """
 
827
class SmartTCPClientMedium(SmartClientStreamMedium):
 
828
    """A client medium using TCP."""
1027
829
 
1028
 
    def __init__(self, base):
 
830
    def __init__(self, host, port, base):
 
831
        """Creates a client that will connect on the first use."""
1029
832
        SmartClientStreamMedium.__init__(self, base)
 
833
        self._connected = False
 
834
        self._host = host
 
835
        self._port = port
1030
836
        self._socket = None
1031
 
        self._connected = False
1032
837
 
1033
838
    def _accept_bytes(self, bytes):
1034
839
        """See SmartClientMedium.accept_bytes."""
1035
840
        self._ensure_connection()
1036
841
        osutils.send_all(self._socket, bytes, self._report_activity)
1037
842
 
1038
 
    def _ensure_connection(self):
1039
 
        """Connect this medium if not already connected."""
1040
 
        raise NotImplementedError(self._ensure_connection)
1041
 
 
1042
 
    def _flush(self):
1043
 
        """See SmartClientStreamMedium._flush().
1044
 
 
1045
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1046
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1047
 
        future.
1048
 
        """
1049
 
 
1050
 
    def _read_bytes(self, count):
1051
 
        """See SmartClientMedium.read_bytes."""
1052
 
        if not self._connected:
1053
 
            raise errors.MediumNotConnected(self)
1054
 
        return osutils.read_bytes_from_socket(
1055
 
            self._socket, self._report_activity)
1056
 
 
1057
843
    def disconnect(self):
1058
844
        """See SmartClientMedium.disconnect()."""
1059
845
        if not self._connected:
1060
846
            return
1061
 
        self._socket.close()
 
847
        osutils.until_no_eintr(self._socket.close)
1062
848
        self._socket = None
1063
849
        self._connected = False
1064
850
 
1065
 
 
1066
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1067
 
    """A client medium that creates a TCP connection."""
1068
 
 
1069
 
    def __init__(self, host, port, base):
1070
 
        """Creates a client that will connect on the first use."""
1071
 
        SmartClientSocketMedium.__init__(self, base)
1072
 
        self._host = host
1073
 
        self._port = port
1074
 
 
1075
851
    def _ensure_connection(self):
1076
852
        """Connect this medium if not already connected."""
1077
853
        if self._connected:
1111
887
                    (self._host, port, err_msg))
1112
888
        self._connected = True
1113
889
 
1114
 
 
1115
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1116
 
    """A client medium for an already connected socket.
1117
 
    
1118
 
    Note that this class will assume it "owns" the socket, so it will close it
1119
 
    when its disconnect method is called.
1120
 
    """
1121
 
 
1122
 
    def __init__(self, base, sock):
1123
 
        SmartClientSocketMedium.__init__(self, base)
1124
 
        self._socket = sock
1125
 
        self._connected = True
1126
 
 
1127
 
    def _ensure_connection(self):
1128
 
        # Already connected, by definition!  So nothing to do.
1129
 
        pass
 
890
    def _flush(self):
 
891
        """See SmartClientStreamMedium._flush().
 
892
 
 
893
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
894
        add a means to do a flush, but that can be done in the future.
 
895
        """
 
896
 
 
897
    def _read_bytes(self, count):
 
898
        """See SmartClientMedium.read_bytes."""
 
899
        if not self._connected:
 
900
            raise errors.MediumNotConnected(self)
 
901
        return _read_bytes_from_socket(
 
902
            self._socket.recv, count, self._report_activity)
1130
903
 
1131
904
 
1132
905
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1169
942
        self._medium._flush()
1170
943
 
1171
944
 
 
945
def _read_bytes_from_socket(sock, desired_count, report_activity):
 
946
    # We ignore the desired_count because on sockets it's more efficient to
 
947
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
948
    try:
 
949
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
 
950
    except socket.error, e:
 
951
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
 
952
            # The connection was closed by the other side.  Callers expect an
 
953
            # empty string to signal end-of-stream.
 
954
            bytes = ''
 
955
        else:
 
956
            raise
 
957
    else:
 
958
        report_activity(len(bytes), 'read')
 
959
    return bytes
 
960