1
# Copyright (C) 2006, 2007 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
from cStringIO import StringIO
24
from bzrlib import errors
25
from bzrlib.smart import request
28
# Protocol version strings. These are sent as prefixes of bzr requests and
29
# responses to identify the protocol version being used. (There are no version
30
# one strings because that version doesn't send any).
31
REQUEST_VERSION_TWO = 'bzr request 2\n'
32
RESPONSE_VERSION_TWO = 'bzr response 2\n'
35
def _recv_tuple(from_file):
36
req_line = from_file.readline()
37
return _decode_tuple(req_line)
40
def _decode_tuple(req_line):
41
if req_line == None or req_line == '':
43
if req_line[-1] != '\n':
44
raise errors.SmartProtocolError("request %r not terminated" % req_line)
45
return tuple(req_line[:-1].split('\x01'))
48
def _encode_tuple(args):
49
"""Encode the tuple args to a bytestream."""
50
return '\x01'.join(args) + '\n'
53
class SmartProtocolBase(object):
54
"""Methods common to client and server"""
56
# TODO: this only actually accomodates a single block; possibly should
57
# support multiple chunks?
58
def _encode_bulk_data(self, body):
59
"""Encode body as a bulk data chunk."""
60
return ''.join(('%d\n' % len(body), body, 'done\n'))
62
def _serialise_offsets(self, offsets):
63
"""Serialise a readv offset list."""
65
for start, length in offsets:
66
txt.append('%d,%d' % (start, length))
70
class SmartServerRequestProtocolOne(SmartProtocolBase):
71
"""Server-side encoding and decoding logic for smart version 1."""
73
def __init__(self, backing_transport, write_func):
74
self._backing_transport = backing_transport
75
self.excess_buffer = ''
76
self._finished = False
78
self.has_dispatched = False
80
self._body_decoder = None
81
self._write_func = write_func
83
def accept_bytes(self, bytes):
84
"""Take bytes, and advance the internal state machine appropriately.
86
:param bytes: must be a byte string
88
assert isinstance(bytes, str)
89
self.in_buffer += bytes
90
if not self.has_dispatched:
91
if '\n' not in self.in_buffer:
94
self.has_dispatched = True
96
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
98
req_args = _decode_tuple(first_line)
99
self.request = request.SmartServerRequestHandler(
100
self._backing_transport, commands=request.request_handlers)
101
self.request.dispatch_command(req_args[0], req_args[1:])
102
if self.request.finished_reading:
104
self.excess_buffer = self.in_buffer
106
self._send_response(self.request.response)
107
except KeyboardInterrupt:
109
except Exception, exception:
110
# everything else: pass to client, flush, and quit
111
self._send_response(request.FailedSmartServerResponse(
112
('error', str(exception))))
115
if self.has_dispatched:
117
# nothing to do.XXX: this routine should be a single state
119
self.excess_buffer += self.in_buffer
122
if self._body_decoder is None:
123
self._body_decoder = LengthPrefixedBodyDecoder()
124
self._body_decoder.accept_bytes(self.in_buffer)
125
self.in_buffer = self._body_decoder.unused_data
126
body_data = self._body_decoder.read_pending_data()
127
self.request.accept_body(body_data)
128
if self._body_decoder.finished_reading:
129
self.request.end_of_body()
130
assert self.request.finished_reading, \
131
"no more body, request not finished"
132
if self.request.response is not None:
133
self._send_response(self.request.response)
134
self.excess_buffer = self.in_buffer
137
assert not self.request.finished_reading, \
138
"no response and we have finished reading."
140
def _send_response(self, response):
141
"""Send a smart server response down the output stream."""
142
assert not self._finished, 'response already sent'
145
self._finished = True
146
self._write_protocol_version()
147
self._write_success_or_failure_prefix(response)
148
self._write_func(_encode_tuple(args))
150
assert isinstance(body, str), 'body must be a str'
151
bytes = self._encode_bulk_data(body)
152
self._write_func(bytes)
154
def _write_protocol_version(self):
155
"""Write any prefixes this protocol requires.
157
Version one doesn't send protocol versions.
160
def _write_success_or_failure_prefix(self, response):
161
"""Write the protocol specific success/failure prefix.
163
For SmartServerRequestProtocolOne this is omitted but we
164
call is_successful to ensure that the response is valid.
166
response.is_successful()
168
def next_read_size(self):
171
if self._body_decoder is None:
174
return self._body_decoder.next_read_size()
177
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
178
r"""Version two of the server side of the smart protocol.
180
This prefixes responses with the value of RESPONSE_VERSION_TWO.
183
def _write_success_or_failure_prefix(self, response):
184
"""Write the protocol specific success/failure prefix."""
185
if response.is_successful():
186
self._write_func('success\n')
188
self._write_func('failed\n')
190
def _write_protocol_version(self):
191
r"""Write any prefixes this protocol requires.
193
Version two sends the value of RESPONSE_VERSION_TWO.
195
self._write_func(RESPONSE_VERSION_TWO)
198
class LengthPrefixedBodyDecoder(object):
199
"""Decodes the length-prefixed bulk data."""
202
self.bytes_left = None
203
self.finished_reading = False
204
self.unused_data = ''
205
self.state_accept = self._state_accept_expecting_length
206
self.state_read = self._state_read_no_data
208
self._trailer_buffer = ''
210
def accept_bytes(self, bytes):
211
"""Decode as much of bytes as possible.
213
If 'bytes' contains too much data it will be appended to
216
finished_reading will be set when no more data is required. Further
217
data will be appended to self.unused_data.
219
# accept_bytes is allowed to change the state
220
current_state = self.state_accept
221
self.state_accept(bytes)
222
while current_state != self.state_accept:
223
current_state = self.state_accept
224
self.state_accept('')
226
def next_read_size(self):
227
if self.bytes_left is not None:
228
# Ideally we want to read all the remainder of the body and the
230
return self.bytes_left + 5
231
elif self.state_accept == self._state_accept_reading_trailer:
232
# Just the trailer left
233
return 5 - len(self._trailer_buffer)
234
elif self.state_accept == self._state_accept_expecting_length:
235
# There's still at least 6 bytes left ('\n' to end the length, plus
239
# Reading excess data. Either way, 1 byte at a time is fine.
242
def read_pending_data(self):
243
"""Return any pending data that has been decoded."""
244
return self.state_read()
246
def _state_accept_expecting_length(self, bytes):
247
self._in_buffer += bytes
248
pos = self._in_buffer.find('\n')
251
self.bytes_left = int(self._in_buffer[:pos])
252
self._in_buffer = self._in_buffer[pos+1:]
253
self.bytes_left -= len(self._in_buffer)
254
self.state_accept = self._state_accept_reading_body
255
self.state_read = self._state_read_in_buffer
257
def _state_accept_reading_body(self, bytes):
258
self._in_buffer += bytes
259
self.bytes_left -= len(bytes)
260
if self.bytes_left <= 0:
262
if self.bytes_left != 0:
263
self._trailer_buffer = self._in_buffer[self.bytes_left:]
264
self._in_buffer = self._in_buffer[:self.bytes_left]
265
self.bytes_left = None
266
self.state_accept = self._state_accept_reading_trailer
268
def _state_accept_reading_trailer(self, bytes):
269
self._trailer_buffer += bytes
270
# TODO: what if the trailer does not match "done\n"? Should this raise
271
# a ProtocolViolation exception?
272
if self._trailer_buffer.startswith('done\n'):
273
self.unused_data = self._trailer_buffer[len('done\n'):]
274
self.state_accept = self._state_accept_reading_unused
275
self.finished_reading = True
277
def _state_accept_reading_unused(self, bytes):
278
self.unused_data += bytes
280
def _state_read_no_data(self):
283
def _state_read_in_buffer(self):
284
result = self._in_buffer
289
class SmartClientRequestProtocolOne(SmartProtocolBase):
290
"""The client-side protocol for smart version 1."""
292
def __init__(self, request):
293
"""Construct a SmartClientRequestProtocolOne.
295
:param request: A SmartClientMediumRequest to serialise onto and
298
self._request = request
299
self._body_buffer = None
301
def call(self, *args):
302
self._write_args(args)
303
self._request.finished_writing()
305
def call_with_body_bytes(self, args, body):
306
"""Make a remote call of args with body bytes 'body'.
308
After calling this, call read_response_tuple to find the result out.
310
self._write_args(args)
311
bytes = self._encode_bulk_data(body)
312
self._request.accept_bytes(bytes)
313
self._request.finished_writing()
315
def call_with_body_readv_array(self, args, body):
316
"""Make a remote call with a readv array.
318
The body is encoded with one line per readv offset pair. The numbers in
319
each pair are separated by a comma, and no trailing \n is emitted.
321
self._write_args(args)
322
readv_bytes = self._serialise_offsets(body)
323
bytes = self._encode_bulk_data(readv_bytes)
324
self._request.accept_bytes(bytes)
325
self._request.finished_writing()
327
def cancel_read_body(self):
328
"""After expecting a body, a response code may indicate one otherwise.
330
This method lets the domain client inform the protocol that no body
331
will be transmitted. This is a terminal method: after calling it the
332
protocol is not able to be used further.
334
self._request.finished_reading()
336
def read_response_tuple(self, expect_body=False):
337
"""Read a response tuple from the wire.
339
This should only be called once.
341
result = self._recv_tuple()
343
self._request.finished_reading()
346
def read_body_bytes(self, count=-1):
347
"""Read bytes from the body, decoding into a byte stream.
349
We read all bytes at once to ensure we've checked the trailer for
350
errors, and then feed the buffer back as read_body_bytes is called.
352
if self._body_buffer is not None:
353
return self._body_buffer.read(count)
354
_body_decoder = LengthPrefixedBodyDecoder()
356
while not _body_decoder.finished_reading:
357
bytes_wanted = _body_decoder.next_read_size()
358
bytes = self._request.read_bytes(bytes_wanted)
359
_body_decoder.accept_bytes(bytes)
360
self._request.finished_reading()
361
self._body_buffer = StringIO(_body_decoder.read_pending_data())
362
# XXX: TODO check the trailer result.
363
return self._body_buffer.read(count)
365
def _recv_tuple(self):
366
"""Receive a tuple from the medium request."""
367
return _decode_tuple(self._recv_line())
369
def _recv_line(self):
370
"""Read an entire line from the medium request."""
372
while not line or line[-1] != '\n':
373
# TODO: this is inefficient - but tuples are short.
374
new_char = self._request.read_bytes(1)
376
assert new_char != '', "end of file reading from server."
379
def query_version(self):
380
"""Return protocol version number of the server."""
382
resp = self.read_response_tuple()
383
if resp == ('ok', '1'):
385
elif resp == ('ok', '2'):
388
raise errors.SmartProtocolError("bad response %r" % (resp,))
390
def _write_args(self, args):
391
self._write_protocol_version()
392
bytes = _encode_tuple(args)
393
self._request.accept_bytes(bytes)
395
def _write_protocol_version(self):
396
"""Write any prefixes this protocol requires.
398
Version one doesn't send protocol versions.
402
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
403
"""Version two of the client side of the smart protocol.
405
This prefixes the request with the value of REQUEST_VERSION_TWO.
408
def read_response_tuple(self, expect_body=False):
409
"""Read a response tuple from the wire.
411
This should only be called once.
413
version = self._request.read_line()
414
if version != RESPONSE_VERSION_TWO:
415
raise errors.SmartProtocolError('bad protocol marker %r' % version)
416
response_status = self._recv_line()
417
if response_status not in ('success\n', 'failed\n'):
418
raise errors.SmartProtocolError(
419
'bad protocol status %r' % response_status)
420
self.response_status = response_status == 'success\n'
421
return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
423
def _write_protocol_version(self):
424
r"""Write any prefixes this protocol requires.
426
Version two sends the value of REQUEST_VERSION_TWO.
428
self._request.accept_bytes(REQUEST_VERSION_TWO)