~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Andrew Bennetts
  • Date: 2009-12-03 02:24:54 UTC
  • mfrom: (4634.101.4 2.0)
  • mto: This revision was merged to the branch mainline in revision 4857.
  • Revision ID: andrew.bennetts@canonical.com-20091203022454-m2gyhbcdqi1t7ujz
Merge lp:bzr/2.0 into lp:bzr.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
from __future__ import absolute_import
28
 
 
29
27
import errno
30
28
import os
 
29
import socket
31
30
import sys
32
 
import time
 
31
import urllib
33
32
 
34
 
import bzrlib
35
33
from bzrlib.lazy_import import lazy_import
36
34
lazy_import(globals(), """
37
 
import select
38
 
import socket
39
 
import thread
 
35
import atexit
40
36
import weakref
41
 
 
42
37
from bzrlib import (
43
38
    debug,
44
39
    errors,
 
40
    symbol_versioning,
45
41
    trace,
46
42
    ui,
47
43
    urlutils,
48
44
    )
49
 
from bzrlib.i18n import gettext
50
 
from bzrlib.smart import client, protocol, request, signals, vfs
 
45
from bzrlib.smart import client, protocol, request, vfs
51
46
from bzrlib.transport import ssh
52
47
""")
 
48
#usually already imported, and getting IllegalScoperReplacer on it here.
53
49
from bzrlib import osutils
54
50
 
55
 
# Throughout this module buffer size parameters are either limited to be at
56
 
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
57
 
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
58
 
# from non-sockets as well.
59
 
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
 
51
# We must not read any more than 64k at a time so we don't risk "no buffer
 
52
# space available" errors on some platforms.  Windows in particular is likely
 
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
54
_MAX_READ_SIZE = 64 * 1024
 
55
 
60
56
 
61
57
def _get_protocol_factory_for_bytes(bytes):
62
58
    """Determine the right protocol factory for 'bytes'.
180
176
        ui.ui_factory.report_transport_activity(self, bytes, direction)
181
177
 
182
178
 
183
 
_bad_file_descriptor = (errno.EBADF,)
184
 
if sys.platform == 'win32':
185
 
    # Given on Windows if you pass a closed socket to select.select. Probably
186
 
    # also given if you pass a file handle to select.
187
 
    WSAENOTSOCK = 10038
188
 
    _bad_file_descriptor += (WSAENOTSOCK,)
189
 
 
190
 
 
191
179
class SmartServerStreamMedium(SmartMedium):
192
180
    """Handles smart commands coming over a stream.
193
181
 
206
194
        the stream.  See also the _push_back method.
207
195
    """
208
196
 
209
 
    _timer = time.time
210
 
 
211
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
197
    def __init__(self, backing_transport, root_client_path='/'):
212
198
        """Construct new server.
213
199
 
214
200
        :param backing_transport: Transport for the directory served.
217
203
        self.backing_transport = backing_transport
218
204
        self.root_client_path = root_client_path
219
205
        self.finished = False
220
 
        if timeout is None:
221
 
            raise AssertionError('You must supply a timeout.')
222
 
        self._client_timeout = timeout
223
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
224
206
        SmartMedium.__init__(self)
225
207
 
226
208
    def serve(self):
232
214
            while not self.finished:
233
215
                server_protocol = self._build_protocol()
234
216
                self._serve_one_request(server_protocol)
235
 
        except errors.ConnectionTimeout, e:
236
 
            trace.note('%s' % (e,))
237
 
            trace.log_exception_quietly()
238
 
            self._disconnect_client()
239
 
            # We reported it, no reason to make a big fuss.
240
 
            return
241
217
        except Exception, e:
242
218
            stderr.write("%s terminating on exception %s\n" % (self, e))
243
219
            raise
244
 
        self._disconnect_client()
245
 
 
246
 
    def _stop_gracefully(self):
247
 
        """When we finish this message, stop looking for more."""
248
 
        trace.mutter('Stopping %s' % (self,))
249
 
        self.finished = True
250
 
 
251
 
    def _disconnect_client(self):
252
 
        """Close the current connection. We stopped due to a timeout/etc."""
253
 
        # The default implementation is a no-op, because that is all we used to
254
 
        # do when disconnecting from a client. I suppose we never had the
255
 
        # *server* initiate a disconnect, before
