~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: John Arbash Meinel
  • Date: 2008-10-30 00:55:00 UTC
  • mto: (3815.2.5 prepare-1.9)
  • mto: This revision was merged to the branch mainline in revision 3811.
  • Revision ID: john@arbash-meinel.com-20081030005500-r5cej1cxflqhs3io
Switch so that we are using a simple timestamp as the first action.

Show diffs side-by-side

added added

removed removed

Lines of Context:
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
import errno
27
28
import os
28
29
import socket
29
30
import sys
30
31
import urllib
31
32
 
 
33
from bzrlib.lazy_import import lazy_import
 
34
lazy_import(globals(), """
 
35
import atexit
 
36
import weakref
32
37
from bzrlib import (
 
38
    debug,
33
39
    errors,
34
40
    osutils,
35
41
    symbol_versioning,
 
42
    trace,
36
43
    urlutils,
37
44
    )
38
 
from bzrlib.smart.protocol import (
39
 
    MESSAGE_VERSION_THREE,
40
 
    REQUEST_VERSION_TWO,
41
 
    SmartClientRequestProtocolOne,
42
 
    SmartServerRequestProtocolOne,
43
 
    SmartServerRequestProtocolTwo,
44
 
    build_server_protocol_three
45
 
    )
 
45
from bzrlib.smart import client, protocol
46
46
from bzrlib.transport import ssh
 
47
""")
 
48
 
 
49
 
 
50
# We must not read any more than 64k at a time so we don't risk "no buffer
 
51
# space available" errors on some platforms.  Windows in particular is likely
 
52
# to give error 10053 or 10055 if we read more than 64k from a socket.
 
53
_MAX_READ_SIZE = 64 * 1024
47
54
 
48
55
 
49
56
def _get_protocol_factory_for_bytes(bytes):
67
74
        root_client_path.  unused_bytes are any bytes that were not part of a
68
75
        protocol version marker.
69
76
    """
70
 
    if bytes.startswith(MESSAGE_VERSION_THREE):
71
 
        protocol_factory = build_server_protocol_three
72
 
        bytes = bytes[len(MESSAGE_VERSION_THREE):]
73
 
    elif bytes.startswith(REQUEST_VERSION_TWO):
74
 
        protocol_factory = SmartServerRequestProtocolTwo
75
 
        bytes = bytes[len(REQUEST_VERSION_TWO):]
 
77
    if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
 
78
        protocol_factory = protocol.build_server_protocol_three
 
79
        bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
 
80
    elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
 
81
        protocol_factory = protocol.SmartServerRequestProtocolTwo
 
82
        bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
76
83
    else:
77
 
        protocol_factory = SmartServerRequestProtocolOne
 
84
        protocol_factory = protocol.SmartServerRequestProtocolOne
78
85
    return protocol_factory, bytes
79
86
 
80
87
 
81
 
class SmartServerStreamMedium(object):
 
88
def _get_line(read_bytes_func):
 
89
    """Read bytes using read_bytes_func until a newline byte.
 
90
    
 
91
    This isn't particularly efficient, so should only be used when the
 
92
    expected size of the line is quite short.
 
93
    
 
94
    :returns: a tuple of two strs: (line, excess)
 
95
    """
 
96
    newline_pos = -1
 
97
    bytes = ''
 
98
    while newline_pos == -1:
 
99
        new_bytes = read_bytes_func(1)
 
100
        bytes += new_bytes
 
101
        if new_bytes == '':
 
102
            # Ran out of bytes before receiving a complete line.
 
103
            return bytes, ''
 
104
        newline_pos = bytes.find('\n')
 
105
    line = bytes[:newline_pos+1]
 
106
    excess = bytes[newline_pos+1:]
 
107
    return line, excess
 
108
 
 
109
 
 
110
class SmartMedium(object):
 
111
    """Base class for smart protocol media, both client- and server-side."""
 
112
 
 
113
    def __init__(self):
 
114
        self._push_back_buffer = None
 
115
        
 
116
    def _push_back(self, bytes):
 
117
        """Return unused bytes to the medium, because they belong to the next
 
118
        request(s).
 
119
 
 
120
        This sets the _push_back_buffer to the given bytes.
 
