~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2011-12-05 14:12:23 UTC
  • mto: This revision was merged to the branch mainline in revision 6348.
  • Revision ID: jelmer@samba.org-20111205141223-8qxae4h37satlzgq
Move more functionality to vf_search.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
26
26
 
27
27
import errno
28
28
import os
29
 
import socket
30
29
import sys
 
30
import time
31
31
import urllib
32
32
 
 
33
import bzrlib
33
34
from bzrlib.lazy_import import lazy_import
34
35
lazy_import(globals(), """
35
 
import atexit
 
36
import select
 
37
import socket
36
38
import thread
37
39
import weakref
38
40
 
39
41
from bzrlib import (
40
42
    debug,
41
43
    errors,
42
 
    symbol_versioning,
43
44
    trace,
44
45
    ui,
45
46
    urlutils,
46
47
    )
47
 
from bzrlib.smart import client, protocol, request, vfs
 
48
from bzrlib.i18n import gettext
 
49
from bzrlib.smart import client, protocol, request, signals, vfs
48
50
from bzrlib.transport import ssh
49
51
""")
50
 
#usually already imported, and getting IllegalScoperReplacer on it here.
51
52
from bzrlib import osutils
52
53
 
53
 
# We must not read any more than 64k at a time so we don't risk "no buffer
54
 
# space available" errors on some platforms.  Windows in particular is likely
55
 
# to give error 10053 or 10055 if we read more than 64k from a socket.
56
 
_MAX_READ_SIZE = 64 * 1024
57
 
 
 
54
# Throughout this module buffer size parameters are either limited to be at
 
55
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
 
56
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
 
57
# from non-sockets as well.
 
58
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
58
59
 
59
60
def _get_protocol_factory_for_bytes(bytes):
60
61
    """Determine the right protocol factory for 'bytes'.
178
179
        ui.ui_factory.report_transport_activity(self, bytes, direction)
179
180
 
180
181
 
 
182
_bad_file_descriptor = (errno.EBADF,)
 
183
if sys.platform == 'win32':
 
184
    # Given on Windows if you pass a closed socket to select.select. Probably
 
185
    # also given if you pass a file handle to select.
 
186
    WSAENOTSOCK = 10038
 
187
    _bad_file_descriptor += (WSAENOTSOCK,)
 
188
 
 
189
 
181
190
class SmartServerStreamMedium(SmartMedium):
182
191
    """Handles smart commands coming over a stream.
183
192
 
196
205
        the stream.  See also the _push_back method.
197
206
    """
198
207
 
199
 
    def __init__(self, backing_transport, root_client_path='/'):
 
208
    _timer = time.time
 
209
 
 
210
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
200
211
        """Construct new server.
201
212
 
202
213
        :param backing_transport: Transport for the directory served.
205
216
        self.backing_transport = backing_transport
206
217
        self.root_client_path = root_client_path
207
218
        self.finished = False
 
219
        if timeout is None:
 
220
            raise AssertionError('You must supply a timeout.')
 
221
        self._client_timeout = timeout
 
222
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
208
223
        SmartMedium.__init__(self)
209
224
 
210
225
    def serve(self):
216
231
            while not self.finished:
217
232
                server_protocol = self._build_protocol()
218
233
                self._serve_one_request(server_protocol)
 
234
        except errors.ConnectionTimeout, e:
 
235
            trace.note('%s' % (e,))
 
236
            trace.log_exception_quietly()
 
237
            self._disconnect_client()
 
238
            # We reported it, no reason to make a big fuss.
 
239
            return
219
240
        except Exception, e:
220
241
            stderr.write("%s terminating on exception %s\n" % (self, e))
221
242
            raise
 
243
        self._disconnect_client()
 
244
 
 
245
    def _stop_gracefully(self):
 
246
        """When we finish this message, stop looking for more."""
 
247
        trace.mutter('Stopping %s' % (self,))
 
248
        self.finished = True
 
249
 
 
250
    def _disconnect_client(self):
 
251
        """Close the current connection. We stopped due to a timeout/etc."""
 
252
        # The default implementation is a no-op, because that is all we used to
 
253
        # do when disconnecting from a client. I suppose we never had the
 
