~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

(jameinel) Allow 'bzr serve' to interpret SIGHUP as a graceful shutdown.
 (bug #795025) (John A Meinel)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""The 'medium' layer for the smart servers and clients.
18
18
 
26
26
 
27
27
import errno
28
28
import os
29
 
import socket
30
29
import sys
 
30
import time
31
31
import urllib
32
32
 
 
33
import bzrlib
33
34
from bzrlib.lazy_import import lazy_import
34
35
lazy_import(globals(), """
35
 
import atexit
 
36
import select
 
37
import socket
 
38
import thread
36
39
import weakref
 
40
 
37
41
from bzrlib import (
38
42
    debug,
39
43
    errors,
40
 
    osutils,
41
 
    symbol_versioning,
42
44
    trace,
 
45
    ui,
43
46
    urlutils,
44
47
    )
45
 
from bzrlib.smart import client, protocol
 
48
from bzrlib.i18n import gettext
 
49
from bzrlib.smart import client, protocol, request, signals, vfs
46
50
from bzrlib.transport import ssh
47
51
""")
48
 
 
49
 
 
50
 
# We must not read any more than 64k at a time so we don't risk "no buffer
51
 
# space available" errors on some platforms.  Windows in particular is likely
52
 
# to give error 10053 or 10055 if we read more than 64k from a socket.
53
 
_MAX_READ_SIZE = 64 * 1024
54
 
 
 
52
from bzrlib import osutils
 
53
 
 
54
# Throughout this module buffer size parameters are either limited to be at
 
55
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
 
56
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
 
57
# from non-sockets as well.
 
58
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
55
59
 
56
60
def _get_protocol_factory_for_bytes(bytes):
57
61
    """Determine the right protocol factory for 'bytes'.
87
91
 
88
92
def _get_line(read_bytes_func):
89
93
    """Read bytes using read_bytes_func until a newline byte.
90
 
    
 
94
 
91
95
    This isn't particularly efficient, so should only be used when the
92
96
    expected size of the line is quite short.
93
 
    
 
97
 
94
98
    :returns: a tuple of two strs: (line, excess)
95
99
    """
96
100
    newline_pos = -1
112
116
 
113
117
    def __init__(self):
114
118
        self._push_back_buffer = None
115
 
        
 
119
 
116
120
    def _push_back(self, bytes):
117
121
        """Return unused bytes to the medium, because they belong to the next
118
122
        request(s).
152
156
 
153
157
    def _get_line(self):
154
158
        """Read bytes from this request's response until a newline byte.
155
 
        
 
159
 
156
160
        This isn't particularly efficient, so should only be used when the
157
161
        expected size of the line is quite short.
158
162
 
161
165
        line, excess = _get_line(self.read_bytes)
162
166
        self._push_back(excess)
163
167
        return line
164
 
 
 
168
 
 
169
    def _report_activity(self, bytes, direction):
 
170
        """Notify that this medium has activity.
 
171
 
 
172
        Implementations should call this from all methods that actually do IO.
 
173
        Be careful that it's not called twice, if one method is implemented on
 
174
        top of another.
 
175
 
 
176
        :param bytes: Number of bytes read or written.
 
177
        :param direction: 'read' or 'write' or None.
 
178
        """
 
179
        ui.ui_factory.report_transport_activity(self, bytes, direction)
 
180
 
 
181
 
 
182
_bad_file_descriptor = (errno.EBADF,)
 
183
if sys.platform == 'win32':
 
184
    # Given on Windows if you pass a closed socket to select.select. Probably
 
185
    # also given if you pass a file handle to select.
 
186
    WSAENOTSOCK = 10038
 
187
    _bad_file_descriptor += (WSAENOTSOCK,)
 
188
 
165
189
 
166
190
class SmartServerStreamMedium(SmartMedium):
167
191
    """Handles smart commands coming over a stream.
172
196
    One instance is created for each connected client; it can serve multiple
173
197
    requests in the lifetime of the connection.
174
198
 
175
 
    The server passes requests through to an underlying backing transport, 
 
199
    The server passes requests through to an underlying backing transport,
176
200
    which will typically be a LocalTransport looking at the server's filesystem.
177
201
 
178
202
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
181
205
        the stream.  See also the _push_back method.
182
206
    """
183
207
 
184
 
    def __init__(self, backing_transport, root_client_path='/'):
 
208
    _timer = time.time
 
209
 
 
210
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
185
211
        """Construct new server.
186
212
 
187
213
        :param backing_transport: Transport for the directory served.
190
216
        self.backing_transport = backing_transport
191
217
        self.root_client_path = root_client_path
192
218
        self.finished = False
 
219
        if timeout is None:
 
220
            raise AssertionError('You must supply a timeout.')
 
221
        self._client_timeout = timeout
 
222
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
193
223
        SmartMedium.__init__(self)
194
224
 
195
225
    def serve(self):
200
230
        try:
201
231
            while not self.finished:
202
232
                server_protocol = self._build_protocol()
 
233
                # TODO: This seems inelegant:
 
234
                if server_protocol is None:
 
235
                    # We could 'continue' only to notice that self.finished is
 
236
                    # True...
 
237
                    break
203
238
                self._serve_one_request(server_protocol)
 
239
        except errors.ConnectionTimeout, e:
 
240
            trace.note('%s' % (e,))
 
241
            trace.log_exception_quietly()
 
242
            self._disconnect_client()
 
243
            # We reported it, no reason to make a big fuss.
 
244
            return
204
245
        except Exception, e:
205
246
            stderr.write("%s terminating on exception %s\n" % (self, e))
206
247
            raise
 
248
        self._disconnect_client()
 
249
 
 
250
    def _stop_gracefully(self):
 
251
        """When we finish this message, stop looking for more."""
 
252
        trace.mutter('Stopping %s' % (self,))
 
253
        self.finished = True
 
254
 
 
255
    def _disconnect_client(self):
 
256
        """Close the current connection. We stopped due to a timeout/etc."""
 
257
        # The default implementation is a no-op, because that is all we used to
 
258
        # do when disconnecting from a client. I suppose we never had the
 
259
        # *server* initiate a disconnect, before
 
260
 
 
261
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
262
        """Wait for more bytes to be read, but timeout if none available.
 
263
 
 
264
        This allows us to detect idle connections, and stop trying to read from
 
265
        them, without setting the socket itself to non-blocking. This also
 
266
        allows us to specify when we watch for idle timeouts.
 
267
 
 
268
        :return: Did we timeout? (True if we timed out, False if there is data
 
269
            to be read)
 
270
        """
 
271
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
207
272
 
208
273
    def _build_protocol(self):
209
274
        """Identifies the version of the incoming request, and returns an
214
279
 
215
280
        :returns: a SmartServerRequestProtocol.
216
281
        """
 
282
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
283
        if self.finished:
 
284
            # We're stopping, so don't try to do any more work
 
285
            return None
217
286
        bytes = self._get_line()
218
287
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
219
288
        protocol = protocol_factory(
221
290
        protocol.accept_bytes(unused_bytes)
222
291
        return protocol
223
292
 
 
293
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
294
        """select() on a file descriptor, waiting for nonblocking read()
 
295
 
 
296
        This will raise a ConnectionTimeout exception if we do not get a
 
297
        readable handle before timeout_seconds.
 
298
        :return: None
 
299
        """
 
300
        t_end = self._timer() + timeout_seconds
 
301
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
302
        rs = xs = None
 
303
        while not rs and not xs and self._timer() < t_end:
 
304
            if self.finished:
 
305
                return
 
306
            try:
 
307
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
308
            except (select.error, socket.error) as e:
 
309
                err = getattr(e, 'errno', None)
 
310
                if err is None and getattr(e, 'args', None) is not None:
 
311
                    # select.error doesn't have 'errno', it just has args[0]
 
312
                    err = e.args[0]
 
313
                if err in _bad_file_descriptor:
 
314
                    return # Not a socket indicates read() will fail
 
315
                elif err == errno.EINTR:
 
316
                    # Interrupted, keep looping.
 
317
                    continue
 
318
                raise
 
319
        if rs or xs:
 
320
            return
 
321
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
322
                                       % (timeout_seconds,))
 
323
 
224
324
    def _serve_one_request(self, protocol):
225
325
        """Read one request from input, process, send back a response.
226
 
        
 
326
 
227
327
        :param protocol: a SmartServerRequestProtocol.
228
328
        """
229
329
        try:
247
347
 
248
348
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
249
349
 
250
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
350
    def __init__(self, sock, backing_transport, root_client_path='/',
 
351
                 timeout=None):
251
352
        """Constructor.
252
353
 
253
354
        :param sock: the socket the server will read from.  It will be put
254
355
            into blocking mode.
255
356
        """
256
357
        SmartServerStreamMedium.__init__(
257
 
            self, backing_transport, root_client_path=root_client_path)
 
358
            self, backing_transport, root_client_path=root_client_path,
 
359
            timeout=timeout)
