~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

Merge bzr.dev.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
 
 
22
from cStringIO import StringIO
 
23
 
 
24
from bzrlib import errors
 
25
from bzrlib.smart import request
 
26
 
 
27
 
 
28
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
29
# responses to identify the protocol version being used. (There are no version
 
30
# one strings because that version doesn't send any).
 
31
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
32
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
33
 
 
34
 
 
35
def _recv_tuple(from_file):
 
36
    req_line = from_file.readline()
 
37
    return _decode_tuple(req_line)
 
38
 
 
39
 
 
40
def _decode_tuple(req_line):
 
41
    if req_line == None or req_line == '':
 
42
        return None
 
43
    if req_line[-1] != '\n':
 
44
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
45
    return tuple(req_line[:-1].split('\x01'))
 
46
 
 
47
 
 
48
def _encode_tuple(args):
 
49
    """Encode the tuple args to a bytestream."""
 
50
    return '\x01'.join(args) + '\n'
 
51
 
 
52
 
 
53
class SmartProtocolBase(object):
 
54
    """Methods common to client and server"""
 
55
 
 
56
    # TODO: this only actually accomodates a single block; possibly should
 
57
    # support multiple chunks?
 
58
    def _encode_bulk_data(self, body):
 
59
        """Encode body as a bulk data chunk."""
 
60
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
61
 
 
62
    def _serialise_offsets(self, offsets):
 
63
        """Serialise a readv offset list."""
 
64
        txt = []
 
65
        for start, length in offsets:
 
66
            txt.append('%d,%d' % (start, length))
 
67
        return '\n'.join(txt)
 
68
        
 
69
 
 
70
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
71
    """Server-side encoding and decoding logic for smart version 1."""
 
72
    
 
73
    def __init__(self, backing_transport, write_func):
 
74
        self._backing_transport = backing_transport
 
75
        self.excess_buffer = ''
 
76
        self._finished = False
 
77
        self.in_buffer = ''
 
78
        self.has_dispatched = False
 
79
        self.request = None
 
80
        self._body_decoder = None
 
81
        self._write_func = write_func
 
82
 
 
83
    def accept_bytes(self, bytes):
 
84
        """Take bytes, and advance the internal state machine appropriately.
 
85
        
 
86
        :param bytes: must be a byte string
 
87
        """
 
88
        assert isinstance(bytes, str)
 
89
        self.in_buffer += bytes
 
90
        if not self.has_dispatched:
 
91
            if '\n' not in self.in_buffer:
 
92
                # no command line yet
 
93
                return
 
94
            self.has_dispatched = True
 
95
            try:
 
96
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
97
                first_line += '\n'
 
98
                req_args = _decode_tuple(first_line)
 
99
                self.request = request.SmartServerRequestHandler(
 
100
                    self._backing_transport, commands=request.request_handlers)
 
101
                self.request.dispatch_command(req_args[0], req_args[1:])
 
102
                if self.request.finished_reading:
 
103
                    # trivial request
 
104
                    self.excess_buffer = self.in_buffer
 
105
                    self.in_buffer = ''
 
106
                    self._send_response(self.request.response.args,
 
107
                        self.request.response.body)
 
108
            except KeyboardInterrupt:
 
109
                raise
 
110
            except Exception, exception:
 
111
                # everything else: pass to client, flush, and quit
 
112
                self._send_response(('error', str(exception)))
 
113
                return
 
114
 
 
115
        if self.has_dispatched:
 
116
            if self._finished:
 
117
                # nothing to do.XXX: this routine should be a single state 
 
118
                # machine too.
 
119
                self.excess_buffer += self.in_buffer
 
120
                self.in_buffer = ''
 
121
                return
 
122
            if self._body_decoder is None:
 
123
                self._body_decoder = LengthPrefixedBodyDecoder()
 
124
            self._body_decoder.accept_bytes(self.in_buffer)
 
125
            self.in_buffer = self._body_decoder.unused_data
 
126
            body_data = self._body_decoder.read_pending_data()
 
127
            self.request.accept_body(body_data)
 
128
            if self._body_decoder.finished_reading:
 
129
                self.request.end_of_body()
 
130
                assert self.request.finished_reading, \
 
131
                    "no more body, request not finished"
 
132
            if self.request.response is not None:
 
133
                self._send_response(self.request.response.args,
 
134
                    self.request.response.body)
 
135
                self.excess_buffer = self.in_buffer
 
136
                self.in_buffer = ''
 
137
            else:
 
138
                assert not self.request.finished_reading, \
 
139
                    "no response and we have finished reading."
 
140
 
 
141
    def _send_response(self, args, body=None):
 
142
        """Send a smart server response down the output stream."""
 
143
        assert not self._finished, 'response already sent'
 
144
        self._finished = True
 
145
        self._write_protocol_version()
 
146
        self._write_func(_encode_tuple(args))
 