254
        # *server* initiate a disconnect, before
 
255
 
 
256
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
257
        """Wait for more bytes to be read, but timeout if none available.
 
258
 
 
259
        This allows us to detect idle connections, and stop trying to read from
 
260
        them, without setting the socket itself to non-blocking. This also
 
261
        allows us to specify when we watch for idle timeouts.
 
262
 
 
263
        :return: Did we timeout? (True if we timed out, False if there is data
 
264
            to be read)
 
265
        """
 
266
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
222
267
 
223
268
    def _build_protocol(self):
224
269
        """Identifies the version of the incoming request, and returns an
229
274
 
230
275
        :returns: a SmartServerRequestProtocol.
231
276
        """
 
277
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
278
        if self.finished:
 
279
            # We're stopping, so don't try to do any more work
 
280
            return None
232
281
        bytes = self._get_line()
233
282
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
234
283
        protocol = protocol_factory(
236
285
        protocol.accept_bytes(unused_bytes)
237
286
        return protocol
238
287
 
 
288
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
289
        """select() on a file descriptor, waiting for nonblocking read()
 
290
 
 
291
        This will raise a ConnectionTimeout exception if we do not get a
 
292
        readable handle before timeout_seconds.
 
293
        :return: None
 
294
        """
 
295
        t_end = self._timer() + timeout_seconds
 
296
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
297
        rs = xs = None
 
298
        while not rs and not xs and self._timer() < t_end:
 
299
            if self.finished:
 
300
                return
 
301
            try:
 
302
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
303
            except (select.error, socket.error) as e:
 
304
                err = getattr(e, 'errno', None)
 
305
                if err is None and getattr(e, 'args', None) is not None:
 
306
                    # select.error doesn't have 'errno', it just has args[0]
 
307
                    err = e.args[0]
 
308
                if err in _bad_file_descriptor:
 
309
                    return # Not a socket indicates read() will fail
 
310
                elif err == errno.EINTR:
 
311
                    # Interrupted, keep looping.
 
312
                    continue
 
313
                raise
 
314
        if rs or xs:
 
315
            return
 
316
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
317
                                       % (timeout_seconds,))
 
318
 
239
319
    def _serve_one_request(self, protocol):
240
320
        """Read one request from input, process, send back a response.
241
321
 
242
322
        :param protocol: a SmartServerRequestProtocol.
243
323
        """
 
324
        if protocol is None:
 
325
            return
244
326
        try:
245
327
            self._serve_one_request_unguarded(protocol)
246
328
        except KeyboardInterrupt:
262
344
 
263
345
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
264
346
 
265
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
347
    def __init__(self, sock, backing_transport, root_client_path='/',
 
348
                 timeout=None):
266
349
        """Constructor.
267
350
 
268
351
        :param sock: the socket the server will read from.  It will be put
269
352
            into blocking mode.
270
353
        """
271
354
        SmartServerStreamMedium.__init__(
272
 
            self, backing_transport, root_client_path=root_client_path)
 
355
            self, backing_transport, root_client_path=root_client_path,
 
356
            timeout=timeout)
273
357
        sock.setblocking(True)
274
358
        self.socket = sock
 
359
        # Get the getpeername now, as we might be closed later when we care.
 
360
        try:
 
361
            self._client_info = sock.getpeername()
 
362
        except socket.error:
 
363
            self._client_info = '<unknown>'
 
364
 
 
365
    def __str__(self):
 
366
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
367
 
 
368
    def __repr__(self):
 
369
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
370
            self._client_info)
275
371
 
276
372
    def _serve_one_request_unguarded(self, protocol):
277
373
        while protocol.next_read_size():
278
374
            # We can safely try to read large chunks.  If there is less data
279
 
            # than _MAX_READ_SIZE ready, the socket wil just return a short
280
 
            # read immediately rather than block.
281
 
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
375
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
 
376
            # short read immediately rather than block.
 
377
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
282
378
            if bytes == '':
283
379
                self.finished = True
284
380
                return
286
382
 
287
383
        self._push_back(protocol.unused_data)
288
384
 
 
385
    def _disconnect_client(self):
 
