~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: John Arbash Meinel
  • Date: 2008-05-29 19:46:01 UTC
  • mfrom: (3456 +trunk)
  • mto: This revision was merged to the branch mainline in revision 3459.
  • Revision ID: john@arbash-meinel.com-20080529194601-r2gpmk536xin9c4a
merge bzr.dev, put the NEWS entry in the right place

Show diffs side-by-side

added added

removed removed

Lines of Context:
27
27
import os
28
28
import socket
29
29
import sys
 
30
import urllib
30
31
 
31
32
from bzrlib import (
32
33
    errors,
33
34
    osutils,
34
35
    symbol_versioning,
 
36
    urlutils,
35
37
    )
36
38
from bzrlib.smart.protocol import (
 
39
    MESSAGE_VERSION_THREE,
37
40
    REQUEST_VERSION_TWO,
 
41
    SmartClientRequestProtocolOne,
38
42
    SmartServerRequestProtocolOne,
39
43
    SmartServerRequestProtocolTwo,
 
44
    build_server_protocol_three
40
45
    )
41
46
from bzrlib.transport import ssh
42
47
 
43
48
 
 
49
def _get_protocol_factory_for_bytes(bytes):
 
50
    """Determine the right protocol factory for 'bytes'.
 
51
 
 
52
    This will return an appropriate protocol factory depending on the version
 
53
    of the protocol being used, as determined by inspecting the given bytes.
 
54
    The bytes should have at least one newline byte (i.e. be a whole line),
 
55
    otherwise it's possible that a request will be incorrectly identified as
 
56
    version 1.
 
57
 
 
58
    Typical use would be::
 
59
 
 
60
         factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
61
         server_protocol = factory(transport, write_func, root_client_path)
 
62
         server_protocol.accept_bytes(unused_bytes)
 
63
 
 
64
    :param bytes: a str of bytes of the start of the request.
 
65
    :returns: 2-tuple of (protocol_factory, unused_bytes).  protocol_factory is
 
66
        a callable that takes three args: transport, write_func,
 
67
        root_client_path.  unused_bytes are any bytes that were not part of a
 
68
        protocol version marker.
 
69
    """
 
70
    if bytes.startswith(MESSAGE_VERSION_THREE):
 
71
        protocol_factory = build_server_protocol_three
 
72
        bytes = bytes[len(MESSAGE_VERSION_THREE):]
 
73
    elif bytes.startswith(REQUEST_VERSION_TWO):
 
74
        protocol_factory = SmartServerRequestProtocolTwo
 
75
        bytes = bytes[len(REQUEST_VERSION_TWO):]
 
76
    else:
 
77
        protocol_factory = SmartServerRequestProtocolOne
 
78
    return protocol_factory, bytes
 
79
 
 
80
 
44
81
class SmartServerStreamMedium(object):
45
82
    """Handles smart commands coming over a stream.
46
83
 
52
89
 
53
90
    The server passes requests through to an underlying backing transport, 
54
91
    which will typically be a LocalTransport looking at the server's filesystem.
 
92
 
 
93
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
94
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
95
        should make sure to exhaust this buffer before reading more bytes from
 
96
        the stream.  See also the _push_back method.
55
97
    """
56
98
 
57
 
    def __init__(self, backing_transport):
 
99
    def __init__(self, backing_transport, root_client_path='/'):
58
100
        """Construct new server.
59
101
 
60
102
        :param backing_transport: Transport for the directory served.
61
103
        """
62
104
        # backing_transport could be passed to serve instead of __init__
63
105
        self.backing_transport = backing_transport
 
106
        self.root_client_path = root_client_path
64
107
        self.finished = False
 
108
        self._push_back_buffer = None
 
109
 
 
110
    def _push_back(self, bytes):
 
111
        """Return unused bytes to the medium, because they belong to the next
 
112
        request(s).
 
113
 
 
114
        This sets the _push_back_buffer to the given bytes.
 
115
        """
 
116
        if self._push_back_buffer is not None:
 
117
            raise AssertionError(
 
118
                "_push_back called when self._push_back_buffer is %r"
 
119
                % (self._push_back_buffer,))
 
120
        if bytes == '':
 
121
            return
 
122
        self._push_back_buffer = bytes
 
123
 
 
124
    def _get_push_back_buffer(self):
 
125
        if self._push_back_buffer == '':
 
126
            raise AssertionError(
 
127
                '%s._push_back_buffer should never be the empty string, '
 
128
                'which can be confused with EOF' % (self,))
 
129
        bytes = self._push_back_buffer
 
130
        self._push_back_buffer = None
 
131
        return bytes
65
132
 
66
133
    def serve(self):
67
134
        """Serve requests until the client disconnects."""
85
152
 
86
153
        :returns: a SmartServerRequestProtocol.