258
360
        sock.setblocking(True)
259
361
        self.socket = sock
 
362
        # Get the getpeername now, as we might be closed later when we care.
 
363
        try:
 
364
            self._client_info = sock.getpeername()
 
365
        except socket.error:
 
366
            self._client_info = '<unknown>'
 
367
 
 
368
    def __str__(self):
 
369
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
370
 
 
371
    def __repr__(self):
 
372
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
373
            self._client_info)
260
374
 
261
375
    def _serve_one_request_unguarded(self, protocol):
262
376
        while protocol.next_read_size():
263
377
            # We can safely try to read large chunks.  If there is less data
264
 
            # than _MAX_READ_SIZE ready, the socket wil just return a short
265
 
            # read immediately rather than block.
266
 
            bytes = self.read_bytes(_MAX_READ_SIZE)
 
378
            # than MAX_SOCKET_CHUNK ready, the socket will just return a
 
379
            # short read immediately rather than block.
 
380
            bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
267
381
            if bytes == '':
268
382
                self.finished = True
269
383
                return
270
384
            protocol.accept_bytes(bytes)
271
 
        
 
385
 
272
386
        self._push_back(protocol.unused_data)
273
387
 
 
388
    def _disconnect_client(self):
 
389
        """Close the current connection. We stopped due to a timeout/etc."""
 