256
 
 
257
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
258
 
        """Wait for more bytes to be read, but timeout if none available.
259
 
 
260
 
        This allows us to detect idle connections, and stop trying to read from
261
 
        them, without setting the socket itself to non-blocking. This also
262
 
        allows us to specify when we watch for idle timeouts.
263
 
 
264
 
        :return: Did we timeout? (True if we timed out, False if there is data
265
 
            to be read)
266
 
        """
267
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
268
220
 
269
221
    def _build_protocol(self):
270
222
        """Identifies the version of the incoming request, and returns an
275
227
 
276
228
        :returns: a SmartServerRequestProtocol.
277
229
        """
278
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
279
 
        if self.finished:
280
 
            # We're stopping, so don't try to do any more work
281
 
            return None
282
230
        bytes = self._get_line()
283
231
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
284
232
        protocol = protocol_factory(
286
234
        protocol.accept_bytes(unused_bytes)
287
235
        return protocol
288
236
 
289
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
290
 
        """select() on a file descriptor, waiting for nonblocking read()
291
 
 
292
 
        This will raise a ConnectionTimeout exception if we do not get a
293
 
        readable handle before timeout_seconds.
294
 
        :return: None
295
 
        """
296
 
        t_end = self._timer() + timeout_seconds
297
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
298
 
        rs = xs = None
299
 
        while not rs and not xs and self._timer() < t_end:
300
 
            if self.finished:
301
 
                return
302
 
            try:
303
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
304
 
            except (select.error, socket.error) as e:
305
 
                err = getattr(e, 'errno', None)
306
 
                if err is None and getattr(e, 'args', None) is not None:
307
 
                    # select.error doesn't have 'errno', it just has args[0]
308
 
                    err = e.args[0]
309
 
                if err in _bad_file_descriptor:
310
 
                    return # Not a socket indicates read() will fail
311
 
                elif err == errno.EINTR:
312
 
                    # Interrupted, keep looping.
313
 
                    continue
314
 
                raise
315
 
        if rs or xs:
316
 
            return
317
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
318
 
                                       % (timeout_seconds,))
319
 
 
320
237
    def _serve_one_request(self, protocol):
321
238
        """Read one request from input, process, send back a response.
322
239
 
323
240
        :param protocol: a SmartServerRequestProtocol.
324
241
        """
325
 
        if protocol is None:
326
 
            return
327
242
        try:
328
243
            self._serve_one_request_unguarded(protocol)
329
244
        except KeyboardInterrupt:
345
260
 
346
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
347
262
 
348
 
    def __init__(self, sock, backing_transport, root_client_path='/',
349
 
                 timeout=None):
 
263
    def __init__(self, sock, backing_transport, root_client_path='/'):
350
264
        """Constructor.
351
265
 
352
266
        :param sock: the socket the server will read from.  It will be put
353
267
            into blocking mode.
354
268
        """
355
269
        SmartServerStreamMedium.__init__(
356
 
            self, backing_transport, root_client_path=root_client_path,
357
 
            timeout=timeout)
 
270
            self, backing_transport, root_client_path=root_client_path)
358
271
        sock.setblocking(True)
359
272
        self.socket = sock
360
 
        # Get the getpeername now, as we might be closed later when we care.
361
 
        try:
362
 
            self._client_info = sock.getpeername()
363
 
        except socket.error:
364
 
            self._client_info = '<unknown>'
365
 
 
366
 
    def __str__(self):
367
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
368
 
 
369
 
    def __repr__(self):
370
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
371
 
            self._client_info)
372
273
 
373
274
    def _serve_one_request_unguarded(self, protocol):
374
275
        while protocol.next_read_size():
375
276
            # We can safely try to read large chunks.  If there is less data
376
 
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
377
 
            # short read immediately rather than block.
378
 
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
 
277
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
278
            # read immediately rather than block.
 
279
            bytes = self.read_bytes(_MAX_READ_SIZE)
379
280
            if bytes == '':
380
281
                self.finished = True
381
282
                return
383
284
 
384
285
        self._push_back(protocol.unused_data)
385
286
 
386
 
    def _disconnect_client(self):
387
 
        """Close the current connection. We stopped due to a timeout/etc."""
388
 
        self.socket.close()