87
154
        """
88
 
        # Identify the protocol version.
89
155
        bytes = self._get_line()
90
 
        if bytes.startswith(REQUEST_VERSION_TWO):
91
 
            protocol_class = SmartServerRequestProtocolTwo
92
 
            bytes = bytes[len(REQUEST_VERSION_TWO):]
93
 
        else:
94
 
            protocol_class = SmartServerRequestProtocolOne
95
 
        protocol = protocol_class(self.backing_transport, self._write_out)
96
 
        protocol.accept_bytes(bytes)
 
156
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
 
157
        protocol = protocol_factory(
 
158
            self.backing_transport, self._write_out, self.root_client_path)
 
159
        protocol.accept_bytes(unused_bytes)
97
160
        return protocol
98
161
 
99
162
    def _serve_one_request(self, protocol):
127
190
 
128
191
        :returns: a string of bytes ending in a newline (byte 0x0A).
129
192
        """
130
 
        # XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
131
 
        line = ''
132
 
        while not line or line[-1] != '\n':
133
 
            new_char = self._get_bytes(1)
134
 
            line += new_char
135
 
            if new_char == '':
 
193
        newline_pos = -1
 
194
        bytes = ''
 
195
        while newline_pos == -1:
 
196
            new_bytes = self._get_bytes(1)
 
197
            bytes += new_bytes
 
198
            if new_bytes == '':
136
199
                # Ran out of bytes before receiving a complete line.
137
 
                break
 
200
                return bytes
 
201
            newline_pos = bytes.find('\n')
 
202
        line = bytes[:newline_pos+1]
 
203
        self._push_back(bytes[newline_pos+1:])
138
204
        return line
139
 
 
 
205
 
140
206
 
141
207
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
142
208
 
143
 
    def __init__(self, sock, backing_transport):
 
209
    def __init__(self, sock, backing_transport, root_client_path='/'):
144
210
        """Constructor.
145
211
 
146
212
        :param sock: the socket the server will read from.  It will be put
147
213
            into blocking mode.
148
214
        """
149
 
        SmartServerStreamMedium.__init__(self, backing_transport)
150
 
        self.push_back = ''
 
215
        SmartServerStreamMedium.__init__(
 
216
            self, backing_transport, root_client_path=root_client_path)
151
217
        sock.setblocking(True)
152
218
        self.socket = sock
153
219
 
154
220
    def _serve_one_request_unguarded(self, protocol):
155
221
        while protocol.next_read_size():
156
 
            if self.push_back:
157
 
                protocol.accept_bytes(self.push_back)
158
 
                self.push_back = ''
159
 
            else:
160
 
                bytes = self._get_bytes(4096)
161
 
                if bytes == '':
162
 
                    self.finished = True
163
 
                    return
164
 
                protocol.accept_bytes(bytes)
 
222
            bytes = self._get_bytes(4096)
 
223
            if bytes == '':
 
224
                self.finished = True
 
225
                return
 
226
            protocol.accept_bytes(bytes)
165
227
        
166
 
        self.push_back = protocol.excess_buffer
 
228
        self._push_back(protocol.unused_data)
167
229
 
168
230
    def _get_bytes(self, desired_count):
 
231
        if self._push_back_buffer is not None:
 
232
            return self._get_push_back_buffer()
169
233
        # We ignore the desired_count because on sockets it's more efficient to
170
234
        # read 4k at a time.
171
235
        return self.socket.recv(4096)
172
236
    
173
237
    def terminate_due_to_error(self):
174
 
        """Called when an unhandled exception from the protocol occurs."""
175
238
        # TODO: This should log to a server log file, but no such thing
176
239
        # exists yet.  Andrew Bennetts 2006-09-29.
177
240
        self.socket.close()
217
280
            protocol.accept_bytes(bytes)
218
281
 
219
282
    def _get_bytes(self, desired_count):
 
283
        if self._push_back_buffer is not None:
 
284
            return self._get_push_back_buffer()
220
285
        return self._in.read(desired_count)
221
286
 
222
287
    def terminate_due_to_error(self):
368
433
class SmartClientMedium(object):
369
434
    """Smart client is a medium for sending smart protocol requests over."""
370
435
 
 
436
    def __init__(self, base):
 
437
        super(SmartClientMedium, self).__init__()
 
438
        self.base = base
 
439
        self._protocol_version_error = None
 
440
        self._protocol_version = None
 
441
        self._done_hello = False
 
442
        # Be optimistic: we assume the remote end can accept new remote
 
443
        # requests until we get an error saying otherwise.  (1.2 adds some
 
444
        # requests that send bodies, which confuses older servers.)
 
445
        self._remote_is_at_least_1_2 = True
 
446
 
 
447
    def protocol_version(self):
 
448
        """Find out if 'hello' smart request works."""
 
449
        if self._protocol_version_error is not None:
 