390
        self.socket.close()
 
391
 
 
392
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
393
        """Wait for more bytes to be read, but timeout if none available.
 
394
 
 
395
        This allows us to detect idle connections, and stop trying to read from
 
396
        them, without setting the socket itself to non-blocking. This also
 
397
        allows us to specify when we watch for idle timeouts.
 
398
 
 
399
        :return: None, this will raise ConnectionTimeout if we time out before
 
400
            data is available.
 
401
        """
 
402
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
403
 
274
404
    def _read_bytes(self, desired_count):
275
 
        # We ignore the desired_count because on sockets it's more efficient to
276
 
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
277
 
        return osutils.until_no_eintr(self.socket.recv, _MAX_READ_SIZE)
 
405
        return osutils.read_bytes_from_socket(
 
406
            self.socket, self._report_activity)
278
407
 
279
408
    def terminate_due_to_error(self):
280
409
        # TODO: This should log to a server log file, but no such thing
283
412
        self.finished = True
284
413
 
285
414
    def _write_out(self, bytes):
286
 
        osutils.send_all(self.socket, bytes)
 
415
        tstart = osutils.timer_func()
 
416
        osutils.send_all(self.socket, bytes, self._report_activity)
 
417
        if 'hpss' in debug.debug_flags:
 
418
            thread_id = thread.get_ident()
 
419
            trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
 
420
                         % ('wrote', thread_id, len(bytes),
 
421
                            osutils.timer_func() - tstart))
287
422
 
288
423
 
289
424
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
290
425
 
291
 
    def __init__(self, in_file, out_file, backing_transport):
 
426
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
292
427
        """Construct new server.
293
428
 
294
429
        :param in_file: Python file from which requests can be read.
295
430
        :param out_file: Python file to write responses.
296
431
        :param backing_transport: Transport for the directory served.
297
432
        """
298
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
433
        SmartServerStreamMedium.__init__(self, backing_transport,
 
434
            timeout=timeout)
299
435
        if sys.platform == 'win32':
300
436
            # force binary mode for files
301
437
            import msvcrt
306
442
        self._in = in_file
307
443
        self._out = out_file
308
444
 
 
445
    def serve(self):
 
446
        """See SmartServerStreamMedium.serve"""
 
447
        # This is the regular serve, except it adds signal trapping for soft
 
448
        # shutdown.
 
449
        stop_gracefully = self._stop_gracefully
 
450
        signals.register_on_hangup(id(self), stop_gracefully)
 
451
        try:
 
452
            return super(SmartServerPipeStreamMedium, self).serve()
 
453
        finally:
 
454
            signals.unregister_on_hangup(id(self))
 
455
 
309
456
    def _serve_one_request_unguarded(self, protocol):
310
457
        while True:
311
458
            # We need to be careful not to read past the end of the current
324
471
                return
325
472
            protocol.accept_bytes(bytes)
326
473
 
 
474
    def _disconnect_client(self):
 
475
        self._in.close()
 
476
        self._out.flush()
 
477
        self._out.close()
 