121
        """
 
122
        if self._push_back_buffer is not None:
 
123
            raise AssertionError(
 
124
                "_push_back called when self._push_back_buffer is %r"
 
125
                % (self._push_back_buffer,))
 
126
        if bytes == '':
 
127
            return
 
128
        self._push_back_buffer = bytes
 
129
 
 
130
    def _get_push_back_buffer(self):
 
131
        if self._push_back_buffer == '':
 
132
            raise AssertionError(
 
133
                '%s._push_back_buffer should never be the empty string, '
 
134
                'which can be confused with EOF' % (self,))
 
135
        bytes = self._push_back_buffer
 
136
        self._push_back_buffer = None
 
137
        return bytes
 
138
 
 
139
    def read_bytes(self, desired_count):
 
140
        """Read some bytes from this medium.
 
141
 
 
142
        :returns: some bytes, possibly more or less than the number requested
 
143
            in 'desired_count' depending on the medium.
 
144
        """
 
145
        if self._push_back_buffer is not None:
 
146
            return self._get_push_back_buffer()
 
147
        bytes_to_read = min(desired_count, _MAX_READ_SIZE)
 
148
        return self._read_bytes(bytes_to_read)
 
149
 
 
150
    def _read_bytes(self, count):
 
151
        raise NotImplementedError(self._read_bytes)
 
152
 
 
153
    def _get_line(self):
 
154
        """Read bytes from this request's response until a newline byte.
 
155
        
 
156
        This isn't particularly efficient, so should only be used when the
 
157
        expected size of the line is quite short.
 
158
 
 
159
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
160
        """
 
161
        line, excess = _get_line(self.read_bytes)
 
162
        self._push_back(excess)
 
163
        return line
 
164
 
 
165
 
 
166
class SmartServerStreamMedium(SmartMedium):
82
167
    """Handles smart commands coming over a stream.
83
168
 
84
169
    The stream may be a pipe connected to sshd, or a tcp socket, or an
105
190
        self.backing_transport = backing_transport
106
191
        self.root_client_path = root_client_path
107
192
        self.finished = False
108
 
        self._push_back_buffer = None
109
 
 
110
 
    def _push_back(self, bytes):
111
 
        """Return unused bytes to the medium, because they belong to the next
112
 
        request(s).
113
 
 
114
 
        This sets the _push_back_buffer to the given bytes.
115
 
        """
116
 
        if self._push_back_buffer is not None:
117
 
            raise AssertionError(
118
 
                "_push_back called when self._push_back_buffer is %r"
119
 
                % (self._push_back_buffer,))
120
 
        if bytes == '':
121
 
            return
122
 
        self._push_back_buffer = bytes
123
 
 
124
 
    def _get_push_back_buffer(self):
125
 
        if self._push_back_buffer == '':
126
 
            raise AssertionError(
127
 
                '%s._push_back_buffer should never be the empty string, '
128
 
                'which can be confused with EOF' % (self,))
129
 
        bytes = self._push_back_buffer
130
 
        self._push_back_buffer = None
131
 
        return bytes
 
193
        SmartMedium.__init__(self)
132
194
 
133
195
    def serve(self):
134
196
        """Serve requests until the client disconnects."""
175
237
        """Called when an unhandled exception from the protocol occurs."""
176
238
        raise NotImplementedError(self.terminate_due_to_error)
177
239
 
178
 
    def _get_bytes(self, desired_count):
 
240
    def _read_bytes(self, desired_count):
179
241
        """Get some bytes from the medium.
180
242
 
181
243
        :param desired_count: number of bytes we want to read.
182
244
        """
183
 
        raise NotImplementedError(self._get_bytes)
184
 
 
185
 
    def _get_line(self):
186
 
        """Read bytes from this request's response until a newline byte.
187
 
        
188
 
        This isn't particularly efficient, so should only be used when the
189
 
        expected size of the line is quite short.
190
 
 
191
 
        :returns: a string of bytes ending in a newline (byte 0x0A).
192
 
        """
193
 
        newline_pos = -1
194
 
        bytes = ''
195
 
        while newline_pos == -1:
196
 
            new_bytes = self._get_bytes(1)
197
 
            bytes += new_bytes
198
 
            if new_bytes == '':
199
 
                # Ran out of bytes before receiving a complete line.
200
 
                return bytes
201
 
            newline_pos = bytes.find('\n')
202
 
        line = bytes[:newline_pos+1]
203
 
        self._push_back(bytes[newline_pos+1:])
204
 
        return line
205
 
 
 
245
        raise NotImplementedError(self._read_bytes)
 