147
        if body is not None:
 
148
            assert isinstance(body, str), 'body must be a str'
 
149
            bytes = self._encode_bulk_data(body)
 
150
            self._write_func(bytes)
 
151
 
 
152
    def _write_protocol_version(self):
 
153
        """Write any prefixes this protocol requires.
 
154
        
 
155
        Version one doesn't send protocol versions.
 
156
        """
 
157
 
 
158
    def next_read_size(self):
 
159
        if self._finished:
 
160
            return 0
 
161
        if self._body_decoder is None:
 
162
            return 1
 
163
        else:
 
164
            return self._body_decoder.next_read_size()
 
165
 
 
166
 
 
167
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
168
    r"""Version two of the server side of the smart protocol.
 
169
   
 
170
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
171
    """
 
172
 
 
173
    def _write_protocol_version(self):
 
174
        r"""Write any prefixes this protocol requires.
 
175
        
 
176
        Version two sends the value of RESPONSE_VERSION_TWO.
 
177
        """
 
178
        self._write_func(RESPONSE_VERSION_TWO)
 
179
 
 
180
 
 
181
class LengthPrefixedBodyDecoder(object):
 
182
    """Decodes the length-prefixed bulk data."""
 
183
    
 
184
    def __init__(self):
 
185
        self.bytes_left = None
 
186
        self.finished_reading = False
 
187
        self.unused_data = ''
 
188
        self.state_accept = self._state_accept_expecting_length
 
189
        self.state_read = self._state_read_no_data
 
190
        self._in_buffer = ''
 
191
        self._trailer_buffer = ''
 
192
    
 
193
    def accept_bytes(self, bytes):
 
194
        """Decode as much of bytes as possible.
 
195
 
 
196
        If 'bytes' contains too much data it will be appended to
 
197
        self.unused_data.
 
198
 
 
199
        finished_reading will be set when no more data is required.  Further
 
200
        data will be appended to self.unused_data.
 
201
        """
 
202
        # accept_bytes is allowed to change the state
 
203
        current_state = self.state_accept
 
204
        self.state_accept(bytes)
 
205
        while current_state != self.state_accept:
 
206
            current_state = self.state_accept
 
207
            self.state_accept('')
 
208
 
 
209
    def next_read_size(self):
 
210
        if self.bytes_left is not None:
 
211
            # Ideally we want to read all the remainder of the body and the
 
212
            # trailer in one go.
 
213
            return self.bytes_left + 5
 
214
        elif self.state_accept == self._state_accept_reading_trailer:
 
215
            # Just the trailer left
 
216
            return 5 - len(self._trailer_buffer)
 
217
        elif self.state_accept == self._state_accept_expecting_length:
 
218
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
219
            # 'done\n').
 
220
            return 6
 
221
        else:
 
222
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
223
            return 1
 
224
        
 
225
    def read_pending_data(self):
 
226
        """Return any pending data that has been decoded."""
 
227
        return self.state_read()
 
228
 
 
229
    def _state_accept_expecting_length(self, bytes):
 
230
        self._in_buffer += bytes
 
231
        pos = self._in_buffer.find('\n')
 
232
        if pos == -1:
 
233
            return
 
234
        self.bytes_left = int(self._in_buffer[:pos])
 
235
        self._in_buffer = self._in_buffer[pos+1:]
 
236
        self.bytes_left -= len(self._in_buffer)
 
237
        self.state_accept = self._state_accept_reading_body
 
238
        self.state_read = self._state_read_in_buffer
 
239
 
 
240
    def _state_accept_reading_body(self, bytes):
 
241
        self._in_buffer += bytes
 
242
        self.bytes_left -= len(bytes)
 
243
        if self.bytes_left <= 0:
 
244
            # Finished with body
 
245
            if self.bytes_left != 0:
 
246
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
247
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
248
            self.bytes_left = None
 
249
            self.state_accept = self._state_accept_reading_trailer
 
250
        
 
251
    def _state_accept_reading_trailer(self, bytes):
 
252
        self._trailer_buffer += bytes
 
253
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
254
        # a ProtocolViolation exception?
 
255
        if self._trailer_buffer.startswith('done\n'):
 
256
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
257
            self.state_accept = self._state_accept_reading_unused
 
258
            self.finished_reading = True
 
259
    
 
260
    def _state_accept_reading_unused(self, bytes):
 
261
        self.unused_data += bytes
 
262
 
 
263
    def _state_read_no_data(self):
 
264
        return ''
 
265
 
 
266
    def _state_read_in_buffer(self):
 
267
        result = self._in_buffer
 
268
        self._in_buffer = ''
 
269
        return result
 
270
 
 
271
 
 
272
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
273
    """The client-side protocol for smart version 1."""
 
274
 
 
275
    def __init__(self, request):
 