389
 
 
390
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
391
 
        """Wait for more bytes to be read, but timeout if none available.
392
 
 
393
 
        This allows us to detect idle connections, and stop trying to read from
394
 
        them, without setting the socket itself to non-blocking. This also
395
 
        allows us to specify when we watch for idle timeouts.
396
 
 
397
 
        :return: None, this will raise ConnectionTimeout if we time out before
398
 
            data is available.
399
 
        """
400
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
401
 
 
402
287
    def _read_bytes(self, desired_count):
403
 
        return osutils.read_bytes_from_socket(
404
 
            self.socket, self._report_activity)
 
288
        return _read_bytes_from_socket(
 
289
            self.socket.recv, desired_count, self._report_activity)
405
290
 
406
291
    def terminate_due_to_error(self):
407
292
        # TODO: This should log to a server log file, but no such thing
408
293
        # exists yet.  Andrew Bennetts 2006-09-29.
409
 
        self.socket.close()
 
294
        osutils.until_no_eintr(self.socket.close)
410
295
        self.finished = True
411
296
 
412
297
    def _write_out(self, bytes):
413
 
        tstart = osutils.timer_func()
414
298
        osutils.send_all(self.socket, bytes, self._report_activity)
415
 
        if 'hpss' in debug.debug_flags:
416
 
            thread_id = thread.get_ident()
417
 
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
418
 
                         % ('wrote', thread_id, len(bytes),
419
 
                            osutils.timer_func() - tstart))
420
299
 
421
300
 
422
301
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
423
302
 
424
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
303
    def __init__(self, in_file, out_file, backing_transport):
425
304
        """Construct new server.
426
305
 
427
306
        :param in_file: Python file from which requests can be read.
428
307
        :param out_file: Python file to write responses.
429
308
        :param backing_transport: Transport for the directory served.
430
309
        """
431
 
        SmartServerStreamMedium.__init__(self, backing_transport,
432
 
            timeout=timeout)
 
310
        SmartServerStreamMedium.__init__(self, backing_transport)
433
311
        if sys.platform == 'win32':
434
312
            # force binary mode for files
435
313
            import msvcrt
440
318
        self._in = in_file
441
319
        self._out = out_file
442
320
 
443
 
    def serve(self):
444
 
        """See SmartServerStreamMedium.serve"""
445
 
        # This is the regular serve, except it adds signal trapping for soft
446
 
        # shutdown.
447
 
        stop_gracefully = self._stop_gracefully
448
 
        signals.register_on_hangup(id(self), stop_gracefully)
449
 
        try:
450
 
            return super(SmartServerPipeStreamMedium, self).serve()
451
 
        finally:
452
 
            signals.unregister_on_hangup(id(self))
453
 
 
454
321
    def _serve_one_request_unguarded(self, protocol):
455
322
        while True:
456
323
            # We need to be careful not to read past the end of the current
459
326
            bytes_to_read = protocol.next_read_size()
460
327
            if bytes_to_read == 0:
461
328
                # Finished serving this request.
462
 
                self._out.flush()
 
329
                osutils.until_no_eintr(self._out.flush)
463
330
                return
464
331
            bytes = self.read_bytes(bytes_to_read)
465
332
            if bytes == '':
466
333
                # Connection has been closed.
467
334
                self.finished = True
468
 
                self._out.flush()
 
335
                osutils.until_no_eintr(self._out.flush)
469
336
                return
470
337
            protocol.accept_bytes(bytes)
471
338
 
472
 
    def _disconnect_client(self):
473
 
        self._in.close()
474
 
        self._out.flush()
475
 
        self._out.close()
476
 
 
477
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
478
 
        """Wait for more bytes to be read, but timeout if none available.
479
 
 
480
 
        This allows us to detect idle connections, and stop trying to read from
481
 
        them, without setting the socket itself to non-blocking. This also
482
 
        allows us to specify when we watch for idle timeouts.
483
 
 
484
 
        :return: None, this will raise ConnectionTimeout if we time out before
485
 
            data is available.
486
 
        """
487
 
        if (getattr(self._in, 'fileno', None) is None
488
 
            or sys.platform == 'win32'):
489
 
            # You can't select() file descriptors on Windows.
490
 
            return
491
 
        return self._wait_on_descriptor(self._in, timeout_seconds)
492
 
 
493
339
    def _read_bytes(self, desired_count):
494
 
        return self._in.read(desired_count)
 
340
        return osutils.until_no_eintr(self._in.read, desired_count)