386
        """Close the current connection. We stopped due to a timeout/etc."""
 
387
        self.socket.close()
 
388
 
 
389
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
390
        """Wait for more bytes to be read, but timeout if none available.
 
391
 
 
392
        This allows us to detect idle connections, and stop trying to read from
 
393
        them, without setting the socket itself to non-blocking. This also
 
394
        allows us to specify when we watch for idle timeouts.
 
395
 
 
396
        :return: None, this will raise ConnectionTimeout if we time out before
 
397
            data is available.
 
398
        """
 
399
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
400
 
289
401
    def _read_bytes(self, desired_count):
290
 
        return _read_bytes_from_socket(
291
 
            self.socket.recv, desired_count, self._report_activity)
 
402
        return osutils.read_bytes_from_socket(
 
403
            self.socket, self._report_activity)
292
404
 
293
405
    def terminate_due_to_error(self):
294
406
        # TODO: This should log to a server log file, but no such thing
295
407
        # exists yet.  Andrew Bennetts 2006-09-29.
296
 
        osutils.until_no_eintr(self.socket.close)
 
408
        self.socket.close()
297
409
        self.finished = True
298
410
 
299
411
    def _write_out(self, bytes):
308
420
 
309
421
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
310
422
 
311
 
    def __init__(self, in_file, out_file, backing_transport):
 
423
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
312
424
        """Construct new server.
313
425
 
314
426
        :param in_file: Python file from which requests can be read.
315
427
        :param out_file: Python file to write responses.
316
428
        :param backing_transport: Transport for the directory served.
317
429
        """
318
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
430
        SmartServerStreamMedium.__init__(self, backing_transport,
 
431
            timeout=timeout)
319
432
        if sys.platform == 'win32':
320
433
            # force binary mode for files
321
434
            import msvcrt
326
439
        self._in = in_file
327
440
        self._out = out_file
328
441
 
 
442
    def serve(self):
 
443
        """See SmartServerStreamMedium.serve"""
 
444
        # This is the regular serve, except it adds signal trapping for soft
 
445
        # shutdown.
 
446
        stop_gracefully = self._stop_gracefully
 
447
        signals.register_on_hangup(id(self), stop_gracefully)
 
448
        try:
 
449
            return super(SmartServerPipeStreamMedium, self).serve()
 
450
        finally:
 
451
            signals.unregister_on_hangup(id(self))
 
452
 
329
453
    def _serve_one_request_unguarded(self, protocol):
330
454
        while True:
331
455
            # We need to be careful not to read past the end of the current
334
458
            bytes_to_read = protocol.next_read_size()
335
459
            if bytes_to_read == 0:
336
460
                # Finished serving this request.
337
 
                osutils.until_no_eintr(self._out.flush)
 
461
                self._out.flush()
338
462
                return
339
463
            bytes = self.read_bytes(bytes_to_read)
340
464
            if bytes == '':
341
465
                # Connection has been closed.
342
466
                self.finished = True
343
 
                osutils.until_no_eintr(self._out.flush)
 
467
                self._out.flush()
344
468
                return
345
469
            protocol.accept_bytes(bytes)
346
470
 
 
471
    def _disconnect_client(self):
 
472
        self._in.close()
 
473
        self._out.flush()
 
474
        self._out.close()
 
475
 
 
476
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
477
        """Wait for more bytes to be read, but timeout if none available.
 
478
 
 
479
        This allows us to detect idle connections, and stop trying to read from
 
480
        them, without setting the socket itself to non-blocking. This also
 
481
        allows us to specify when we watch for idle timeouts.
 
482
 
 
483
        :return: None, this will raise ConnectionTimeout if we time out before
 
484
            data is available.
 
485
        """
 
486
        if (getattr(self._in, 'fileno', None) is None
 
487
            or sys.platform == 'win32'):
 
488
            # You can't select() file descriptors on Windows.
 
489
            return
 
490
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
491
 
347
492
    def _read_bytes(self, desired_count):
348
 
        return osutils.until_no_eintr(self._in.read, desired_count)
 
493
        return self._in.read(desired_count)
349
494
 
350
495
    def terminate_due_to_error(self):