246
 
206
247
 
207
248
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
208
249
 
219
260
 
220
261
    def _serve_one_request_unguarded(self, protocol):
221
262
        while protocol.next_read_size():
222
 
            bytes = self._get_bytes(4096)
 
263
            # We can safely try to read large chunks.  If there is less data
 
264
            # than _MAX_READ_SIZE ready, the socket wil just return a short
 
265
            # read immediately rather than block.
 
266
            bytes = self.read_bytes(_MAX_READ_SIZE)
223
267
            if bytes == '':
224
268
                self.finished = True
225
269
                return
227
271
        
228
272
        self._push_back(protocol.unused_data)
229
273
 
230
 
    def _get_bytes(self, desired_count):
231
 
        if self._push_back_buffer is not None:
232
 
            return self._get_push_back_buffer()
 
274
    def _read_bytes(self, desired_count):
233
275
        # We ignore the desired_count because on sockets it's more efficient to
234
 
        # read 4k at a time.
235
 
        return self.socket.recv(4096)
236
 
    
 
276
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
277
        return self.socket.recv(_MAX_READ_SIZE)
 
278
 
237
279
    def terminate_due_to_error(self):
238
280
        # TODO: This should log to a server log file, but no such thing
239
281
        # exists yet.  Andrew Bennetts 2006-09-29.
266
308
 
267
309
    def _serve_one_request_unguarded(self, protocol):
268
310
        while True:
 
311
            # We need to be careful not to read past the end of the current
 
312
            # request, or else the read from the pipe will block, so we use
 
313
            # protocol.next_read_size().
269
314
            bytes_to_read = protocol.next_read_size()
270
315
            if bytes_to_read == 0:
271
316
                # Finished serving this request.
272
317
                self._out.flush()
273
318
                return
274
 
            bytes = self._get_bytes(bytes_to_read)
 
319
            bytes = self.read_bytes(bytes_to_read)
275
320
            if bytes == '':
276
321
                # Connection has been closed.
277
322
                self.finished = True
279
324
                return
280
325
            protocol.accept_bytes(bytes)
281
326
 
282
 
    def _get_bytes(self, desired_count):
283
 
        if self._push_back_buffer is not None:
284
 
            return self._get_push_back_buffer()
 
327
    def _read_bytes(self, desired_count):
285
328
        return self._in.read(desired_count)
286
329
 
287
330
    def terminate_due_to_error(self):
401
444
        return self._read_bytes(count)
402
445
 
403
446
    def _read_bytes(self, count):
404
 
        """Helper for read_bytes.
 
447
        """Helper for SmartClientMediumRequest.read_bytes.
405
448
 
406
449
        read_bytes checks the state of the request to determing if bytes
407
450
        should be read. After that it hands off to _read_bytes to do the
408
451
        actual read.
 
452
        
 
453
        By default this forwards to self._medium.read_bytes because we are
 
454
        operating on the medium's stream.
409
455
        """
410
 
        raise NotImplementedError(self._read_bytes)
 
456
        return self._medium.read_bytes(count)
411
457
 
412
458
    def read_line(self):
413
 
        """Read bytes from this request's response until a newline byte.
414
 
        
415
 
        This isn't particularly efficient, so should only be used when the
416
 
        expected size of the line is quite short.
417
 
 
418
 
        :returns: a string of bytes ending in a newline (byte 0x0A).
419
 
        """
420
 
        # XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
421
 
        line = ''
422
 
        while not line or line[-1] != '\n':
423
 
            new_char = self.read_bytes(1)
424
 
            line += new_char
425
 
            if new_char == '':
426
 
                # end of file encountered reading from server
427
 
                raise errors.ConnectionReset(
428
 
                    "please check connectivity and permissions",
429
 
                    "(and try -Dhpss if further diagnosis is required)")
 
459
        line = self._read_line()
 
460
        if not line.endswith('\n'):
 
461
            # end of file encountered reading from server
 
462
            raise errors.ConnectionReset(
 
463
                "please check connectivity and permissions",
 
464
                "(and try -Dhpss if further diagnosis is required)")
430
465
        return line
431
466
 
432
 
 
433
 
class SmartClientMedium(object):
 
467
    def _read_line(self):
 
468
        """Helper for SmartClientMediumRequest.read_line.
 
469
        
 
470
        By default this forwards to self._medium._get_line because we are
 