495
341
 
496
342
    def terminate_due_to_error(self):
497
343
        # TODO: This should log to a server log file, but no such thing
498
344
        # exists yet.  Andrew Bennetts 2006-09-29.
499
 
        self._out.close()
 
345
        osutils.until_no_eintr(self._out.close)
500
346
        self.finished = True
501
347
 
502
348
    def _write_out(self, bytes):
503
 
        self._out.write(bytes)
 
349
        osutils.until_no_eintr(self._out.write, bytes)
504
350
 
505
351
 
506
352
class SmartClientMediumRequest(object):
639
485
        return self._medium._get_line()
640
486
 
641
487
 
642
 
class _VfsRefuser(object):
643
 
    """An object that refuses all VFS requests.
644
 
 
645
 
    """
646
 
 
647
 
    def __init__(self):
648
 
        client._SmartClient.hooks.install_named_hook(
649
 
            'call', self.check_vfs, 'vfs refuser')
650
 
 
651
 
    def check_vfs(self, params):
652
 
        try:
653
 
            request_method = request.request_handlers.get(params.method)
654
 
        except KeyError:
655
 
            # A method we don't know about doesn't count as a VFS method.
656
 
            return
657
 
        if issubclass(request_method, vfs.VfsRequest):
658
 
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
659
 
 
660
 
 
661
488
class _DebugCounter(object):
662
489
    """An object that counts the HPSS calls made to each client medium.
663
490
 
664
 
    When a medium is garbage-collected, or failing that when
665
 
    bzrlib.global_state exits, the total number of calls made on that medium
666
 
    are reported via trace.note.
 
491
    When a medium is garbage-collected, or failing that when atexit functions
 
492
    are run, the total number of calls made on that medium are reported via
 
493
    trace.note.
667
494
    """
668
495
 
669
496
    def __init__(self):
670
497
        self.counts = weakref.WeakKeyDictionary()
671
498
        client._SmartClient.hooks.install_named_hook(
672
499
            'call', self.increment_call_count, 'hpss call counter')
673
 
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
 
500
        atexit.register(self.flush_all)
674
501
 
675
502
    def track(self, medium):
676
503
        """Start tracking calls made to a medium.
710
537
        value['count'] = 0
711
538
        value['vfs_count'] = 0
712
539
        if count != 0:
713
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
714
 
                       count, vfs_count, medium_repr))
 
540
            trace.note('HPSS calls: %d (%d vfs) %s',
 
541
                       count, vfs_count, medium_repr)
715
542
 
716
543
    def flush_all(self):
717
544
        for ref in list(self.counts.keys()):
718
545
            self.done(ref)
719
546
 
720
547
_debug_counter = None
721
 
_vfs_refuser = None
722
548
 
723
549
 
724
550
class SmartClientMedium(SmartMedium):
741
567
            if _debug_counter is None:
742
568
                _debug_counter = _DebugCounter()
743
569
            _debug_counter.track(self)
744
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
745
 
            global _vfs_refuser
746
 
            if _vfs_refuser is None:
747
 
                _vfs_refuser = _VfsRefuser()
748
570
 
749
571
    def _is_remote_before(self, version_tuple):
750
572
        """Is it possible the remote side supports RPCs for a given version?
779
601
            # which is newer than a previously supplied older-than version.
780
602
            # This indicates that some smart verb call is not guarded
781
603
            # appropriately (it should simply not have been tried).
782
 
            trace.mutter(
 
604
            raise AssertionError(
783
605
                "_remember_remote_is_before(%r) called, but "
784
606
                "_remember_remote_is_before(%r) was called previously."
785
 
                , version_tuple, self._remote_version_is_before)
786
 
            if 'hpss' in debug.debug_flags:
787
 
                ui.ui_factory.show_warning(
788
 
                    "_remember_remote_is_before(%r) called, but "
789
 
                    "_remember_remote_is_before(%r) was called previously."
790
 
                    % (version_tuple, self._remote_version_is_before))
791
 
            return
 
607
                % (version_tuple, self._remote_version_is_before))
792
608
        self._remote_version_is_before = version_tuple
793
609
 
794
610
    def protocol_version(self):