351
496
        # TODO: This should log to a server log file, but no such thing
352
497
        # exists yet.  Andrew Bennetts 2006-09-29.
353
 
        osutils.until_no_eintr(self._out.close)
 
498
        self._out.close()
354
499
        self.finished = True
355
500
 
356
501
    def _write_out(self, bytes):
357
 
        osutils.until_no_eintr(self._out.write, bytes)
 
502
        self._out.write(bytes)
358
503
 
359
504
 
360
505
class SmartClientMediumRequest(object):
493
638
        return self._medium._get_line()
494
639
 
495
640
 
 
641
class _VfsRefuser(object):
 
642
    """An object that refuses all VFS requests.
 
643
 
 
644
    """
 
645
 
 
646
    def __init__(self):
 
647
        client._SmartClient.hooks.install_named_hook(
 
648
            'call', self.check_vfs, 'vfs refuser')
 
649
 
 
650
    def check_vfs(self, params):
 
651
        try:
 
652
            request_method = request.request_handlers.get(params.method)
 
653
        except KeyError:
 
654
            # A method we don't know about doesn't count as a VFS method.
 
655
            return
 
656
        if issubclass(request_method, vfs.VfsRequest):
 
657
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
 
658
 
 
659
 
496
660
class _DebugCounter(object):
497
661
    """An object that counts the HPSS calls made to each client medium.
498
662
 
499
 
    When a medium is garbage-collected, or failing that when atexit functions
500
 
    are run, the total number of calls made on that medium are reported via
501
 
    trace.note.
 
663
    When a medium is garbage-collected, or failing that when
 
664
    bzrlib.global_state exits, the total number of calls made on that medium
 
665
    are reported via trace.note.
502
666
    """
503
667
 
504
668
    def __init__(self):
505
669
        self.counts = weakref.WeakKeyDictionary()
506
670
        client._SmartClient.hooks.install_named_hook(
507
671
            'call', self.increment_call_count, 'hpss call counter')
508
 
        atexit.register(self.flush_all)
 
672
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
509
673
 
510
674
    def track(self, medium):
511
675
        """Start tracking calls made to a medium.
545
709
        value['count'] = 0
546
710
        value['vfs_count'] = 0
547
711
        if count != 0:
548
 
            trace.note('HPSS calls: %d (%d vfs) %s',
549
 
                       count, vfs_count, medium_repr)
 
712
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
713
                       count, vfs_count, medium_repr))
550
714
 
551
715
    def flush_all(self):
552
716
        for ref in list(self.counts.keys()):
553
717
            self.done(ref)
554
718
 
555
719
_debug_counter = None
 
720
_vfs_refuser = None
556
721
 
557
722
 
558
723
class SmartClientMedium(SmartMedium):
575
740
            if _debug_counter is None:
576
741
                _debug_counter = _DebugCounter()
577
742
            _debug_counter.track(self)
 
743
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
744
            global _vfs_refuser
 
745
            if _vfs_refuser is None:
 
746
                _vfs_refuser = _VfsRefuser()
578
747
 
579
748
    def _is_remote_before(self, version_tuple):
580
749
        """Is it possible the remote side supports RPCs for a given version?
609
778
            # which is newer than a previously supplied older-than version.
610
779
            # This indicates that some smart verb call is not guarded
611
780
            # appropriately (it should simply not have been tried).
612
 
            raise AssertionError(
 
781
            trace.mutter(
613
782
                "_remember_remote_is_before(%r) called, but "
614
783
                "_remember_remote_is_before(%r) was called previously."
615
 
                % (version_tuple, self._remote_version_is_before))
 
784
                , version_tuple, self._remote_version_is_before)
 
785
            if 'hpss' in debug.debug_flags:
 
786
                ui.ui_factory.show_warning(
 
787
                    "_remember_remote_is_before(%r) called, but "
 
788
                    "_remember_remote_is_before(%r) was called previously."
 
789
                    % (version_tuple, self._remote_version_is_before))
 
790
            return
616
791
        self._remote_version_is_before = version_tuple
617
792
 
618
793
    def protocol_version(self):