471
        operating on the medium's stream.
 
472
        """
 
473
        return self._medium._get_line()
 
474
 
 
475
 
 
476
class _DebugCounter(object):
 
477
    """An object that counts the HPSS calls made to each client medium.
 
478
 
 
479
    When a medium is garbage-collected, or failing that when atexit functions
 
480
    are run, the total number of calls made on that medium are reported via
 
481
    trace.note.
 
482
    """
 
483
 
 
484
    def __init__(self):
 
485
        self.counts = weakref.WeakKeyDictionary()
 
486
        client._SmartClient.hooks.install_named_hook(
 
487
            'call', self.increment_call_count, 'hpss call counter')
 
488
        atexit.register(self.flush_all)
 
489
 
 
490
    def track(self, medium):
 
491
        """Start tracking calls made to a medium.
 
492
 
 
493
        This only keeps a weakref to the medium, so shouldn't affect the
 
494
        medium's lifetime.
 
495
        """
 
496
        medium_repr = repr(medium)
 
497
        # Add this medium to the WeakKeyDictionary
 
498
        self.counts[medium] = [0, medium_repr]
 
499
        # Weakref callbacks are fired in reverse order of their association
 
500
        # with the referenced object.  So we add a weakref *after* adding to
 
501
        # the WeakKeyDict so that we can report the value from it before the
 
502
        # entry is removed by the WeakKeyDict's own callback.
 
503
        ref = weakref.ref(medium, self.done)
 
504
 
 
505
    def increment_call_count(self, params):
 
506
        # Increment the count in the WeakKeyDictionary
 
507
        value = self.counts[params.medium]
 
508
        value[0] += 1
 
509
 
 
510
    def done(self, ref):
 
511
        value = self.counts[ref]
 
512
        count, medium_repr = value
 
513
        # In case this callback is invoked for the same ref twice (by the
 
514
        # weakref callback and by the atexit function), set the call count back
 
515
        # to 0 so this item won't be reported twice.
 
516
        value[0] = 0
 
517
        if count != 0:
 
518
            trace.note('HPSS calls: %d %s', count, medium_repr)
 
519
        
 
520
    def flush_all(self):
 
521
        for ref in list(self.counts.keys()):
 
522
            self.done(ref)
 
523
 
 
524
_debug_counter = None
 
525
  
 
526
  
 
527
class SmartClientMedium(SmartMedium):
434
528
    """Smart client is a medium for sending smart protocol requests over."""
435
529
 
436
530
    def __init__(self, base):
440
534
        self._protocol_version = None
441
535
        self._done_hello = False
442
536
        # Be optimistic: we assume the remote end can accept new remote
443
 
        # requests until we get an error saying otherwise.  (1.2 adds some
444
 
        # requests that send bodies, which confuses older servers.)
445
 
        self._remote_is_at_least_1_2 = True
 
537
        # requests until we get an error saying otherwise.
 
538
        # _remote_version_is_before tracks the bzr version the remote side
 
539
        # can be based on what we've seen so far.
 
540
        self._remote_version_is_before = None
 
541
        # Install debug hook function if debug flag is set.
 
542
        if 'hpss' in debug.debug_flags:
 
543
            global _debug_counter
 
544
            if _debug_counter is None:
 
545
                _debug_counter = _DebugCounter()
 
546
            _debug_counter.track(self)
 
547
 
 
548
    def _is_remote_before(self, version_tuple):
 
549
        """Is it possible the remote side supports RPCs for a given version?
 
550
 
 
551
        Typical use::
 
552
 
 
553
            needed_version = (1, 2)
 
554
            if medium._is_remote_before(needed_version):
 
555
                fallback_to_pre_1_2_rpc()
 
556
            else:
 
557
                try:
 
558
                    do_1_2_rpc()
 
559
                except UnknownSmartMethod:
 
560
                    medium._remember_remote_is_before(needed_version)
 
561
                    fallback_to_pre_1_2_rpc()
 
562
 
 
563
        :seealso: _remember_remote_is_before
 
564
        """
 
565
        if self._remote_version_is_before is None:
 
566
            # So far, the remote side seems to support everything
 
567
            return False
 
568
        return version_tuple >= self._remote_version_is_before
 
569
 
 
570
    def _remember_remote_is_before(self, version_tuple):
 
571
        """Tell this medium that the remote side is older the given version.
 
572
 
 
573
        :seealso: _is_remote_before
 