841
657
        """
842
658
        medium_base = urlutils.join(self.base, '/')
843
659
        rel_url = urlutils.relative_url(medium_base, transport.base)
844
 
        return urlutils.unquote(rel_url)
 
660
        return urllib.unquote(rel_url)
845
661
 
846
662
 
847
663
class SmartClientStreamMedium(SmartClientMedium):
882
698
        """
883
699
        return SmartClientStreamMediumRequest(self)
884
700
 
885
 
    def reset(self):
886
 
        """We have been disconnected, reset current state.
887
 
 
888
 
        This resets things like _current_request and connected state.
889
 
        """
890
 
        self.disconnect()
891
 
        self._current_request = None
892
 
 
893
701
 
894
702
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
895
703
    """A client medium using simple pipes.
904
712
 
905
713
    def _accept_bytes(self, bytes):
906
714
        """See SmartClientStreamMedium.accept_bytes."""
907
 
        try:
908
 
            self._writeable_pipe.write(bytes)
909
 
        except IOError, e:
910
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
911
 
                raise errors.ConnectionReset(
912
 
                    "Error trying to write to subprocess:\n%s" % (e,))
913
 
            raise
 
715
        osutils.until_no_eintr(self._writeable_pipe.write, bytes)
914
716
        self._report_activity(len(bytes), 'write')
915
717
 
916
718
    def _flush(self):
917
719
        """See SmartClientStreamMedium._flush()."""
918
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
919
 
        #       However, testing shows that even when the child process is
920
 
        #       gone, this doesn't error.
921
 
        self._writeable_pipe.flush()
 
720
        osutils.until_no_eintr(self._writeable_pipe.flush)
922
721
 
923
722
    def _read_bytes(self, count):
924
723
        """See SmartClientStreamMedium._read_bytes."""
925
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
926
 
        bytes = self._readable_pipe.read(bytes_to_read)
 
724
        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
927
725
        self._report_activity(len(bytes), 'read')
928
726
        return bytes
929
727
 
930
728
 
931
 
class SSHParams(object):
932
 
    """A set of parameters for starting a remote bzr via SSH."""
 
729
class SmartSSHClientMedium(SmartClientStreamMedium):
 
730
    """A client medium using SSH."""
933
731
 
934
732
    def __init__(self, host, port=None, username=None, password=None,
935
 
            bzr_remote_path='bzr'):
936
 
        self.host = host
937
 
        self.port = port
938
 
        self.username = username
939
 
        self.password = password
940
 
        self.bzr_remote_path = bzr_remote_path
941
 
 
942
 
 
943
 
class SmartSSHClientMedium(SmartClientStreamMedium):
944
 
    """A client medium using SSH.
945
 
 
946
 
    It delegates IO to a SmartSimplePipesClientMedium or
947
 
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
948
 
    """
949
 
 
950
 
    def __init__(self, base, ssh_params, vendor=None):
 
733
            base=None, vendor=None, bzr_remote_path=None):
951
734
        """Creates a client that will connect on the first use.
952
735
 
953
 
        :param ssh_params: A SSHParams instance.
954
736
        :param vendor: An optional override for the ssh vendor to use. See
955
737
            bzrlib.transport.ssh for details on ssh vendors.