706
881
        """
707
882
        return SmartClientStreamMediumRequest(self)
708
883
 
 
884
    def reset(self):
 
885
        """We have been disconnected, reset current state.
 
886
 
 
887
        This resets things like _current_request and connected state.
 
888
        """
 
889
        self.disconnect()
 
890
        self._current_request = None
 
891
 
709
892
 
710
893
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
711
894
    """A client medium using simple pipes.
720
903
 
721
904
    def _accept_bytes(self, bytes):
722
905
        """See SmartClientStreamMedium.accept_bytes."""
723
 
        osutils.until_no_eintr(self._writeable_pipe.write, bytes)
 
906
        try:
 
907
            self._writeable_pipe.write(bytes)
 
908
        except IOError, e:
 
909
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
910
                raise errors.ConnectionReset(
 
911
                    "Error trying to write to subprocess:\n%s" % (e,))
 
912
            raise
724
913
        self._report_activity(len(bytes), 'write')
725
914
 
726
915
    def _flush(self):
727
916
        """See SmartClientStreamMedium._flush()."""
728
 
        osutils.until_no_eintr(self._writeable_pipe.flush)
 
917
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
918
        #       However, testing shows that even when the child process is
 
919
        #       gone, this doesn't error.
 
920
        self._writeable_pipe.flush()
729
921
 
730
922
    def _read_bytes(self, count):
731
923
        """See SmartClientStreamMedium._read_bytes."""
732
 
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
 
924
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
925
        bytes = self._readable_pipe.read(bytes_to_read)
733
926
        self._report_activity(len(bytes), 'read')
734
927
        return bytes
735
928
 
736
929
 
 
930
class SSHParams(object):
 
931
    """A set of parameters for starting a remote bzr via SSH."""
 
932
 
 
933
    def __init__(self, host, port=None, username=None, password=None,
 
934
            bzr_remote_path='bzr'):
 
935
        self.host = host
 
936
        self.port = port
 
937
        self.username = username
 
938
        self.password = password
 
939
        self.bzr_remote_path = bzr_remote_path
 
940
 
 
941
 
737
942
class SmartSSHClientMedium(SmartClientStreamMedium):
738
 
    """A client medium using SSH."""
739
 
 
740
 
    def __init__(self, host, port=None, username=None, password=None,
741
 
            base=None, vendor=None, bzr_remote_path=None):
 
943
    """A client medium using SSH.
 
944
 
 
945
    It delegates IO to a SmartSimplePipesClientMedium or
 
946
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
947
    """
 
948
 
 
949
    def __init__(self, base, ssh_params, vendor=None):
742
950
        """Creates a client that will connect on the first use.
743
951
 
 
952
        :param ssh_params: A SSHParams instance.
744
953
        :param vendor: An optional override for the ssh vendor to use. See
745
954
            bzrlib.transport.ssh for details on ssh vendors.
746
955
        """
747
 
        self._connected = False
748
 
        self._host = host
749
 
        self._password = password
750
 
        self._port = port
751
 
        self._username = username
 
956
        self._real_medium = None
 
957
        self._ssh_params = ssh_params
752
958
        # for the benefit of progress making a short description of this
753
959
        # transport
754
960
        self._scheme = 'bzr+ssh'
756
962
        # _DebugCounter so we have to store all the values used in our repr
757
963
        # method before calling the super init.
758
964
        SmartClientStreamMedium.__init__(self, base)
759
 
        self._read_from = None
 
965
        self._vendor = vendor
760
966
        self._ssh_connection = None
761
 
        self._vendor = vendor
762
 
        self._write_to = None
763
 
        self._bzr_remote_path = bzr_remote_path
764
967
 
765
968
    def __repr__(self):
766
 
        if self._port is None:
 
969
        if self._ssh_params.port is None:
767
970
            maybe_port = ''
768
971
        else:
769
 
            maybe_port = ':%s' % self._port