574
        """
 
575
        if (self._remote_version_is_before is not None and
 
576
            version_tuple > self._remote_version_is_before):
 
577
            raise AssertionError(
 
578
                "_remember_remote_is_before(%r) called, but "
 
579
                "_remember_remote_is_before(%r) was called previously."
 
580
                % (version_tuple, self._remote_version_is_before))
 
581
        self._remote_version_is_before = version_tuple
446
582
 
447
583
    def protocol_version(self):
448
584
        """Find out if 'hello' smart request works."""
453
589
                medium_request = self.get_request()
454
590
                # Send a 'hello' request in protocol version one, for maximum
455
591
                # backwards compatibility.
456
 
                client_protocol = SmartClientRequestProtocolOne(medium_request)
 
592
                client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
457
593
                client_protocol.query_version()
458
594
                self._done_hello = True
459
595
            except errors.SmartProtocolError, e:
535
671
        """
536
672
        return SmartClientStreamMediumRequest(self)
537
673
 
538
 
    def read_bytes(self, count):
539
 
        return self._read_bytes(count)
540
 
 
541
674
 
542
675
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
543
676
    """A client medium using simple pipes.
628
761
        """See SmartClientStreamMedium.read_bytes."""
629
762
        if not self._connected:
630
763
            raise errors.MediumNotConnected(self)
631
 
        return self._read_from.read(count)
 
764
        bytes_to_read = min(count, _MAX_READ_SIZE)
 
765
        return self._read_from.read(bytes_to_read)
632
766
 
633
767
 
634
768
# Port 4155 is the default port for bzr://, registered with IANA.
635
 
BZR_DEFAULT_INTERFACE = '0.0.0.0'
 
769
BZR_DEFAULT_INTERFACE = None
636
770
BZR_DEFAULT_PORT = 4155
637
771
 
638
772
 
664
798
        """Connect this medium if not already connected."""
665
799
        if self._connected:
666
800
            return
667
 
        self._socket = socket.socket()
668
 
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
669
801
        if self._port is None:
670
802
            port = BZR_DEFAULT_PORT
671
803
        else:
672
804
            port = int(self._port)
673
805
        try:
674
 
            self._socket.connect((self._host, port))
675
 
        except socket.error, err:
 
806
            sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC, 
 
807
                socket.SOCK_STREAM, 0, 0)
 
808
        except socket.gaierror, (err_num, err_msg):
 
809
            raise errors.ConnectionError("failed to lookup %s:%d: %s" %
 
810
                    (self._host, port, err_msg))
 
811
        # Initialize err in case there are no addresses returned:
 
812
        err = socket.error("no address found for %s" % self._host)
 
813
        for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
 
814
            try:
 
815
                self._socket = socket.socket(family, socktype, proto)
 
816
                self._socket.setsockopt(socket.IPPROTO_TCP, 
 
817
                                        socket.TCP_NODELAY, 1)
 
818
                self._socket.connect(sockaddr)
 
819
            except socket.error, err:
 
820
                if self._socket is not None:
 
821
                    self._socket.close()
 
822
                self._socket = None
 
823
                continue
 
824
            break
 
825
        if self._socket is None:
676
826
            # socket errors either have a (string) or (errno, string) as their
677
827
            # args.
678
828
            if type(err.args) is str:
694
844
        """See SmartClientMedium.read_bytes."""
695
845
        if not self._connected:
696
846
            raise errors.MediumNotConnected(self)
697
 
        return self._socket.recv(count)
 
847
        # We ignore the desired_count because on sockets it's more efficient to
 
848
        # read large chunks (of _MAX_READ_SIZE bytes) at a time.
 
849
        try:
 
850
            return self._socket.recv(_MAX_READ_SIZE)
 
851
        except socket.error, e:
 
852
            if len(e.args) and e.args[0] == errno.ECONNRESET:
 
853
                # Callers expect an empty string in that case
 
854
                return ''
 
855
            else:
 
856
                raise
698
857
 
699
858
 
700
859
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
736
895
        """
737
896
        self._medium._flush()
738
897
 
739
 
    def _read_bytes(self, count):
740
 
        """See SmartClientMediumRequest._read_bytes.
741
 
        
742
 
        This forwards to self._medium._read_bytes because we are operating
743
 
        on the mediums stream.
744
 
        """
745
 
        return self._medium._read_bytes(count)
746