956
738
        """
957
 
        self._real_medium = None
958
 
        self._ssh_params = ssh_params
959
 
        # for the benefit of progress making a short description of this
960
 
        # transport
961
 
        self._scheme = 'bzr+ssh'
 
739
        self._connected = False
 
740
        self._host = host
 
741
        self._password = password
 
742
        self._port = port
 
743
        self._username = username
962
744
        # SmartClientStreamMedium stores the repr of this object in its
963
745
        # _DebugCounter so we have to store all the values used in our repr
964
746
        # method before calling the super init.
965
747
        SmartClientStreamMedium.__init__(self, base)
 
748
        self._read_from = None
 
749
        self._ssh_connection = None
966
750
        self._vendor = vendor
967
 
        self._ssh_connection = None
 
751
        self._write_to = None
 
752
        self._bzr_remote_path = bzr_remote_path
 
753
        # for the benefit of progress making a short description of this
 
754
        # transport
 
755
        self._scheme = 'bzr+ssh'
968
756
 
969
757
    def __repr__(self):
970
 
        if self._ssh_params.port is None:
971
 
            maybe_port = ''
972
 
        else:
973
 
            maybe_port = ':%s' % self._ssh_params.port
974
 
        if self._ssh_params.username is None:
975
 
            maybe_user = ''
976
 
        else:
977
 
            maybe_user = '%s@' % self._ssh_params.username
978
 
        return "%s(%s://%s%s%s/)" % (
 
758
        return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
979
759
            self.__class__.__name__,
980
 
            self._scheme,
981
 
            maybe_user,
982
 
            self._ssh_params.host,
983
 
            maybe_port)
 
760
            self._connected,
 
761
            self._username,
 
762
            self._host,
 
763
            self._port)
984
764
 
985
765
    def _accept_bytes(self, bytes):
986
766
        """See SmartClientStreamMedium.accept_bytes."""
987
767
        self._ensure_connection()
988
 
        self._real_medium.accept_bytes(bytes)
 
768
        osutils.until_no_eintr(self._write_to.write, bytes)
 
769
        self._report_activity(len(bytes), 'write')
989
770
 
990
771
    def disconnect(self):
991
772
        """See SmartClientMedium.disconnect()."""
992
 
        if self._real_medium is not None:
993
 
            self._real_medium.disconnect()
994
 
            self._real_medium = None
995
 
        if self._ssh_connection is not None:
996
 
            self._ssh_connection.close()
997
 
            self._ssh_connection = None
 
773
        if not self._connected:
 
774
            return
 
775
        osutils.until_no_eintr(self._read_from.close)
 
776
        osutils.until_no_eintr(self._write_to.close)
 
777
        self._ssh_connection.close()
 
778
        self._connected = False
998
779
 
999
780
    def _ensure_connection(self):
1000
781
        """Connect this medium if not already connected."""
1001
 
        if self._real_medium is not None:
 
782
        if self._connected:
1002
783
            return
1003
784
        if self._vendor is None:
1004
785
            vendor = ssh._get_ssh_vendor()
1005
786
        else:
1006
787
            vendor = self._vendor
1007
 
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1008
 
                self._ssh_params.password, self._ssh_params.host,
1009
 
                self._ssh_params.port,
1010
 
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
 
788
        self._ssh_connection = vendor.connect_ssh(self._username,
 
789
                self._password, self._host, self._port,
 
790
                command=[self._bzr_remote_path, 'serve', '--inet',
1011
791
                         '--directory=/', '--allow-writes'])
1012
 
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1013
 
        if io_kind == 'socket':
1014
 
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1015
 
                self.base, io_object)
1016
 
        elif io_kind == 'pipes':
1017
 
            read_from, write_to = io_object
1018
 
            self._real_medium = SmartSimplePipesClientMedium(
1019
 
                read_from, write_to, self.base)
1020
 
        else:
1021
 
            raise AssertionError(
1022
 
                "Unexpected io_kind %r from %r"
1023
 
                % (io_kind, self._ssh_connection))
 
792
        self._read_from, self._write_to = \
 
793
            self._ssh_connection.get_filelike_channels()
 
794
        self._connected = True
1024
795
 
1025
796
    def _flush(self):
1026
797
        """See SmartClientStreamMedium._flush()."""
1027
 
        self._real_medium._flush()
 
798
        self._write_to.flush()
1028
799
 
1029
800
    def _read_bytes(self, count):
1030
801
        """See SmartClientStreamMedium.read_bytes."""
1031
 
        if self._real_medium is None:
 
802
        if not self._connected:
1032
803
            raise errors.MediumNotConnected(self)
1033
 
        return self._real_medium.read_bytes(count)
 
804
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
805
        bytes = osutils.until_no_eintr(self._read_from.read, bytes_to_read)
 
806
        self._report_activity(len(bytes), 'read')
 
807
        return bytes
1034
808
 
1035
809
 
1036
810
# Port 4155 is the default port for bzr://, registered with IANA.
1038
812
BZR_DEFAULT_PORT = 4155
1039
813
 
1040
814
 
1041
 
class SmartClientSocketMedium(SmartClientStreamMedium):
1042
 
    """A client medium using a socket.
1043
 
    
1044
 
    This class isn't usable directly.  Use one of its subclasses instead.