770
 
        return "%s(%s://%s@%s%s/)" % (
 
972
            maybe_port = ':%s' % self._ssh_params.port
 
973
        if self._ssh_params.username is None:
 
974
            maybe_user = ''
 
975
        else:
 
976
            maybe_user = '%s@' % self._ssh_params.username
 
977
        return "%s(%s://%s%s%s/)" % (
771
978
            self.__class__.__name__,
772
979
            self._scheme,
773
 
            self._username,
774
 
            self._host,
 
980
            maybe_user,
 
981
            self._ssh_params.host,
775
982
            maybe_port)
776
983
 
777
984
    def _accept_bytes(self, bytes):
778
985
        """See SmartClientStreamMedium.accept_bytes."""
779
986
        self._ensure_connection()
780
 
        osutils.until_no_eintr(self._write_to.write, bytes)
781
 
        self._report_activity(len(bytes), 'write')
 
987
        self._real_medium.accept_bytes(bytes)
782
988
 
783
989
    def disconnect(self):
784
990
        """See SmartClientMedium.disconnect()."""
785
 
        if not self._connected:
786
 
            return
787
 
        osutils.until_no_eintr(self._read_from.close)
788
 
        osutils.until_no_eintr(self._write_to.close)
789
 
        self._ssh_connection.close()
790
 
        self._connected = False
 
991
        if self._real_medium is not None:
 
992
            self._real_medium.disconnect()
 
993
            self._real_medium = None
 
994
        if self._ssh_connection is not None:
 
995
            self._ssh_connection.close()
 
996
            self._ssh_connection = None
791
997
 
792
998
    def _ensure_connection(self):
793
999
        """Connect this medium if not already connected."""
794
 
        if self._connected:
 
1000
        if self._real_medium is not None:
795
1001
            return
796
1002
        if self._vendor is None:
797
1003
            vendor = ssh._get_ssh_vendor()
798
1004
        else:
799
1005
            vendor = self._vendor
800
 
        self._ssh_connection = vendor.connect_ssh(self._username,
801
 
                self._password, self._host, self._port,
802
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
1006
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
1007
                self._ssh_params.password, self._ssh_params.host,
 
1008
                self._ssh_params.port,
 
1009
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
803
1010
                         '--directory=/', '--allow-writes'])
804
 
        self._read_from, self._write_to = \
805
 
            self._ssh_connection.get_filelike_channels()
806
 
        self._connected = True
 
1011
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
1012
        if io_kind == 'socket':
 
1013
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
1014
                self.base, io_object)
 
1015
        elif io_kind == 'pipes':
 
1016
            read_from, write_to = io_object
 
1017
            self._real_medium = SmartSimplePipesClientMedium(
 
1018
                read_from, write_to, self.base)
 
1019
        else:
 
1020
            raise AssertionError(
 
1021
                "Unexpected io_kind %r from %r"
 
1022
                % (io_kind, self._ssh_connection))
807
1023
 
808
1024
    def _flush(self):
809
1025
        """See SmartClientStreamMedium._flush()."""
810
 
        self._write_to.flush()
 
1026
        self._real_medium._flush()
811
1027
 
812
1028
    def _read_bytes(self, count):
813
1029
        """See SmartClientStreamMedium.read_bytes."""
814
 
        if not self._connected:
 
1030
        if self._real_medium is None:
815
1031
            raise errors.MediumNotConnected(self)
816
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
817
 
        bytes = osutils.until_no_eintr(self._read_from.read, bytes_to_read)
818
 
        self._report_activity(len(bytes), 'read')
819
 
        return bytes
 
1032
        return self._real_medium.read_bytes(count)
820
1033
 
821
1034
 
822
1035
# Port 4155 is the default port for bzr://, registered with IANA.
824
1037
BZR_DEFAULT_PORT = 4155
825
1038
 
826
1039
 
827
 
class SmartTCPClientMedium(SmartClientStreamMedium):
828
 
    """A client medium using TCP."""
 
1040
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1041
    """A client medium using a socket.
 
1042
    
 
1043
    This class isn't usable directly.  Use one of its subclasses instead.
 
1044
    """
 
1045
 
 
1046
    def __init__(self, base):
 
1047
        SmartClientStreamMedium.__init__(self, base)
 
1048
        self._socket = None
 
1049
        self._connected = False
 
1050
 
 
1051
    def _accept_bytes(self, bytes):
 
1052
        """See SmartClientMedium.accept_bytes."""
 