478
 
 
479
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
480
        """Wait for more bytes to be read, but timeout if none available.
 
481
 
 
482
        This allows us to detect idle connections, and stop trying to read from
 
483
        them, without setting the socket itself to non-blocking. This also
 
484
        allows us to specify when we watch for idle timeouts.
 
485
 
 
486
        :return: None, this will raise ConnectionTimeout if we time out before
 
487
            data is available.
 
488
        """
 
489
        if (getattr(self._in, 'fileno', None) is None
 
490
            or sys.platform == 'win32'):
 
491
            # You can't select() file descriptors on Windows.
 
492
            return
 
493
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
494
 
327
495
    def _read_bytes(self, desired_count):
328
496
        return self._in.read(desired_count)
329
497
 
350
518
    request.finished_reading()
351
519
 
352
520
    It is up to the individual SmartClientMedium whether multiple concurrent
353
 
    requests can exist. See SmartClientMedium.get_request to obtain instances 
354
 
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
521
    requests can exist. See SmartClientMedium.get_request to obtain instances
 
522
    of SmartClientMediumRequest, and the concrete Medium you are using for
355
523
    details on concurrency and pipelining.
356
524
    """
357
525
 
366
534
    def accept_bytes(self, bytes):
367
535
        """Accept bytes for inclusion in this request.
368
536
 
369
 
        This method may not be be called after finished_writing() has been
 
537
        This method may not be called after finished_writing() has been
370
538
        called.  It depends upon the Medium whether or not the bytes will be
371
539
        immediately transmitted. Message based Mediums will tend to buffer the
372
540
        bytes until finished_writing() is called.
403
571
    def _finished_reading(self):
404
572
        """Helper for finished_reading.
405
573
 
406
 
        finished_reading checks the state of the request to determine if 
 
574
        finished_reading checks the state of the request to determine if
407
575
        finished_reading is allowed, and if it is hands off to _finished_reading
408
576
        to perform the action.
409
577
        """
423
591
    def _finished_writing(self):
424
592
        """Helper for finished_writing.
425
593
 
426
 
        finished_writing checks the state of the request to determine if 
 
594
        finished_writing checks the state of the request to determine if
427
595
        finished_writing is allowed, and if it is hands off to _finished_writing
428
596
        to perform the action.
429
597
        """
449
617
        read_bytes checks the state of the request to determing if bytes
450
618
        should be read. After that it hands off to _read_bytes to do the
451
619
        actual read.
452
 
        
 
620
 
453
621
        By default this forwards to self._medium.read_bytes because we are
454
622
        operating on the medium's stream.
455
623
        """
460
628
        if not line.endswith('\n'):
461
629
            # end of file encountered reading from server
462
630
            raise errors.ConnectionReset(
463
 
                "please check connectivity and permissions",
464
 
                "(and try -Dhpss if further diagnosis is required)")
 
631
                "Unexpected end of message. Please check connectivity "
 
632
                "and permissions, and report a bug if problems persist.")
465
633
        return line
466
634
 
467
635
    def _read_line(self):
468
636
        """Helper for SmartClientMediumRequest.read_line.
469
 
        
 
637
 
470
638
        By default this forwards to self._medium._get_line because we are
471
639
        operating on the medium's stream.
472
640
        """
473
641
        return self._medium._get_line()
474
642
 
475
643
 
 
644
class _VfsRefuser(object):
 
645
    """An object that refuses all VFS requests.
 
646
 
 
647
    """
 
648
 
 
649
    def __init__(self):
 
650
        client._SmartClient.hooks.install_named_hook(
 
651
            'call', self.check_vfs, 'vfs refuser')
 
652
 
 
653
    def check_vfs(self, params):
 
654
        try:
 
655
            request_method = request.request_handlers.get(params.method)
 
656
        except KeyError:
 
657
            # A method we don't know about doesn't count as a VFS method.
 
658
            return
 
659
        if issubclass(request_method, vfs.VfsRequest):
 
660
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
 
661
 
 
662
 
476
663
class _DebugCounter(object):
477
664
    """An object that counts the HPSS calls made to each client medium.
478
665
 
479
 
    When a medium is garbage-collected, or failing that when atexit functions
480
 
    are run, the total number of calls made on that medium are reported via
481
 
    trace.note.
 
666
    When a medium is garbage-collected, or failing that when
 
667
    bzrlib.global_state exits, the total number of calls made on that medium
 
668
    are reported via trace.note.