1045
 
    """
 
815
class SmartTCPClientMedium(SmartClientStreamMedium):
 
816
    """A client medium using TCP."""
1046
817
 
1047
 
    def __init__(self, base):
 
818
    def __init__(self, host, port, base):
 
819
        """Creates a client that will connect on the first use."""
1048
820
        SmartClientStreamMedium.__init__(self, base)
 
821
        self._connected = False
 
822
        self._host = host
 
823
        self._port = port
1049
824
        self._socket = None
1050
 
        self._connected = False
1051
825
 
1052
826
    def _accept_bytes(self, bytes):
1053
827
        """See SmartClientMedium.accept_bytes."""
1054
828
        self._ensure_connection()
1055
829
        osutils.send_all(self._socket, bytes, self._report_activity)
1056
830
 
1057
 
    def _ensure_connection(self):
1058
 
        """Connect this medium if not already connected."""
1059
 
        raise NotImplementedError(self._ensure_connection)
1060
 
 
1061
 
    def _flush(self):
1062
 
        """See SmartClientStreamMedium._flush().
1063
 
 
1064
 
        For sockets we do no flushing. For TCP sockets we may want to turn off
1065
 
        TCP_NODELAY and add a means to do a flush, but that can be done in the
1066
 
        future.
1067
 
        """
1068
 
 
1069
 
    def _read_bytes(self, count):
1070
 
        """See SmartClientMedium.read_bytes."""
1071
 
        if not self._connected:
1072
 
            raise errors.MediumNotConnected(self)
1073
 
        return osutils.read_bytes_from_socket(
1074
 
            self._socket, self._report_activity)
1075
 
 
1076
831
    def disconnect(self):
1077
832
        """See SmartClientMedium.disconnect()."""
1078
833
        if not self._connected:
1079
834
            return
1080
 
        self._socket.close()
 
835
        osutils.until_no_eintr(self._socket.close)
1081
836
        self._socket = None
1082
837
        self._connected = False
1083
838
 
1084
 
 
1085
 
class SmartTCPClientMedium(SmartClientSocketMedium):
1086
 
    """A client medium that creates a TCP connection."""
1087
 
 
1088
 
    def __init__(self, host, port, base):
1089
 
        """Creates a client that will connect on the first use."""
1090
 
        SmartClientSocketMedium.__init__(self, base)
1091
 
        self._host = host
1092
 
        self._port = port
1093
 
 
1094
839
    def _ensure_connection(self):
1095
840
        """Connect this medium if not already connected."""
1096
841
        if self._connected:
1130
875
                    (self._host, port, err_msg))
1131
876
        self._connected = True
1132
877
 
1133
 
 
1134
 
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1135
 
    """A client medium for an already connected socket.
1136
 
    
1137
 
    Note that this class will assume it "owns" the socket, so it will close it
1138
 
    when its disconnect method is called.
1139
 
    """
1140
 
 
1141
 
    def __init__(self, base, sock):
1142
 
        SmartClientSocketMedium.__init__(self, base)
1143
 
        self._socket = sock
1144
 
        self._connected = True
1145
 
 
1146
 
    def _ensure_connection(self):
1147
 
        # Already connected, by definition!  So nothing to do.
1148
 
        pass
 
878
    def _flush(self):
 
879
        """See SmartClientStreamMedium._flush().
 
880
 
 
881
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
 
882
        add a means to do a flush, but that can be done in the future.
 
883
        """
 
884
 
 
885
    def _read_bytes(self, count):
 
886
        """See SmartClientMedium.read_bytes."""
 
887
        if not self._connected:
 
888
            raise errors.MediumNotConnected(self)
 
889
        return _read_bytes_from_socket(
 
890
            self._socket.recv, count, self._report_activity)
1149
891
 
1150
892
 
1151
893
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1186
928
        This invokes self._medium._flush to ensure all bytes are transmitted.
1187
929
        """
1188
930
        self._medium._flush()
 
931
 
 
932
 
 
933
def _read_bytes_from_socket(sock, desired_count, report_activity):
 
934
    # We ignore the desired_count because on sockets it's more efficient to
 
935
    # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
936
    try:
 
937
        bytes = osutils.until_no_eintr(sock, _MAX_READ_SIZE)
 
938
    except socket.error, e:
 
939
        if len(e.args) and e.args[0] in (errno.ECONNRESET, 10054):
 
940
            # The connection was closed by the other side.  Callers expect an
 
941
            # empty string to signal end-of-stream.
 
942
            bytes = ''
 
943
        else:
 
944
            raise
 
945
    else:
 
946
        report_activity(len(bytes), 'read')
 
947
    return bytes
 
948