1053
        self._ensure_connection()
 
1054
        osutils.send_all(self._socket, bytes, self._report_activity)
 
1055
 
 
1056
    def _ensure_connection(self):
 
1057
        """Connect this medium if not already connected."""
 
1058
        raise NotImplementedError(self._ensure_connection)
 
1059
 
 
1060
    def _flush(self):
 
1061
        """See SmartClientStreamMedium._flush().
 
1062
 
 
1063
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1064
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1065
        future.
 
1066
        """
 
1067
 
 
1068
    def _read_bytes(self, count):
 
1069
        """See SmartClientMedium.read_bytes."""
 
1070
        if not self._connected:
 
1071
            raise errors.MediumNotConnected(self)
 
1072
        return osutils.read_bytes_from_socket(
 
1073
            self._socket, self._report_activity)
 
1074
 
 
1075
    def disconnect(self):
 
1076
        """See SmartClientMedium.disconnect()."""
 
1077
        if not self._connected:
 
1078
            return
 
1079
        self._socket.close()
 
1080
        self._socket = None
 
1081
        self._connected = False
 
1082
 
 
1083
 
 
1084
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1085
    """A client medium that creates a TCP connection."""
829
1086
 
830
1087
    def __init__(self, host, port, base):
831
1088
        """Creates a client that will connect on the first use."""
832
 
        SmartClientStreamMedium.__init__(self, base)
833
 
        self._connected = False
 
1089
        SmartClientSocketMedium.__init__(self, base)
834
1090
        self._host = host
835
1091
        self._port = port
836
 
        self._socket = None
837
 
 
838
 
    def _accept_bytes(self, bytes):
839
 
        """See SmartClientMedium.accept_bytes."""
840
 
        self._ensure_connection()
841
 
        osutils.send_all(self._socket, bytes, self._report_activity)
842
 
 
843
 
    def disconnect(self):
844
 
        """See SmartClientMedium.disconnect()."""
845
 
        if not self._connected:
846
 
            return
847
 
        osutils.until_no_eintr(self._socket.close)
848
 
        self._socket = None
849
 
        self._connected = False
850
1092
 
851
1093
    def _ensure_connection(self):
852
1094
        """Connect this medium if not already connected."""
887
1129
                    (self._host, port, err_msg))
888
1130
        self._connected = True
889
1131
 
890
 
    def _flush(self):
891
 
        """See SmartClientStreamMedium._flush().
892
 
 
893
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
894
 
        add a means to do a flush, but that can be done in the future.
895
 
        """
896
 
 
897
 
    def _read_bytes(self, count):
898
 
        """See SmartClientMedium.read_bytes."""
899
 
        if not self._connected:
900
 
            raise errors.MediumNotConnected(self)
901
 
        return _read_bytes_from_socket(
902
 
            self._socket.recv, count, self._report_activity)
 
1132
 
 
1133
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1134
    """A client medium for an already connected socket.
 
1135
    
 
1136
    Note that this class will assume it "owns" the socket, so it will close it
 
1137
    when its disconnect method is called.
 
1138
    """
 
1139
 
 
1140
    def __init__(self, base, sock):
 
1141
        SmartClientSocketMedium.__init__(self, base)
 
1142
        self._socket = sock
 
1143
        self._connected = True
 
1144
 
 
1145
    def _ensure_connection(self):
 
1146
        # Already connected, by definition!  So nothing to do.
 
1147
        pass
903
1148
 
904
1149
 
905
1150
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
940
1185
        This invokes self._medium._flush to ensure all bytes are transmitted.
941
1186
        """
942
1187
        self._medium._flush()
943
 
 
944
 
 
945
 
def _read_bytes_from_socket(sock, desired_count, report_activity):
946
 
    # We ignore the desired_count because on sockets it's more efficient to
947
 
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
948
 
    try:
949
 
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
950
 
    except socket.error, e:
951
 
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
952
 
            # The connection was closed by the other side.  Callers expect an
953
 
            # empty string to signal end-of-stream.
954
 
            bytes = ''
955
 
        else:
956
 
            raise
957
 
    else:
958
 
        report_activity(len(bytes), 'read')
959
 
    return bytes
960