482
669
    """
483
670
 
484
671
    def __init__(self):
485
672
        self.counts = weakref.WeakKeyDictionary()
486
673
        client._SmartClient.hooks.install_named_hook(
487
674
            'call', self.increment_call_count, 'hpss call counter')
488
 
        atexit.register(self.flush_all)
 
675
        bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
489
676
 
490
677
    def track(self, medium):
491
678
        """Start tracking calls made to a medium.
495
682
        """
496
683
        medium_repr = repr(medium)
497
684
        # Add this medium to the WeakKeyDictionary
498
 
        self.counts[medium] = [0, medium_repr]
 
685
        self.counts[medium] = dict(count=0, vfs_count=0,
 
686
                                   medium_repr=medium_repr)
499
687
        # Weakref callbacks are fired in reverse order of their association
500
688
        # with the referenced object.  So we add a weakref *after* adding to
501
689
        # the WeakKeyDict so that we can report the value from it before the
505
693
    def increment_call_count(self, params):
506
694
        # Increment the count in the WeakKeyDictionary
507
695
        value = self.counts[params.medium]
508
 
        value[0] += 1
 
696
        value['count'] += 1
 
697
        try:
 
698
            request_method = request.request_handlers.get(params.method)
 
699
        except KeyError:
 
700
            # A method we don't know about doesn't count as a VFS method.
 
701
            return
 
702
        if issubclass(request_method, vfs.VfsRequest):
 
703
            value['vfs_count'] += 1
509
704
 
510
705
    def done(self, ref):
511
706
        value = self.counts[ref]
512
 
        count, medium_repr = value
 
707
        count, vfs_count, medium_repr = (
 
708
            value['count'], value['vfs_count'], value['medium_repr'])
513
709
        # In case this callback is invoked for the same ref twice (by the
514
710
        # weakref callback and by the atexit function), set the call count back
515
711
        # to 0 so this item won't be reported twice.
516
 
        value[0] = 0
 
712
        value['count'] = 0
 
713
        value['vfs_count'] = 0
517
714
        if count != 0:
518
 
            trace.note('HPSS calls: %d %s', count, medium_repr)
519
 
        
 
715
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
716
                       count, vfs_count, medium_repr))
 
717
 
520
718
    def flush_all(self):
521
719
        for ref in list(self.counts.keys()):
522
720
            self.done(ref)
523
721
 
524
722
_debug_counter = None
525
 
  
526
 
  
 
723
_vfs_refuser = None
 
724
 
 
725
 
527
726
class SmartClientMedium(SmartMedium):
528
727
    """Smart client is a medium for sending smart protocol requests over."""
529
728
 
544
743
            if _debug_counter is None:
545
744
                _debug_counter = _DebugCounter()
546
745
            _debug_counter.track(self)
 
746
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
747
            global _vfs_refuser
 
748
            if _vfs_refuser is None:
 
749
                _vfs_refuser = _VfsRefuser()
547
750
 
548
751
    def _is_remote_before(self, version_tuple):
549
752
        """Is it possible the remote side supports RPCs for a given version?
574
777
        """
575
778
        if (self._remote_version_is_before is not None and
576
779
            version_tuple > self._remote_version_is_before):
577
 
            raise AssertionError(
 
780
            # We have been told that the remote side is older than some version
 
781
            # which is newer than a previously supplied older-than version.
 
782
            # This indicates that some smart verb call is not guarded
 
783
            # appropriately (it should simply not have been tried).
 
784
            trace.mutter(
578
785
                "_remember_remote_is_before(%r) called, but "
579
786
                "_remember_remote_is_before(%r) was called previously."
580
 
                % (version_tuple, self._remote_version_is_before))
 
787
                , version_tuple, self._remote_version_is_before)
 
788
            if 'hpss' in debug.debug_flags:
 
789
                ui.ui_factory.show_warning(
 
790
                    "_remember_remote_is_before(%r) called, but "
 
791
                    "_remember_remote_is_before(%r) was called previously."
 
792
                    % (version_tuple, self._remote_version_is_before))
 
793
            return
581
794
        self._remote_version_is_before = version_tuple
582
795
 
583
796
    def protocol_version(self):
617
830
 
618
831
    def disconnect(self):
619
832
        """If this medium maintains a persistent connection, close it.
620
 
        
 
833
 
621
834
        The default implementation does nothing.
622
835
        """
623
 
        
 
836
 
624
837
    def remote_path_from_transport(self, transport):
625
838
        """Convert transport into a path suitable for using in a request.
626
 
        
 
839
 
627
840
        Note that the resulting remote path doesn't encode the host name or
