1
# Copyright (C) 2006, 2007 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
30
from bzrlib.smart import message, request
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.bencode import bdecode_as_tuple, bencode
35
# Protocol version strings. These are sent as prefixes of bzr requests and
36
# responses to identify the protocol version being used. (There are no version
37
# one strings because that version doesn't send any).
38
REQUEST_VERSION_TWO = 'bzr request 2\n'
39
RESPONSE_VERSION_TWO = 'bzr response 2\n'
41
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
42
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
45
def _recv_tuple(from_file):
46
req_line = from_file.readline()
47
return _decode_tuple(req_line)
50
def _decode_tuple(req_line):
51
if req_line is None or req_line == '':
53
if req_line[-1] != '\n':
54
raise errors.SmartProtocolError("request %r not terminated" % req_line)
55
return tuple(req_line[:-1].split('\x01'))
58
def _encode_tuple(args):
59
"""Encode the tuple args to a bytestream."""
60
return '\x01'.join(args) + '\n'
63
class Requester(object):
64
"""Abstract base class for an object that can issue requests on a smart
68
def call(self, *args):
69
"""Make a remote call.
71
:param args: the arguments of this call.
73
raise NotImplementedError(self.call)
75
def call_with_body_bytes(self, args, body):
76
"""Make a remote call with a body.
78
:param args: the arguments of this call.
80
:param body: the body to send with the request.
82
raise NotImplementedError(self.call_with_body_bytes)
84
def call_with_body_readv_array(self, args, body):
85
"""Make a remote call with a readv array.
87
:param args: the arguments of this call.
88
:type body: iterable of (start, length) tuples.
89
:param body: the readv ranges to send with this request.
91
raise NotImplementedError(self.call_with_body_readv_array)
93
def set_headers(self, headers):
94
raise NotImplementedError(self.set_headers)
97
class SmartProtocolBase(object):
98
"""Methods common to client and server"""
100
# TODO: this only actually accomodates a single block; possibly should
101
# support multiple chunks?
102
def _encode_bulk_data(self, body):
103
"""Encode body as a bulk data chunk."""
104
return ''.join(('%d\n' % len(body), body, 'done\n'))
106
def _serialise_offsets(self, offsets):
107
"""Serialise a readv offset list."""
109
for start, length in offsets:
110
txt.append('%d,%d' % (start, length))
111
return '\n'.join(txt)
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
"""Server-side encoding and decoding logic for smart version 1."""
117
def __init__(self, backing_transport, write_func, root_client_path='/',
119
self._backing_transport = backing_transport
120
self._root_client_path = root_client_path
121
self._jail_root = jail_root
122
self.unused_data = ''
123
self._finished = False
125
self._has_dispatched = False
127
self._body_decoder = None
128
self._write_func = write_func
130
def accept_bytes(self, bytes):
131
"""Take bytes, and advance the internal state machine appropriately.
133
:param bytes: must be a byte string
135
if not isinstance(bytes, str):
136
raise ValueError(bytes)
137
self.in_buffer += bytes
138
if not self._has_dispatched:
139
if '\n' not in self.in_buffer:
140
# no command line yet
142
self._has_dispatched = True
144
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
146
req_args = _decode_tuple(first_line)
147
self.request = request.SmartServerRequestHandler(
148
self._backing_transport, commands=request.request_handlers,
149
root_client_path=self._root_client_path,
150
jail_root=self._jail_root)
151
self.request.args_received(req_args)
152
if self.request.finished_reading:
154
self.unused_data = self.in_buffer
156
self._send_response(self.request.response)
157
except KeyboardInterrupt:
159
except errors.UnknownSmartMethod, err:
160
protocol_error = errors.SmartProtocolError(
161
"bad request %r" % (err.verb,))
162
failure = request.FailedSmartServerResponse(
163
('error', str(protocol_error)))
164
self._send_response(failure)
166
except Exception, exception:
167
# everything else: pass to client, flush, and quit
168
log_exception_quietly()
169
self._send_response(request.FailedSmartServerResponse(
170
('error', str(exception))))
173
if self._has_dispatched:
175
# nothing to do.XXX: this routine should be a single state
177
self.unused_data += self.in_buffer
180
if self._body_decoder is None:
181
self._body_decoder = LengthPrefixedBodyDecoder()
182
self._body_decoder.accept_bytes(self.in_buffer)
183
self.in_buffer = self._body_decoder.unused_data
184
body_data = self._body_decoder.read_pending_data()
185
self.request.accept_body(body_data)
186
if self._body_decoder.finished_reading:
187
self.request.end_of_body()
188
if not self.request.finished_reading:
189
raise AssertionError("no more body, request not finished")
190
if self.request.response is not None:
191
self._send_response(self.request.response)
192
self.unused_data = self.in_buffer
195
if self.request.finished_reading:
196
raise AssertionError(
197
"no response and we have finished reading.")
199
def _send_response(self, response):
200
"""Send a smart server response down the output stream."""
202
raise AssertionError('response already sent')
205
self._finished = True
206
self._write_protocol_version()
207
self._write_success_or_failure_prefix(response)
208
self._write_func(_encode_tuple(args))
210
if not isinstance(body, str):
211
raise ValueError(body)
212
bytes = self._encode_bulk_data(body)
213
self._write_func(bytes)
215
def _write_protocol_version(self):
216
"""Write any prefixes this protocol requires.
218
Version one doesn't send protocol versions.
221
def _write_success_or_failure_prefix(self, response):
222
"""Write the protocol specific success/failure prefix.
224
For SmartServerRequestProtocolOne this is omitted but we
225
call is_successful to ensure that the response is valid.
227
response.is_successful()
229
def next_read_size(self):
232
if self._body_decoder is None:
235
return self._body_decoder.next_read_size()
238
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
239
r"""Version two of the server side of the smart protocol.
241
This prefixes responses with the value of RESPONSE_VERSION_TWO.
244
response_marker = RESPONSE_VERSION_TWO
245
request_marker = REQUEST_VERSION_TWO
247
def _write_success_or_failure_prefix(self, response):
248
"""Write the protocol specific success/failure prefix."""
249
if response.is_successful():
250
self._write_func('success\n')
252
self._write_func('failed\n')
254
def _write_protocol_version(self):
255
r"""Write any prefixes this protocol requires.
257
Version two sends the value of RESPONSE_VERSION_TWO.
259
self._write_func(self.response_marker)
261
def _send_response(self, response):
262
"""Send a smart server response down the output stream."""
264
raise AssertionError('response already sent')
265
self._finished = True
266
self._write_protocol_version()
267
self._write_success_or_failure_prefix(response)
268
self._write_func(_encode_tuple(response.args))
269
if response.body is not None:
270
if not isinstance(response.body, str):
271
raise AssertionError('body must be a str')
272
if not (response.body_stream is None):
273
raise AssertionError(
274
'body_stream and body cannot both be set')
275
bytes = self._encode_bulk_data(response.body)
276
self._write_func(bytes)
277
elif response.body_stream is not None:
278
_send_stream(response.body_stream, self._write_func)
281
def _send_stream(stream, write_func):
282
write_func('chunked\n')
283
_send_chunks(stream, write_func)
287
def _send_chunks(stream, write_func):
289
if isinstance(chunk, str):
290
bytes = "%x\n%s" % (len(chunk), chunk)
292
elif isinstance(chunk, request.FailedSmartServerResponse):
294
_send_chunks(chunk.args, write_func)
297
raise errors.BzrError(
298
'Chunks must be str or FailedSmartServerResponse, got %r'
302
class _NeedMoreBytes(Exception):
303
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
307
def __init__(self, count=None):
310
:param count: the total number of bytes needed by the current state.
311
May be None if the number of bytes needed is unknown.
316
class _StatefulDecoder(object):
317
"""Base class for writing state machines to decode byte streams.
319
Subclasses should provide a self.state_accept attribute that accepts bytes
320
and, if appropriate, updates self.state_accept to a different function.
321
accept_bytes will call state_accept as often as necessary to make sure the
322
state machine has progressed as far as possible before it returns.
324
See ProtocolThreeDecoder for an example subclass.
328
self.finished_reading = False
329
self._in_buffer_list = []
330
self._in_buffer_len = 0
331
self.unused_data = ''
332
self.bytes_left = None
333
self._number_needed_bytes = None
335
def _get_in_buffer(self):
336
if len(self._in_buffer_list) == 1:
337
return self._in_buffer_list[0]
338
in_buffer = ''.join(self._in_buffer_list)
339
if len(in_buffer) != self._in_buffer_len:
340
raise AssertionError(
341
"Length of buffer did not match expected value: %s != %s"
342
% self._in_buffer_len, len(in_buffer))
343
self._in_buffer_list = [in_buffer]
346
def _get_in_bytes(self, count):
347
"""Grab X bytes from the input_buffer.
349
Callers should have already checked that self._in_buffer_len is >
350
count. Note, this does not consume the bytes from the buffer. The
351
caller will still need to call _get_in_buffer() and then
352
_set_in_buffer() if they actually need to consume the bytes.
354
# check if we can yield the bytes from just the first entry in our list
355
if len(self._in_buffer_list) == 0:
356
raise AssertionError('Callers must be sure we have buffered bytes'
357
' before calling _get_in_bytes')
358
if len(self._in_buffer_list[0]) > count:
359
return self._in_buffer_list[0][:count]
360
# We can't yield it from the first buffer, so collapse all buffers, and
362
in_buf = self._get_in_buffer()
363
return in_buf[:count]
365
def _set_in_buffer(self, new_buf):
366
if new_buf is not None:
367
self._in_buffer_list = [new_buf]
368
self._in_buffer_len = len(new_buf)
370
self._in_buffer_list = []
371
self._in_buffer_len = 0
373
def accept_bytes(self, bytes):
374
"""Decode as much of bytes as possible.
376
If 'bytes' contains too much data it will be appended to
379
finished_reading will be set when no more data is required. Further
380
data will be appended to self.unused_data.
382
# accept_bytes is allowed to change the state
383
self._number_needed_bytes = None
384
# lsprof puts a very large amount of time on this specific call for
386
self._in_buffer_list.append(bytes)
387
self._in_buffer_len += len(bytes)
389
# Run the function for the current state.
390
current_state = self.state_accept
392
while current_state != self.state_accept:
393
# The current state has changed. Run the function for the new
394
# current state, so that it can:
395
# - decode any unconsumed bytes left in a buffer, and
396
# - signal how many more bytes are expected (via raising
398
current_state = self.state_accept
400
except _NeedMoreBytes, e:
401
self._number_needed_bytes = e.count
404
class ChunkedBodyDecoder(_StatefulDecoder):
405
"""Decoder for chunked body data.
407
This is very similar the HTTP's chunked encoding. See the description of
408
streamed body data in `doc/developers/network-protocol.txt` for details.
412
_StatefulDecoder.__init__(self)
413
self.state_accept = self._state_accept_expecting_header
414
self.chunk_in_progress = None
415
self.chunks = collections.deque()
417
self.error_in_progress = None
419
def next_read_size(self):
420
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
421
# end-of-body marker is 4 bytes: 'END\n'.
422
if self.state_accept == self._state_accept_reading_chunk:
423
# We're expecting more chunk content. So we're expecting at least
424
# the rest of this chunk plus an END chunk.
425
return self.bytes_left + 4
426
elif self.state_accept == self._state_accept_expecting_length:
427
if self._in_buffer_len == 0:
428
# We're expecting a chunk length. There's at least two bytes
429
# left: a digit plus '\n'.
432
# We're in the middle of reading a chunk length. So there's at
433
# least one byte left, the '\n' that terminates the length.
435
elif self.state_accept == self._state_accept_reading_unused:
437
elif self.state_accept == self._state_accept_expecting_header:
438
return max(0, len('chunked\n') - self._in_buffer_len)
440
raise AssertionError("Impossible state: %r" % (self.state_accept,))
442
def read_next_chunk(self):
444
return self.chunks.popleft()
448
def _extract_line(self):
449
in_buf = self._get_in_buffer()
450
pos = in_buf.find('\n')
452
# We haven't read a complete line yet, so request more bytes before
454
raise _NeedMoreBytes(1)
456
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
457
self._set_in_buffer(in_buf[pos+1:])
461
self.unused_data = self._get_in_buffer()
462
self._in_buffer_list = []
463
self._in_buffer_len = 0
464
self.state_accept = self._state_accept_reading_unused
466
error_args = tuple(self.error_in_progress)
467
self.chunks.append(request.FailedSmartServerResponse(error_args))
468
self.error_in_progress = None
469
self.finished_reading = True
471
def _state_accept_expecting_header(self):
472
prefix = self._extract_line()
473
if prefix == 'chunked':
474
self.state_accept = self._state_accept_expecting_length
476
raise errors.SmartProtocolError(
477
'Bad chunked body header: "%s"' % (prefix,))
479
def _state_accept_expecting_length(self):
480
prefix = self._extract_line()
483
self.error_in_progress = []
484
self._state_accept_expecting_length()
486
elif prefix == 'END':
487
# We've read the end-of-body marker.
488
# Any further bytes are unused data, including the bytes left in
493
self.bytes_left = int(prefix, 16)
494
self.chunk_in_progress = ''
495
self.state_accept = self._state_accept_reading_chunk
497
def _state_accept_reading_chunk(self):
498
in_buf = self._get_in_buffer()
499
in_buffer_len = len(in_buf)
500
self.chunk_in_progress += in_buf[:self.bytes_left]
501
self._set_in_buffer(in_buf[self.bytes_left:])
502
self.bytes_left -= in_buffer_len
503
if self.bytes_left <= 0:
504
# Finished with chunk
505
self.bytes_left = None
507
self.error_in_progress.append(self.chunk_in_progress)
509
self.chunks.append(self.chunk_in_progress)
510
self.chunk_in_progress = None
511
self.state_accept = self._state_accept_expecting_length
513
def _state_accept_reading_unused(self):
514
self.unused_data += self._get_in_buffer()
515
self._in_buffer_list = []
518
class LengthPrefixedBodyDecoder(_StatefulDecoder):
519
"""Decodes the length-prefixed bulk data."""
522
_StatefulDecoder.__init__(self)
523
self.state_accept = self._state_accept_expecting_length
524
self.state_read = self._state_read_no_data
526
self._trailer_buffer = ''
528
def next_read_size(self):
529
if self.bytes_left is not None:
530
# Ideally we want to read all the remainder of the body and the
532
return self.bytes_left + 5
533
elif self.state_accept == self._state_accept_reading_trailer:
534
# Just the trailer left
535
return 5 - len(self._trailer_buffer)
536
elif self.state_accept == self._state_accept_expecting_length:
537
# There's still at least 6 bytes left ('\n' to end the length, plus
541
# Reading excess data. Either way, 1 byte at a time is fine.
544
def read_pending_data(self):
545
"""Return any pending data that has been decoded."""
546
return self.state_read()
548
def _state_accept_expecting_length(self):
549
in_buf = self._get_in_buffer()
550
pos = in_buf.find('\n')
553
self.bytes_left = int(in_buf[:pos])
554
self._set_in_buffer(in_buf[pos+1:])
555
self.state_accept = self._state_accept_reading_body
556
self.state_read = self._state_read_body_buffer
558
def _state_accept_reading_body(self):
559
in_buf = self._get_in_buffer()
561
self.bytes_left -= len(in_buf)
562
self._set_in_buffer(None)
563
if self.bytes_left <= 0:
565
if self.bytes_left != 0:
566
self._trailer_buffer = self._body[self.bytes_left:]
567
self._body = self._body[:self.bytes_left]
568
self.bytes_left = None
569
self.state_accept = self._state_accept_reading_trailer
571
def _state_accept_reading_trailer(self):
572
self._trailer_buffer += self._get_in_buffer()
573
self._set_in_buffer(None)
574
# TODO: what if the trailer does not match "done\n"? Should this raise
575
# a ProtocolViolation exception?
576
if self._trailer_buffer.startswith('done\n'):
577
self.unused_data = self._trailer_buffer[len('done\n'):]
578
self.state_accept = self._state_accept_reading_unused
579
self.finished_reading = True
581
def _state_accept_reading_unused(self):
582
self.unused_data += self._get_in_buffer()
583
self._set_in_buffer(None)
585
def _state_read_no_data(self):
588
def _state_read_body_buffer(self):
594
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
595
message.ResponseHandler):
596
"""The client-side protocol for smart version 1."""
598
def __init__(self, request):
599
"""Construct a SmartClientRequestProtocolOne.
601
:param request: A SmartClientMediumRequest to serialise onto and
604
self._request = request
605
self._body_buffer = None
606
self._request_start_time = None
607
self._last_verb = None
610
def set_headers(self, headers):
611
self._headers = dict(headers)
613
def call(self, *args):
614
if 'hpss' in debug.debug_flags:
615
mutter('hpss call: %s', repr(args)[1:-1])
616
if getattr(self._request._medium, 'base', None) is not None:
617
mutter(' (to %s)', self._request._medium.base)
618
self._request_start_time = time.time()
619
self._write_args(args)
620
self._request.finished_writing()
621
self._last_verb = args[0]
623
def call_with_body_bytes(self, args, body):
624
"""Make a remote call of args with body bytes 'body'.
626
After calling this, call read_response_tuple to find the result out.
628
if 'hpss' in debug.debug_flags:
629
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
630
if getattr(self._request._medium, '_path', None) is not None:
631
mutter(' (to %s)', self._request._medium._path)
632
mutter(' %d bytes', len(body))
633
self._request_start_time = time.time()
634
if 'hpssdetail' in debug.debug_flags:
635
mutter('hpss body content: %s', body)
636
self._write_args(args)
637
bytes = self._encode_bulk_data(body)
638
self._request.accept_bytes(bytes)
639
self._request.finished_writing()
640
self._last_verb = args[0]
642
def call_with_body_readv_array(self, args, body):
643
"""Make a remote call with a readv array.
645
The body is encoded with one line per readv offset pair. The numbers in
646
each pair are separated by a comma, and no trailing \n is emitted.
648
if 'hpss' in debug.debug_flags:
649
mutter('hpss call w/readv: %s', repr(args)[1:-1])
650
if getattr(self._request._medium, '_path', None) is not None:
651
mutter(' (to %s)', self._request._medium._path)
652
self._request_start_time = time.time()
653
self._write_args(args)
654
readv_bytes = self._serialise_offsets(body)
655
bytes = self._encode_bulk_data(readv_bytes)
656
self._request.accept_bytes(bytes)
657
self._request.finished_writing()
658
if 'hpss' in debug.debug_flags:
659
mutter(' %d bytes in readv request', len(readv_bytes))
660
self._last_verb = args[0]
662
def call_with_body_stream(self, args, stream):
663
# Protocols v1 and v2 don't support body streams. So it's safe to
664
# assume that a v1/v2 server doesn't support whatever method we're
665
# trying to call with a body stream.
666
self._request.finished_writing()
667
self._request.finished_reading()
668
raise errors.UnknownSmartMethod(args[0])
670
def cancel_read_body(self):
671
"""After expecting a body, a response code may indicate one otherwise.
673
This method lets the domain client inform the protocol that no body
674
will be transmitted. This is a terminal method: after calling it the
675
protocol is not able to be used further.
677
self._request.finished_reading()
679
def _read_response_tuple(self):
680
result = self._recv_tuple()
681
if 'hpss' in debug.debug_flags:
682
if self._request_start_time is not None:
683
mutter(' result: %6.3fs %s',
684
time.time() - self._request_start_time,
686
self._request_start_time = None
688
mutter(' result: %s', repr(result)[1:-1])
691
def read_response_tuple(self, expect_body=False):
692
"""Read a response tuple from the wire.
694
This should only be called once.
696
result = self._read_response_tuple()
697
self._response_is_unknown_method(result)
698
self._raise_args_if_error(result)
700
self._request.finished_reading()
703
def _raise_args_if_error(self, result_tuple):
704
# Later protocol versions have an explicit flag in the protocol to say
705
# if an error response is "failed" or not. In version 1 we don't have
706
# that luxury. So here is a complete list of errors that can be
707
# returned in response to existing version 1 smart requests. Responses
708
# starting with these codes are always "failed" responses.
715
'UnicodeEncodeError',
716
'UnicodeDecodeError',
722
'UnlockableTransport',
728
if result_tuple[0] in v1_error_codes:
729
self._request.finished_reading()
730
raise errors.ErrorFromSmartServer(result_tuple)
732
def _response_is_unknown_method(self, result_tuple):
733
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
734
method' response to the request.
736
:param response: The response from a smart client call_expecting_body
738
:param verb: The verb used in that call.
739
:raises: UnexpectedSmartServerResponse
741
if (result_tuple == ('error', "Generic bzr smart protocol error: "
742
"bad request '%s'" % self._last_verb) or
743
result_tuple == ('error', "Generic bzr smart protocol error: "
744
"bad request u'%s'" % self._last_verb)):
745
# The response will have no body, so we've finished reading.
746
self._request.finished_reading()
747
raise errors.UnknownSmartMethod(self._last_verb)
749
def read_body_bytes(self, count=-1):
750
"""Read bytes from the body, decoding into a byte stream.
752
We read all bytes at once to ensure we've checked the trailer for
753
errors, and then feed the buffer back as read_body_bytes is called.
755
if self._body_buffer is not None:
756
return self._body_buffer.read(count)
757
_body_decoder = LengthPrefixedBodyDecoder()
759
while not _body_decoder.finished_reading:
760
bytes = self._request.read_bytes(_body_decoder.next_read_size())
762
# end of file encountered reading from server
763
raise errors.ConnectionReset(
764
"Connection lost while reading response body.")
765
_body_decoder.accept_bytes(bytes)
766
self._request.finished_reading()
767
self._body_buffer = StringIO(_body_decoder.read_pending_data())
768
# XXX: TODO check the trailer result.
769
if 'hpss' in debug.debug_flags:
770
mutter(' %d body bytes read',
771
len(self._body_buffer.getvalue()))
772
return self._body_buffer.read(count)
774
def _recv_tuple(self):
775
"""Receive a tuple from the medium request."""
776
return _decode_tuple(self._request.read_line())
778
def query_version(self):
779
"""Return protocol version number of the server."""
781
resp = self.read_response_tuple()
782
if resp == ('ok', '1'):
784
elif resp == ('ok', '2'):
787
raise errors.SmartProtocolError("bad response %r" % (resp,))
789
def _write_args(self, args):
790
self._write_protocol_version()
791
bytes = _encode_tuple(args)
792
self._request.accept_bytes(bytes)
794
def _write_protocol_version(self):
795
"""Write any prefixes this protocol requires.
797
Version one doesn't send protocol versions.
801
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
802
"""Version two of the client side of the smart protocol.
804
This prefixes the request with the value of REQUEST_VERSION_TWO.
807
response_marker = RESPONSE_VERSION_TWO
808
request_marker = REQUEST_VERSION_TWO
810
def read_response_tuple(self, expect_body=False):
811
"""Read a response tuple from the wire.
813
This should only be called once.
815
version = self._request.read_line()
816
if version != self.response_marker:
817
self._request.finished_reading()
818
raise errors.UnexpectedProtocolVersionMarker(version)
819
response_status = self._request.read_line()
820
result = SmartClientRequestProtocolOne._read_response_tuple(self)
821
self._response_is_unknown_method(result)
822
if response_status == 'success\n':
823
self.response_status = True
825
self._request.finished_reading()
827
elif response_status == 'failed\n':
828
self.response_status = False
829
self._request.finished_reading()
830
raise errors.ErrorFromSmartServer(result)
832
raise errors.SmartProtocolError(
833
'bad protocol status %r' % response_status)
835
def _write_protocol_version(self):
836
"""Write any prefixes this protocol requires.
838
Version two sends the value of REQUEST_VERSION_TWO.
840
self._request.accept_bytes(self.request_marker)
842
def read_streamed_body(self):
843
"""Read bytes from the body, decoding into a byte stream.
845
# Read no more than 64k at a time so that we don't risk error 10055 (no
846
# buffer space available) on Windows.
847
_body_decoder = ChunkedBodyDecoder()
848
while not _body_decoder.finished_reading:
849
bytes = self._request.read_bytes(_body_decoder.next_read_size())
851
# end of file encountered reading from server
852
raise errors.ConnectionReset(
853
"Connection lost while reading streamed body.")
854
_body_decoder.accept_bytes(bytes)
855
for body_bytes in iter(_body_decoder.read_next_chunk, None):
856
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
857
mutter(' %d byte chunk read',
860
self._request.finished_reading()
863
def build_server_protocol_three(backing_transport, write_func,
864
root_client_path, jail_root=None):
865
request_handler = request.SmartServerRequestHandler(
866
backing_transport, commands=request.request_handlers,
867
root_client_path=root_client_path, jail_root=jail_root)
868
responder = ProtocolThreeResponder(write_func)
869
message_handler = message.ConventionalRequestHandler(request_handler, responder)
870
return ProtocolThreeDecoder(message_handler)
873
class ProtocolThreeDecoder(_StatefulDecoder):
875
response_marker = RESPONSE_VERSION_THREE
876
request_marker = REQUEST_VERSION_THREE
878
def __init__(self, message_handler, expect_version_marker=False):
879
_StatefulDecoder.__init__(self)
880
self._has_dispatched = False
882
if expect_version_marker:
883
self.state_accept = self._state_accept_expecting_protocol_version
884
# We're expecting at least the protocol version marker + some
886
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
888
self.state_accept = self._state_accept_expecting_headers
889
self._number_needed_bytes = 4
890
self.decoding_failed = False
891
self.request_handler = self.message_handler = message_handler
893
def accept_bytes(self, bytes):
894
self._number_needed_bytes = None
896
_StatefulDecoder.accept_bytes(self, bytes)
897
except KeyboardInterrupt:
899
except errors.SmartMessageHandlerError, exception:
900
# We do *not* set self.decoding_failed here. The message handler
901
# has raised an error, but the decoder is still able to parse bytes
902
# and determine when this message ends.
903
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
904
log_exception_quietly()
905
self.message_handler.protocol_error(exception.exc_value)
906
# The state machine is ready to continue decoding, but the
907
# exception has interrupted the loop that runs the state machine.
908
# So we call accept_bytes again to restart it.
909
self.accept_bytes('')
910
except Exception, exception:
911
# The decoder itself has raised an exception. We cannot continue
913
self.decoding_failed = True
914
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
915
# This happens during normal operation when the client tries a
916
# protocol version the server doesn't understand, so no need to
917
# log a traceback every time.
918
# Note that this can only happen when
919
# expect_version_marker=True, which is only the case on the
923
log_exception_quietly()
924
self.message_handler.protocol_error(exception)
926
def _extract_length_prefixed_bytes(self):
927
if self._in_buffer_len < 4:
928
# A length prefix by itself is 4 bytes, and we don't even have that
930
raise _NeedMoreBytes(4)
931
(length,) = struct.unpack('!L', self._get_in_bytes(4))
932
end_of_bytes = 4 + length
933
if self._in_buffer_len < end_of_bytes:
934
# We haven't yet read as many bytes as the length-prefix says there
936
raise _NeedMoreBytes(end_of_bytes)
937
# Extract the bytes from the buffer.
938
in_buf = self._get_in_buffer()
939
bytes = in_buf[4:end_of_bytes]
940
self._set_in_buffer(in_buf[end_of_bytes:])
943
def _extract_prefixed_bencoded_data(self):
944
prefixed_bytes = self._extract_length_prefixed_bytes()
946
decoded = bdecode_as_tuple(prefixed_bytes)
948
raise errors.SmartProtocolError(
949
'Bytes %r not bencoded' % (prefixed_bytes,))
952
def _extract_single_byte(self):
953
if self._in_buffer_len == 0:
954
# The buffer is empty
955
raise _NeedMoreBytes(1)
956
in_buf = self._get_in_buffer()
958
self._set_in_buffer(in_buf[1:])
961
def _state_accept_expecting_protocol_version(self):
962
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
963
in_buf = self._get_in_buffer()
965
# We don't have enough bytes to check if the protocol version
966
# marker is right. But we can check if it is already wrong by
967
# checking that the start of MESSAGE_VERSION_THREE matches what
969
# [In fact, if the remote end isn't bzr we might never receive
970
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
971
# are wrong then we should just raise immediately rather than
973
if not MESSAGE_VERSION_THREE.startswith(in_buf):
974
# We have enough bytes to know the protocol version is wrong
975
raise errors.UnexpectedProtocolVersionMarker(in_buf)
976
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
977
if not in_buf.startswith(MESSAGE_VERSION_THREE):
978
raise errors.UnexpectedProtocolVersionMarker(in_buf)
979
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
980
self.state_accept = self._state_accept_expecting_headers
982
def _state_accept_expecting_headers(self):
983
decoded = self._extract_prefixed_bencoded_data()
984
if type(decoded) is not dict:
985
raise errors.SmartProtocolError(
986
'Header object %r is not a dict' % (decoded,))
987
self.state_accept = self._state_accept_expecting_message_part
989
self.message_handler.headers_received(decoded)
991
raise errors.SmartMessageHandlerError(sys.exc_info())
993
def _state_accept_expecting_message_part(self):
994
message_part_kind = self._extract_single_byte()
995
if message_part_kind == 'o':
996
self.state_accept = self._state_accept_expecting_one_byte
997
elif message_part_kind == 's':
998
self.state_accept = self._state_accept_expecting_structure
999
elif message_part_kind == 'b':
1000
self.state_accept = self._state_accept_expecting_bytes
1001
elif message_part_kind == 'e':
1004
raise errors.SmartProtocolError(
1005
'Bad message kind byte: %r' % (message_part_kind,))
1007
def _state_accept_expecting_one_byte(self):
1008
byte = self._extract_single_byte()
1009
self.state_accept = self._state_accept_expecting_message_part
1011
self.message_handler.byte_part_received(byte)
1013
raise errors.SmartMessageHandlerError(sys.exc_info())
1015
def _state_accept_expecting_bytes(self):
1016
# XXX: this should not buffer whole message part, but instead deliver
1017
# the bytes as they arrive.
1018
prefixed_bytes = self._extract_length_prefixed_bytes()
1019
self.state_accept = self._state_accept_expecting_message_part
1021
self.message_handler.bytes_part_received(prefixed_bytes)
1023
raise errors.SmartMessageHandlerError(sys.exc_info())
1025
def _state_accept_expecting_structure(self):
1026
structure = self._extract_prefixed_bencoded_data()
1027
self.state_accept = self._state_accept_expecting_message_part
1029
self.message_handler.structure_part_received(structure)
1031
raise errors.SmartMessageHandlerError(sys.exc_info())
1034
self.unused_data = self._get_in_buffer()
1035
self._set_in_buffer(None)
1036
self.state_accept = self._state_accept_reading_unused
1038
self.message_handler.end_received()
1040
raise errors.SmartMessageHandlerError(sys.exc_info())
1042
def _state_accept_reading_unused(self):
1043
self.unused_data += self._get_in_buffer()
1044
self._set_in_buffer(None)
1046
def next_read_size(self):
1047
if self.state_accept == self._state_accept_reading_unused:
1049
elif self.decoding_failed:
1050
# An exception occured while processing this message, probably from
1051
# self.message_handler. We're not sure that this state machine is
1052
# in a consistent state, so just signal that we're done (i.e. give
1056
if self._number_needed_bytes is not None:
1057
return self._number_needed_bytes - self._in_buffer_len
1059
raise AssertionError("don't know how many bytes are expected!")
1062
class _ProtocolThreeEncoder(object):
1064
response_marker = request_marker = MESSAGE_VERSION_THREE
1066
def __init__(self, write_func):
1068
self._real_write_func = write_func
1070
def _write_func(self, bytes):
1071
self._buf.append(bytes)
1072
if len(self._buf) > 100:
1077
self._real_write_func(''.join(self._buf))
1080
def _serialise_offsets(self, offsets):
1081
"""Serialise a readv offset list."""
1083
for start, length in offsets:
1084
txt.append('%d,%d' % (start, length))
1085
return '\n'.join(txt)
1087
def _write_protocol_version(self):
1088
self._write_func(MESSAGE_VERSION_THREE)
1090
def _write_prefixed_bencode(self, structure):
1091
bytes = bencode(structure)
1092
self._write_func(struct.pack('!L', len(bytes)))
1093
self._write_func(bytes)
1095
def _write_headers(self, headers):
1096
self._write_prefixed_bencode(headers)
1098
def _write_structure(self, args):
1099
self._write_func('s')
1102
if type(arg) is unicode:
1103
utf8_args.append(arg.encode('utf8'))
1105
utf8_args.append(arg)
1106
self._write_prefixed_bencode(utf8_args)
1108
def _write_end(self):
1109
self._write_func('e')
1112
def _write_prefixed_body(self, bytes):
1113
self._write_func('b')
1114
self._write_func(struct.pack('!L', len(bytes)))
1115
self._write_func(bytes)
1117
def _write_chunked_body_start(self):
1118
self._write_func('oC')
1120
def _write_error_status(self):
1121
self._write_func('oE')
1123
def _write_success_status(self):
1124
self._write_func('oS')
1127
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1129
def __init__(self, write_func):
1130
_ProtocolThreeEncoder.__init__(self, write_func)
1131
self.response_sent = False
1132
self._headers = {'Software version': bzrlib.__version__}
1134
def send_error(self, exception):
1135
if self.response_sent:
1136
raise AssertionError(
1137
"send_error(%s) called, but response already sent."
1139
if isinstance(exception, errors.UnknownSmartMethod):
1140
failure = request.FailedSmartServerResponse(
1141
('UnknownMethod', exception.verb))
1142
self.send_response(failure)
1144
self.response_sent = True
1145
self._write_protocol_version()
1146
self._write_headers(self._headers)
1147
self._write_error_status()
1148
self._write_structure(('error', str(exception)))
1151
def send_response(self, response):
1152
if self.response_sent:
1153
raise AssertionError(
1154
"send_response(%r) called, but response already sent."
1156
self.response_sent = True
1157
self._write_protocol_version()
1158
self._write_headers(self._headers)
1159
if response.is_successful():
1160
self._write_success_status()
1162
self._write_error_status()
1163
self._write_structure(response.args)
1164
if response.body is not None:
1165
self._write_prefixed_body(response.body)
1166
elif response.body_stream is not None:
1167
for exc_info, chunk in _iter_with_errors(response.body_stream):
1168
if exc_info is not None:
1169
self._write_error_status()
1170
error_struct = request._translate_error(exc_info[1])
1171
self._write_structure(error_struct)
1174
if isinstance(chunk, request.FailedSmartServerResponse):
1175
self._write_error_status()
1176
self._write_structure(chunk.args)
1178
self._write_prefixed_body(chunk)
1182
def _iter_with_errors(iterable):
1183
"""Handle errors from iterable.next().
1187
for exc_info, value in _iter_with_errors(iterable):
1190
This is a safer alternative to::
1193
for value in iterable:
1198
Because the latter will catch errors from the for-loop body, not just
1201
If an error occurs, exc_info will be a exc_info tuple, and the generator
1202
will terminate. Otherwise exc_info will be None, and value will be the
1203
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1204
will not be itercepted.
1206
iterator = iter(iterable)
1209
yield None, iterator.next()
1210
except StopIteration:
1212
except (KeyboardInterrupt, SystemExit):
1215
mutter('_iter_with_errors caught error')
1216
log_exception_quietly()
1217
yield sys.exc_info(), None
1221
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1223
def __init__(self, medium_request):
1224
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1225
self._medium_request = medium_request
1228
def set_headers(self, headers):
1229
self._headers = headers.copy()
1231
def call(self, *args):
1232
if 'hpss' in debug.debug_flags:
1233
mutter('hpss call: %s', repr(args)[1:-1])
1234
base = getattr(self._medium_request._medium, 'base', None)
1235
if base is not None:
1236
mutter(' (to %s)', base)
1237
self._request_start_time = time.time()
1238
self._write_protocol_version()
1239
self._write_headers(self._headers)
1240
self._write_structure(args)
1242
self._medium_request.finished_writing()
1244
def call_with_body_bytes(self, args, body):
1245
"""Make a remote call of args with body bytes 'body'.
1247
After calling this, call read_response_tuple to find the result out.
1249
if 'hpss' in debug.debug_flags:
1250
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1251
path = getattr(self._medium_request._medium, '_path', None)
1252
if path is not None:
1253
mutter(' (to %s)', path)
1254
mutter(' %d bytes', len(body))
1255
self._request_start_time = time.time()
1256
self._write_protocol_version()
1257
self._write_headers(self._headers)
1258
self._write_structure(args)
1259
self._write_prefixed_body(body)
1261
self._medium_request.finished_writing()
1263
def call_with_body_readv_array(self, args, body):
1264
"""Make a remote call with a readv array.
1266
The body is encoded with one line per readv offset pair. The numbers in
1267
each pair are separated by a comma, and no trailing \n is emitted.
1269
if 'hpss' in debug.debug_flags:
1270
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1271
path = getattr(self._medium_request._medium, '_path', None)
1272
if path is not None:
1273
mutter(' (to %s)', path)
1274
self._request_start_time = time.time()
1275
self._write_protocol_version()
1276
self._write_headers(self._headers)
1277
self._write_structure(args)
1278
readv_bytes = self._serialise_offsets(body)
1279
if 'hpss' in debug.debug_flags:
1280
mutter(' %d bytes in readv request', len(readv_bytes))
1281
self._write_prefixed_body(readv_bytes)
1283
self._medium_request.finished_writing()
1285
def call_with_body_stream(self, args, stream):
1286
if 'hpss' in debug.debug_flags:
1287
mutter('hpss call w/body stream: %r', args)
1288
path = getattr(self._medium_request._medium, '_path', None)
1289
if path is not None:
1290
mutter(' (to %s)', path)
1291
self._request_start_time = time.time()
1292
self._write_protocol_version()
1293
self._write_headers(self._headers)
1294
self._write_structure(args)
1295
# TODO: notice if the server has sent an early error reply before we
1296
# have finished sending the stream. We would notice at the end
1297
# anyway, but if the medium can deliver it early then it's good
1298
# to short-circuit the whole request...
1299
for exc_info, part in _iter_with_errors(stream):
1300
if exc_info is not None:
1301
# Iterating the stream failed. Cleanly abort the request.
1302
self._write_error_status()
1303
# Currently the client unconditionally sends ('error',) as the
1305
self._write_structure(('error',))
1307
self._medium_request.finished_writing()
1308
raise exc_info[0], exc_info[1], exc_info[2]
1310
self._write_prefixed_body(part)
1313
self._medium_request.finished_writing()