276
        """Construct a SmartClientRequestProtocolOne.
 
277
 
 
278
        :param request: A SmartClientMediumRequest to serialise onto and
 
279
            deserialise from.
 
280
        """
 
281
        self._request = request
 
282
        self._body_buffer = None
 
283
 
 
284
    def call(self, *args):
 
285
        self._write_args(args)
 
286
        self._request.finished_writing()
 
287
 
 
288
    def call_with_body_bytes(self, args, body):
 
289
        """Make a remote call of args with body bytes 'body'.
 
290
 
 
291
        After calling this, call read_response_tuple to find the result out.
 
292
        """
 
293
        self._write_args(args)
 
294
        bytes = self._encode_bulk_data(body)
 
295
        self._request.accept_bytes(bytes)
 
296
        self._request.finished_writing()
 
297
 
 
298
    def call_with_body_readv_array(self, args, body):
 
299
        """Make a remote call with a readv array.
 
300
 
 
301
        The body is encoded with one line per readv offset pair. The numbers in
 
302
        each pair are separated by a comma, and no trailing \n is emitted.
 
303
        """
 
304
        self._write_args(args)
 
305
        readv_bytes = self._serialise_offsets(body)
 
306
        bytes = self._encode_bulk_data(readv_bytes)
 
307
        self._request.accept_bytes(bytes)
 
308
        self._request.finished_writing()
 
309
 
 
310
    def cancel_read_body(self):
 
311
        """After expecting a body, a response code may indicate one otherwise.
 
312
 
 
313
        This method lets the domain client inform the protocol that no body
 
314
        will be transmitted. This is a terminal method: after calling it the
 
315
        protocol is not able to be used further.
 
316
        """
 
317
        self._request.finished_reading()
 
318
 
 
319
    def read_response_tuple(self, expect_body=False):
 
320
        """Read a response tuple from the wire.
 
321
 
 
322
        This should only be called once.
 
323
        """
 
324
        result = self._recv_tuple()
 
325
        if not expect_body:
 
326
            self._request.finished_reading()
 
327
        return result
 
328
 
 
329
    def read_body_bytes(self, count=-1):
 
330
        """Read bytes from the body, decoding into a byte stream.
 
331
        
 
332
        We read all bytes at once to ensure we've checked the trailer for 
 
333
        errors, and then feed the buffer back as read_body_bytes is called.
 
334
        """
 
335
        if self._body_buffer is not None:
 
336
            return self._body_buffer.read(count)
 
337
        _body_decoder = LengthPrefixedBodyDecoder()
 
338
 
 
339
        while not _body_decoder.finished_reading:
 
340
            bytes_wanted = _body_decoder.next_read_size()
 
341
            bytes = self._request.read_bytes(bytes_wanted)
 
342
            _body_decoder.accept_bytes(bytes)
 
343
        self._request.finished_reading()
 
344
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
345
        # XXX: TODO check the trailer result.
 
346
        return self._body_buffer.read(count)
 
347
 
 
348
    def _recv_tuple(self):
 
349
        """Receive a tuple from the medium request."""
 
350
        line = ''
 
351
        while not line or line[-1] != '\n':
 
352
            # TODO: this is inefficient - but tuples are short.
 
353
            new_char = self._request.read_bytes(1)
 
354
            line += new_char
 
355
            assert new_char != '', "end of file reading from server."
 
356
        return _decode_tuple(line)
 
357
 
 
358
    def query_version(self):
 
359
        """Return protocol version number of the server."""
 
360
        self.call('hello')
 
361
        resp = self.read_response_tuple()
 
362
        if resp == ('ok', '1'):
 
363
            return 1
 
364
        elif resp == ('ok', '2'):
 
365
            return 2
 
366
        else:
 
367
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
368
 
 
369
    def _write_args(self, args):
 
370
        self._write_protocol_version()
 
371
        bytes = _encode_tuple(args)
 
372
        self._request.accept_bytes(bytes)
 
373
 
 
374
    def _write_protocol_version(self):
 
375
        """Write any prefixes this protocol requires.
 
376
        
 
377
        Version one doesn't send protocol versions.
 
378
        """
 
379
 
 
380
 
 
381
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
382
    """Version two of the client side of the smart protocol.
 
383
    
 
384
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
385
    """
 
386
 
 
387
    def read_response_tuple(self, expect_body=False):
 
388
        """Read a response tuple from the wire.
 
389
 
 
390
        This should only be called once.
 
391
        """
 
392
        version = self._request.read_line()
 
393
        if version != RESPONSE_VERSION_TWO:
 
394
            raise errors.SmartProtocolError('bad protocol marker %r' % version)
 
395
        return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
 
396
 
 
397
    def _write_protocol_version(self):
 
398
        r"""Write any prefixes this protocol requires.
 
399
        
 
400
        Version two sends the value of REQUEST_VERSION_TWO.
 
401
        """
 
402
        self._request.accept_bytes(REQUEST_VERSION_TWO)
 
403