628
841
        anything but path, so it is only safe to use it in requests sent over
629
842
        the medium from the matching transport.
657
870
 
658
871
    def _flush(self):
659
872
        """Flush the output stream.
660
 
        
 
873
 
661
874
        This method is used by the SmartClientStreamMediumRequest to ensure that
662
875
        all data for a request is sent, to avoid long timeouts or deadlocks.
663
876
        """
674
887
 
675
888
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
676
889
    """A client medium using simple pipes.
677
 
    
 
890
 
678
891
    This client does not manage the pipes: it assumes they will always be open.
679
892
    """
680
893
 
686
899
    def _accept_bytes(self, bytes):
687
900
        """See SmartClientStreamMedium.accept_bytes."""
688
901
        self._writeable_pipe.write(bytes)
 
902
        self._report_activity(len(bytes), 'write')
689
903
 
690
904
    def _flush(self):
691
905
        """See SmartClientStreamMedium._flush()."""
693
907
 
694
908
    def _read_bytes(self, count):
695
909
        """See SmartClientStreamMedium._read_bytes."""
696
 
        return self._readable_pipe.read(count)
 
910
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
911
        bytes = self._readable_pipe.read(bytes_to_read)
 
912
        self._report_activity(len(bytes), 'read')
 
913
        return bytes
 
914
 
 
915
 
 
916
class SSHParams(object):
 
917
    """A set of parameters for starting a remote bzr via SSH."""
 
918
 
 
919
    def __init__(self, host, port=None, username=None, password=None,
 
920
            bzr_remote_path='bzr'):
 
921
        self.host = host
 
922
        self.port = port
 
923
        self.username = username
 
924
        self.password = password
 
925
        self.bzr_remote_path = bzr_remote_path
697
926
 
698
927
 
699
928
class SmartSSHClientMedium(SmartClientStreamMedium):
700
 
    """A client medium using SSH."""
 
929
    """A client medium using SSH.
701
930
    
702
 
    def __init__(self, host, port=None, username=None, password=None,
703
 
            base=None, vendor=None, bzr_remote_path=None):
 
931
    It delegates IO to a SmartClientSocketMedium or
 
932
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
 
933
    """
 
934
 
 
935
    def __init__(self, base, ssh_params, vendor=None):
704
936
        """Creates a client that will connect on the first use.
705
 
        
 
937
 
 
938
        :param ssh_params: A SSHParams instance.
706
939
        :param vendor: An optional override for the ssh vendor to use. See
707
940
            bzrlib.transport.ssh for details on ssh vendors.
708
941
        """
 
942
        self._real_medium = None
 
943
        self._ssh_params = ssh_params
 
944
        # for the benefit of progress making a short description of this
 
945
        # transport
 
946
        self._scheme = 'bzr+ssh'
 
947
        # SmartClientStreamMedium stores the repr of this object in its
 
948
        # _DebugCounter so we have to store all the values used in our repr
 
949
        # method before calling the super init.
709
950
        SmartClientStreamMedium.__init__(self, base)
710
 
        self._connected = False
711
 
        self._host = host
712
 
        self._password = password
713
 
        self._port = port
714
 
        self._username = username
715
 
        self._read_from = None
 
951
        self._vendor = vendor
716
952
        self._ssh_connection = None
717
 
        self._vendor = vendor
718
 
        self._write_to = None
719
 
        self._bzr_remote_path = bzr_remote_path
720
 
        if self._bzr_remote_path is None:
721
 
            symbol_versioning.warn(
722
 
                'bzr_remote_path is required as of bzr 0.92',
723
 
                DeprecationWarning, stacklevel=2)
724
 
            self._bzr_remote_path = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
953
 
 
954
    def __repr__(self):
 
955
        if self._ssh_params.port is None:
 
956
            maybe_port = ''
 
957
        else:
 
958
            maybe_port = ':%s' % self._ssh_params.port
 
959
        return "%s(%s://%s@%s%s/)" % (
 
960
            self.__class__.__name__,
 
961
            self._scheme,
 
962
            self._ssh_params.username,
 
963
            self._ssh_params.host,
 
964
            maybe_port)
725
965
 
726
966
    def _accept_bytes(self, bytes):
727
967
        """See SmartClientStreamMedium.accept_bytes."""
728
968
        self._ensure_connection()
729
 
        self._write_to.write(bytes)
 
969
        self._real_medium.accept_bytes(bytes)
730
970
 
731
971
    def disconnect(self):
732
972
        """See SmartClientMedium.disconnect()."""