450
            raise self._protocol_version_error
 
451
        if not self._done_hello:
 
452
            try:
 
453
                medium_request = self.get_request()
 
454
                # Send a 'hello' request in protocol version one, for maximum
 
455
                # backwards compatibility.
 
456
                client_protocol = SmartClientRequestProtocolOne(medium_request)
 
457
                client_protocol.query_version()
 
458
                self._done_hello = True
 
459
            except errors.SmartProtocolError, e:
 
460
                # Cache the error, just like we would cache a successful
 
461
                # result.
 
462
                self._protocol_version_error = e
 
463
                raise
 
464
        return '2'
 
465
 
 
466
    def should_probe(self):
 
467
        """Should RemoteBzrDirFormat.probe_transport send a smart request on
 
468
        this medium?
 
469
 
 
470
        Some transports are unambiguously smart-only; there's no need to check
 
471
        if the transport is able to carry smart requests, because that's all
 
472
        it is for.  In those cases, this method should return False.
 
473
 
 
474
        But some HTTP transports can sometimes fail to carry smart requests,
 
475
        but still be usuable for accessing remote bzrdirs via plain file
 
476
        accesses.  So for those transports, their media should return True here
 
477
        so that RemoteBzrDirFormat can determine if it is appropriate for that
 
478
        transport.
 
479
        """
 
480
        return False
 
481
 
371
482
    def disconnect(self):
372
483
        """If this medium maintains a persistent connection, close it.
373
484
        
374
485
        The default implementation does nothing.
375
486
        """
376
487
        
 
488
    def remote_path_from_transport(self, transport):
 
489
        """Convert transport into a path suitable for using in a request.
 
490
        
 
491
        Note that the resulting remote path doesn't encode the host name or
 
492
        anything but path, so it is only safe to use it in requests sent over
 
493
        the medium from the matching transport.
 
494
        """
 
495
        medium_base = urlutils.join(self.base, '/')
 
496
        rel_url = urlutils.relative_url(medium_base, transport.base)
 
497
        return urllib.unquote(rel_url)
 
498
 
377
499
 
378
500
class SmartClientStreamMedium(SmartClientMedium):
379
501
    """Stream based medium common class.
384
506
    receive bytes.
385
507
    """
386
508
 
387
 
    def __init__(self):
 
509
    def __init__(self, base):
 
510
        SmartClientMedium.__init__(self, base)
388
511
        self._current_request = None
389
 
        # Be optimistic: we assume the remote end can accept new remote
390
 
        # requests until we get an error saying otherwise.  (1.2 adds some
391
 
        # requests that send bodies, which confuses older servers.)
392
 
        self._remote_is_at_least_1_2 = True
393
512
 
394
513
    def accept_bytes(self, bytes):
395
514
        self._accept_bytes(bytes)
426
545
    This client does not manage the pipes: it assumes they will always be open.
427
546
    """
428
547
 
429
 
    def __init__(self, readable_pipe, writeable_pipe):
430
 
        SmartClientStreamMedium.__init__(self)
 
548
    def __init__(self, readable_pipe, writeable_pipe, base):
 
549
        SmartClientStreamMedium.__init__(self, base)
431
550
        self._readable_pipe = readable_pipe
432
551
        self._writeable_pipe = writeable_pipe
433
552
 
448
567
    """A client medium using SSH."""
449
568
    
450
569
    def __init__(self, host, port=None, username=None, password=None,
451
 
            vendor=None, bzr_remote_path=None):
 
570
            base=None, vendor=None, bzr_remote_path=None):
452
571
        """Creates a client that will connect on the first use.
453
572
        
454
573
        :param vendor: An optional override for the ssh vendor to use. See
455
574
            bzrlib.transport.ssh for details on ssh vendors.
456
575
        """
457
 
        SmartClientStreamMedium.__init__(self)
 
576
        SmartClientStreamMedium.__init__(self, base)
458
577
        self._connected = False
459
578
        self._host = host
460
579
        self._password = password
520
639
class SmartTCPClientMedium(SmartClientStreamMedium):
521
640
    """A client medium using TCP."""
522
641
    
523
 
    def __init__(self, host, port):
 
642
    def __init__(self, host, port, base):
524
643
        """Creates a client that will connect on the first use."""
525
 
        SmartClientStreamMedium.__init__(self)
 
644
        SmartClientStreamMedium.__init__(self, base)
526
645
        self._connected = False
527
646
        self._host = host
528
647
        self._port = port
606
725
        This clears the _current_request on self._medium to allow a new 
607
726
        request to be created.
608
727
        """
609
 
        assert self._medium._current_request is self
 
728
        if self._medium._current_request is not self:
 
729
            raise AssertionError()
610
730
        self._medium._current_request = None
611
731
        
612
732
    def _finished_writing(self):