733
 
        if not self._connected:
734
 
            return
735
 
        self._read_from.close()
736
 
        self._write_to.close()
737
 
        self._ssh_connection.close()
738
 
        self._connected = False
 
973
        if self._real_medium is not None:
 
974
            self._real_medium.disconnect()
 
975
            self._real_medium = None
 
976
        if self._ssh_connection is not None:
 
977
            self._ssh_connection.close()
 
978
            self._ssh_connection = None
739
979
 
740
980
    def _ensure_connection(self):
741
981
        """Connect this medium if not already connected."""
742
 
        if self._connected:
 
982
        if self._real_medium is not None:
743
983
            return
744
984
        if self._vendor is None:
745
985
            vendor = ssh._get_ssh_vendor()
746
986
        else:
747
987
            vendor = self._vendor
748
 
        self._ssh_connection = vendor.connect_ssh(self._username,
749
 
                self._password, self._host, self._port,
750
 
                command=[self._bzr_remote_path, 'serve', '--inet',
 
988
        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
 
989
                self._ssh_params.password, self._ssh_params.host,
 
990
                self._ssh_params.port,
 
991
                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
751
992
                         '--directory=/', '--allow-writes'])
752
 
        self._read_from, self._write_to = \
753
 
            self._ssh_connection.get_filelike_channels()
754
 
        self._connected = True
 
993
        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
 
994
        if io_kind == 'socket':
 
995
            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
 
996
                self.base, io_object)
 
997
        elif io_kind == 'pipes':
 
998
            read_from, write_to = io_object
 
999
            self._real_medium = SmartSimplePipesClientMedium(
 
1000
                read_from, write_to, self.base)
 
1001
        else:
 
1002
            raise AssertionError(
 
1003
                "Unexpected io_kind %r from %r"
 
1004
                % (io_kind, self._ssh_connection))
755
1005
 
756
1006
    def _flush(self):
757
1007
        """See SmartClientStreamMedium._flush()."""
758
 
        self._write_to.flush()
 
1008
        self._real_medium._flush()
759
1009
 
760
1010
    def _read_bytes(self, count):
761
1011
        """See SmartClientStreamMedium.read_bytes."""
762
 
        if not self._connected:
 
1012
        if self._real_medium is None:
763
1013
            raise errors.MediumNotConnected(self)
764
 
        bytes_to_read = min(count, _MAX_READ_SIZE)
765
 
        return self._read_from.read(bytes_to_read)
 
1014
        return self._real_medium.read_bytes(count)
766
1015
 
767
1016
 
768
1017
# Port 4155 is the default port for bzr://, registered with IANA.
770
1019
BZR_DEFAULT_PORT = 4155
771
1020
 
772
1021
 
773
 
class SmartTCPClientMedium(SmartClientStreamMedium):
774
 
    """A client medium using TCP."""
 
1022
class SmartClientSocketMedium(SmartClientStreamMedium):
 
1023
    """A client medium using a socket.
775
1024
    
 
1025
    This class isn't usable directly.  Use one of its subclasses instead.
 
1026
    """
 
1027
 
 
1028
    def __init__(self, base):
 
1029
        SmartClientStreamMedium.__init__(self, base)
 
1030
        self._socket = None
 
1031
        self._connected = False
 
1032
 
 
1033
    def _accept_bytes(self, bytes):
 
1034
        """See SmartClientMedium.accept_bytes."""
 
1035
        self._ensure_connection()
 
1036
        osutils.send_all(self._socket, bytes, self._report_activity)
 
1037
 
 
1038
    def _ensure_connection(self):
 
1039
        """Connect this medium if not already connected."""
 
1040
        raise NotImplementedError(self._ensure_connection)
 
1041
 
 
1042
    def _flush(self):
 
1043
        """See SmartClientStreamMedium._flush().
 
1044
 
 
1045
        For sockets we do no flushing. For TCP sockets we may want to turn off
 
1046
        TCP_NODELAY and add a means to do a flush, but that can be done in the
 
1047
        future.
 
1048
        """
 
1049
 
 
1050
    def _read_bytes(self, count):
 
1051
        """See SmartClientMedium.read_bytes."""
 
1052
        if not self._connected:
 
1053
            raise errors.MediumNotConnected(self)
 
1054
        return osutils.read_bytes_from_socket(
 
1055
            self._socket, self._report_activity)
 
1056
 
 
1057
    def disconnect(self):
 
1058
        """See SmartClientMedium.disconnect()."""
 
1059
        if not self._connected:
 
1060
            return
 
1061
        self._socket.close()
 
1062
        self._socket = None
 
1063
        self._connected = False
 
1064
 
 
1065
 
 
1066
class SmartTCPClientMedium(SmartClientSocketMedium):
 
1067
    """A client medium that creates a TCP connection."""
 
1068
 
776
1069
    def __init__(self, host, port, base):
777
1070
        """Creates a client that will connect on the first use."""
778
 
        SmartClientStreamMedium.__init__(self, base)
779
 
        self._connected = False
 
1071
        SmartClientSocketMedium.__init__(self, base)
780
1072
        self._host = host
781
1073
        self._port = port
782
 
        self._socket = None
783
 
 
784
 
    def _accept_bytes(self, bytes):
785
 
        """See SmartClientMedium.accept_bytes."""
786
 
        self._ensure_connection()
787
 
        osutils.send_all(self._socket, bytes)
788
 
 
789
 
    def disconnect(self):
790
 
        """See SmartClientMedium.disconnect()."""
791
 
        if not self._connected:
792
 
            return
793
 
        self._socket.close()
794
 
        self._socket = None
795
 
        self._connected = False
796
1074
 
797
1075
    def _ensure_connection(self):
798
1076
        """Connect this medium if not already connected."""
803
1081
        else:
804
1082
            port = int(self._port)
805
1083
        try:
806
 
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC, 
 
1084
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
807
1085
                socket.SOCK_STREAM, 0, 0)
808
1086
        except socket.gaierror, (err_num, err_msg):
809
1087
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
813
1091
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
814
1092
            try:
815
1093
                self._socket = socket.socket(family, socktype, proto)
816
 
                self._socket.setsockopt(socket.IPPROTO_TCP, 
 
1094
                self._socket.setsockopt(socket.IPPROTO_TCP,
817
1095
                                        socket.TCP_NODELAY, 1)
818
1096
                self._socket.connect(sockaddr)
819
1097
            except socket.error, err:
833
1111
                    (self._host, port, err_msg))
834
1112
        self._connected = True
835
1113
 
836
 
    def _flush(self):
837
 
        """See SmartClientStreamMedium._flush().
838
 
        
839
 
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
840
 
        add a means to do a flush, but that can be done in the future.
841
 
        """
842
 
 
843
 
    def _read_bytes(self, count):
844
 
        """See SmartClientMedium.read_bytes."""
845
 
        if not self._connected:
846
 
            raise errors.MediumNotConnected(self)
847
 
        # We ignore the desired_count because on sockets it's more efficient to
848
 
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
849
 
        try:
850
 
            return self._socket.recv(_MAX_READ_SIZE)
851
 
        except socket.error, e:
852
 
            if len(e.args) and e.args[0] == errno.ECONNRESET:
853
 
                # Callers expect an empty string in that case
854
 
                return ''
855
 
            else:
856
 
                raise
 
1114
 
 
1115
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
 
1116
    """A client medium for an already connected socket.
 
1117
    
 
1118
    Note that this class will assume it "owns" the socket, so it will close it
 
1119
    when its disconnect method is called.
 
1120
    """
 
1121
 
 
1122
    def __init__(self, base, sock):
 
1123
        SmartClientSocketMedium.__init__(self, base)
 
1124
        self._socket = sock
 
1125
        self._connected = True
 
1126
 
 
1127
    def _ensure_connection(self):
 
1128
        # Already connected, by definition!  So nothing to do.
 
1129
        pass
857
1130
 
858
1131
 
859
1132
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
872
1145
 
873
1146
    def _accept_bytes(self, bytes):
874
1147
        """See SmartClientMediumRequest._accept_bytes.
875
 
        
 
1148
 
876
1149
        This forwards to self._medium._accept_bytes because we are operating
877
1150
        on the mediums stream.
878
1151
        """
881
1154
    def _finished_reading(self):
882
1155
        """See SmartClientMediumRequest._finished_reading.
883
1156
 
884
 
        This clears the _current_request on self._medium to allow a new 
 
1157
        This clears the _current_request on self._medium to allow a new
885
1158
        request to be created.
886
1159
        """
887
1160
        if self._medium._current_request is not self:
888
1161
            raise AssertionError()
889
1162
        self._medium._current_request = None
890
 
        
 
1163
 
891
1164
    def _finished_writing(self):
892
1165
        """See SmartClientMediumRequest._finished_writing.
893
1166
 
895
1168
        """
896
1169
        self._medium._flush()
897
1170
 
 
1171