1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
from cStringIO import StringIO
35
from bzrlib.smart import message, request
36
from bzrlib.trace import log_exception_quietly, mutter
37
from bzrlib.bencode import bdecode_as_tuple, bencode
40
# Protocol version strings. These are sent as prefixes of bzr requests and
41
# responses to identify the protocol version being used. (There are no version
42
# one strings because that version doesn't send any).
43
REQUEST_VERSION_TWO = 'bzr request 2\n'
44
RESPONSE_VERSION_TWO = 'bzr response 2\n'
46
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
47
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
50
def _recv_tuple(from_file):
51
req_line = from_file.readline()
52
return _decode_tuple(req_line)
55
def _decode_tuple(req_line):
56
if req_line is None or req_line == '':
58
if req_line[-1] != '\n':
59
raise errors.SmartProtocolError("request %r not terminated" % req_line)
60
return tuple(req_line[:-1].split('\x01'))
63
def _encode_tuple(args):
64
"""Encode the tuple args to a bytestream."""
65
return '\x01'.join(args) + '\n'
68
class Requester(object):
69
"""Abstract base class for an object that can issue requests on a smart
73
def call(self, *args):
74
"""Make a remote call.
76
:param args: the arguments of this call.
78
raise NotImplementedError(self.call)
80
def call_with_body_bytes(self, args, body):
81
"""Make a remote call with a body.
83
:param args: the arguments of this call.
85
:param body: the body to send with the request.
87
raise NotImplementedError(self.call_with_body_bytes)
89
def call_with_body_readv_array(self, args, body):
90
"""Make a remote call with a readv array.
92
:param args: the arguments of this call.
93
:type body: iterable of (start, length) tuples.
94
:param body: the readv ranges to send with this request.
96
raise NotImplementedError(self.call_with_body_readv_array)
98
def set_headers(self, headers):
99
raise NotImplementedError(self.set_headers)
102
class SmartProtocolBase(object):
103
"""Methods common to client and server"""
105
# TODO: this only actually accomodates a single block; possibly should
106
# support multiple chunks?
107
def _encode_bulk_data(self, body):
108
"""Encode body as a bulk data chunk."""
109
return ''.join(('%d\n' % len(body), body, 'done\n'))
111
def _serialise_offsets(self, offsets):
112
"""Serialise a readv offset list."""
114
for start, length in offsets:
115
txt.append('%d,%d' % (start, length))
116
return '\n'.join(txt)
119
class SmartServerRequestProtocolOne(SmartProtocolBase):
120
"""Server-side encoding and decoding logic for smart version 1."""
122
def __init__(self, backing_transport, write_func, root_client_path='/',
124
self._backing_transport = backing_transport
125
self._root_client_path = root_client_path
126
self._jail_root = jail_root
127
self.unused_data = ''
128
self._finished = False
130
self._has_dispatched = False
132
self._body_decoder = None
133
self._write_func = write_func
135
def accept_bytes(self, bytes):
136
"""Take bytes, and advance the internal state machine appropriately.
138
:param bytes: must be a byte string
140
if not isinstance(bytes, str):
141
raise ValueError(bytes)
142
self.in_buffer += bytes
143
if not self._has_dispatched:
144
if '\n' not in self.in_buffer:
145
# no command line yet
147
self._has_dispatched = True
149
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
151
req_args = _decode_tuple(first_line)
152
self.request = request.SmartServerRequestHandler(
153
self._backing_transport, commands=request.request_handlers,
154
root_client_path=self._root_client_path,
155
jail_root=self._jail_root)
156
self.request.args_received(req_args)
157
if self.request.finished_reading:
159
self.unused_data = self.in_buffer
161
self._send_response(self.request.response)
162
except KeyboardInterrupt:
164
except errors.UnknownSmartMethod, err:
165
protocol_error = errors.SmartProtocolError(
166
"bad request %r" % (err.verb,))
167
failure = request.FailedSmartServerResponse(
168
('error', str(protocol_error)))
169
self._send_response(failure)
171
except Exception, exception:
172
# everything else: pass to client, flush, and quit
173
log_exception_quietly()
174
self._send_response(request.FailedSmartServerResponse(
175
('error', str(exception))))
178
if self._has_dispatched:
180
# nothing to do.XXX: this routine should be a single state
182
self.unused_data += self.in_buffer
185
if self._body_decoder is None:
186
self._body_decoder = LengthPrefixedBodyDecoder()
187
self._body_decoder.accept_bytes(self.in_buffer)
188
self.in_buffer = self._body_decoder.unused_data
189
body_data = self._body_decoder.read_pending_data()
190
self.request.accept_body(body_data)
191
if self._body_decoder.finished_reading:
192
self.request.end_of_body()
193
if not self.request.finished_reading:
194
raise AssertionError("no more body, request not finished")
195
if self.request.response is not None:
196
self._send_response(self.request.response)
197
self.unused_data = self.in_buffer
200
if self.request.finished_reading:
201
raise AssertionError(
202
"no response and we have finished reading.")
204
def _send_response(self, response):
205
"""Send a smart server response down the output stream."""
207
raise AssertionError('response already sent')
210
self._finished = True
211
self._write_protocol_version()
212
self._write_success_or_failure_prefix(response)
213
self._write_func(_encode_tuple(args))
215
if not isinstance(body, str):
216
raise ValueError(body)
217
bytes = self._encode_bulk_data(body)
218
self._write_func(bytes)
220
def _write_protocol_version(self):
221
"""Write any prefixes this protocol requires.
223
Version one doesn't send protocol versions.
226
def _write_success_or_failure_prefix(self, response):
227
"""Write the protocol specific success/failure prefix.
229
For SmartServerRequestProtocolOne this is omitted but we
230
call is_successful to ensure that the response is valid.
232
response.is_successful()
234
def next_read_size(self):
237
if self._body_decoder is None:
240
return self._body_decoder.next_read_size()
243
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
244
r"""Version two of the server side of the smart protocol.
246
This prefixes responses with the value of RESPONSE_VERSION_TWO.
249
response_marker = RESPONSE_VERSION_TWO
250
request_marker = REQUEST_VERSION_TWO
252
def _write_success_or_failure_prefix(self, response):
253
"""Write the protocol specific success/failure prefix."""
254
if response.is_successful():
255
self._write_func('success\n')
257
self._write_func('failed\n')
259
def _write_protocol_version(self):
260
r"""Write any prefixes this protocol requires.
262
Version two sends the value of RESPONSE_VERSION_TWO.
264
self._write_func(self.response_marker)
266
def _send_response(self, response):
267
"""Send a smart server response down the output stream."""
269
raise AssertionError('response already sent')
270
self._finished = True
271
self._write_protocol_version()
272
self._write_success_or_failure_prefix(response)
273
self._write_func(_encode_tuple(response.args))
274
if response.body is not None:
275
if not isinstance(response.body, str):
276
raise AssertionError('body must be a str')
277
if not (response.body_stream is None):
278
raise AssertionError(
279
'body_stream and body cannot both be set')
280
bytes = self._encode_bulk_data(response.body)
281
self._write_func(bytes)
282
elif response.body_stream is not None:
283
_send_stream(response.body_stream, self._write_func)
286
def _send_stream(stream, write_func):
287
write_func('chunked\n')
288
_send_chunks(stream, write_func)
292
def _send_chunks(stream, write_func):
294
if isinstance(chunk, str):
295
bytes = "%x\n%s" % (len(chunk), chunk)
297
elif isinstance(chunk, request.FailedSmartServerResponse):
299
_send_chunks(chunk.args, write_func)
302
raise errors.BzrError(
303
'Chunks must be str or FailedSmartServerResponse, got %r'
307
class _NeedMoreBytes(Exception):
308
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
312
def __init__(self, count=None):
315
:param count: the total number of bytes needed by the current state.
316
May be None if the number of bytes needed is unknown.
321
class _StatefulDecoder(object):
322
"""Base class for writing state machines to decode byte streams.
324
Subclasses should provide a self.state_accept attribute that accepts bytes
325
and, if appropriate, updates self.state_accept to a different function.
326
accept_bytes will call state_accept as often as necessary to make sure the
327
state machine has progressed as far as possible before it returns.
329
See ProtocolThreeDecoder for an example subclass.
333
self.finished_reading = False
334
self._in_buffer_list = []
335
self._in_buffer_len = 0
336
self.unused_data = ''
337
self.bytes_left = None
338
self._number_needed_bytes = None
340
def _get_in_buffer(self):
341
if len(self._in_buffer_list) == 1:
342
return self._in_buffer_list[0]
343
in_buffer = ''.join(self._in_buffer_list)
344
if len(in_buffer) != self._in_buffer_len:
345
raise AssertionError(
346
"Length of buffer did not match expected value: %s != %s"
347
% self._in_buffer_len, len(in_buffer))
348
self._in_buffer_list = [in_buffer]
351
def _get_in_bytes(self, count):
352
"""Grab X bytes from the input_buffer.
354
Callers should have already checked that self._in_buffer_len is >
355
count. Note, this does not consume the bytes from the buffer. The
356
caller will still need to call _get_in_buffer() and then
357
_set_in_buffer() if they actually need to consume the bytes.
359
# check if we can yield the bytes from just the first entry in our list
360
if len(self._in_buffer_list) == 0:
361
raise AssertionError('Callers must be sure we have buffered bytes'
362
' before calling _get_in_bytes')
363
if len(self._in_buffer_list[0]) > count:
364
return self._in_buffer_list[0][:count]
365
# We can't yield it from the first buffer, so collapse all buffers, and
367
in_buf = self._get_in_buffer()
368
return in_buf[:count]
370
def _set_in_buffer(self, new_buf):
371
if new_buf is not None:
372
self._in_buffer_list = [new_buf]
373
self._in_buffer_len = len(new_buf)
375
self._in_buffer_list = []
376
self._in_buffer_len = 0
378
def accept_bytes(self, bytes):
379
"""Decode as much of bytes as possible.
381
If 'bytes' contains too much data it will be appended to
384
finished_reading will be set when no more data is required. Further
385
data will be appended to self.unused_data.
387
# accept_bytes is allowed to change the state
388
self._number_needed_bytes = None
389
# lsprof puts a very large amount of time on this specific call for
391
self._in_buffer_list.append(bytes)
392
self._in_buffer_len += len(bytes)
394
# Run the function for the current state.
395
current_state = self.state_accept
397
while current_state != self.state_accept:
398
# The current state has changed. Run the function for the new
399
# current state, so that it can:
400
# - decode any unconsumed bytes left in a buffer, and
401
# - signal how many more bytes are expected (via raising
403
current_state = self.state_accept
405
except _NeedMoreBytes, e:
406
self._number_needed_bytes = e.count
409
class ChunkedBodyDecoder(_StatefulDecoder):
410
"""Decoder for chunked body data.
412
This is very similar the HTTP's chunked encoding. See the description of
413
streamed body data in `doc/developers/network-protocol.txt` for details.
417
_StatefulDecoder.__init__(self)
418
self.state_accept = self._state_accept_expecting_header
419
self.chunk_in_progress = None
420
self.chunks = collections.deque()
422
self.error_in_progress = None
424
def next_read_size(self):
425
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
426
# end-of-body marker is 4 bytes: 'END\n'.
427
if self.state_accept == self._state_accept_reading_chunk:
428
# We're expecting more chunk content. So we're expecting at least
429
# the rest of this chunk plus an END chunk.
430
return self.bytes_left + 4
431
elif self.state_accept == self._state_accept_expecting_length:
432
if self._in_buffer_len == 0:
433
# We're expecting a chunk length. There's at least two bytes
434
# left: a digit plus '\n'.
437
# We're in the middle of reading a chunk length. So there's at
438
# least one byte left, the '\n' that terminates the length.
440
elif self.state_accept == self._state_accept_reading_unused:
442
elif self.state_accept == self._state_accept_expecting_header:
443
return max(0, len('chunked\n') - self._in_buffer_len)
445
raise AssertionError("Impossible state: %r" % (self.state_accept,))
447
def read_next_chunk(self):
449
return self.chunks.popleft()
453
def _extract_line(self):
454
in_buf = self._get_in_buffer()
455
pos = in_buf.find('\n')
457
# We haven't read a complete line yet, so request more bytes before
459
raise _NeedMoreBytes(1)
461
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
462
self._set_in_buffer(in_buf[pos+1:])
466
self.unused_data = self._get_in_buffer()
467
self._in_buffer_list = []
468
self._in_buffer_len = 0
469
self.state_accept = self._state_accept_reading_unused
471
error_args = tuple(self.error_in_progress)
472
self.chunks.append(request.FailedSmartServerResponse(error_args))
473
self.error_in_progress = None
474
self.finished_reading = True
476
def _state_accept_expecting_header(self):
477
prefix = self._extract_line()
478
if prefix == 'chunked':
479
self.state_accept = self._state_accept_expecting_length
481
raise errors.SmartProtocolError(
482
'Bad chunked body header: "%s"' % (prefix,))
484
def _state_accept_expecting_length(self):
485
prefix = self._extract_line()
488
self.error_in_progress = []
489
self._state_accept_expecting_length()
491
elif prefix == 'END':
492
# We've read the end-of-body marker.
493
# Any further bytes are unused data, including the bytes left in
498
self.bytes_left = int(prefix, 16)
499
self.chunk_in_progress = ''
500
self.state_accept = self._state_accept_reading_chunk
502
def _state_accept_reading_chunk(self):
503
in_buf = self._get_in_buffer()
504
in_buffer_len = len(in_buf)
505
self.chunk_in_progress += in_buf[:self.bytes_left]
506
self._set_in_buffer(in_buf[self.bytes_left:])
507
self.bytes_left -= in_buffer_len
508
if self.bytes_left <= 0:
509
# Finished with chunk
510
self.bytes_left = None
512
self.error_in_progress.append(self.chunk_in_progress)
514
self.chunks.append(self.chunk_in_progress)
515
self.chunk_in_progress = None
516
self.state_accept = self._state_accept_expecting_length
518
def _state_accept_reading_unused(self):
519
self.unused_data += self._get_in_buffer()
520
self._in_buffer_list = []
523
class LengthPrefixedBodyDecoder(_StatefulDecoder):
524
"""Decodes the length-prefixed bulk data."""
527
_StatefulDecoder.__init__(self)
528
self.state_accept = self._state_accept_expecting_length
529
self.state_read = self._state_read_no_data
531
self._trailer_buffer = ''
533
def next_read_size(self):
534
if self.bytes_left is not None:
535
# Ideally we want to read all the remainder of the body and the
537
return self.bytes_left + 5
538
elif self.state_accept == self._state_accept_reading_trailer:
539
# Just the trailer left
540
return 5 - len(self._trailer_buffer)
541
elif self.state_accept == self._state_accept_expecting_length:
542
# There's still at least 6 bytes left ('\n' to end the length, plus
546
# Reading excess data. Either way, 1 byte at a time is fine.
549
def read_pending_data(self):
550
"""Return any pending data that has been decoded."""
551
return self.state_read()
553
def _state_accept_expecting_length(self):
554
in_buf = self._get_in_buffer()
555
pos = in_buf.find('\n')
558
self.bytes_left = int(in_buf[:pos])
559
self._set_in_buffer(in_buf[pos+1:])
560
self.state_accept = self._state_accept_reading_body
561
self.state_read = self._state_read_body_buffer
563
def _state_accept_reading_body(self):
564
in_buf = self._get_in_buffer()
566
self.bytes_left -= len(in_buf)
567
self._set_in_buffer(None)
568
if self.bytes_left <= 0:
570
if self.bytes_left != 0:
571
self._trailer_buffer = self._body[self.bytes_left:]
572
self._body = self._body[:self.bytes_left]
573
self.bytes_left = None
574
self.state_accept = self._state_accept_reading_trailer
576
def _state_accept_reading_trailer(self):
577
self._trailer_buffer += self._get_in_buffer()
578
self._set_in_buffer(None)
579
# TODO: what if the trailer does not match "done\n"? Should this raise
580
# a ProtocolViolation exception?
581
if self._trailer_buffer.startswith('done\n'):
582
self.unused_data = self._trailer_buffer[len('done\n'):]
583
self.state_accept = self._state_accept_reading_unused
584
self.finished_reading = True
586
def _state_accept_reading_unused(self):
587
self.unused_data += self._get_in_buffer()
588
self._set_in_buffer(None)
590
def _state_read_no_data(self):
593
def _state_read_body_buffer(self):
599
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
600
message.ResponseHandler):
601
"""The client-side protocol for smart version 1."""
603
def __init__(self, request):
604
"""Construct a SmartClientRequestProtocolOne.
606
:param request: A SmartClientMediumRequest to serialise onto and
609
self._request = request
610
self._body_buffer = None
611
self._request_start_time = None
612
self._last_verb = None
615
def set_headers(self, headers):
616
self._headers = dict(headers)
618
def call(self, *args):
619
if 'hpss' in debug.debug_flags:
620
mutter('hpss call: %s', repr(args)[1:-1])
621
if getattr(self._request._medium, 'base', None) is not None:
622
mutter(' (to %s)', self._request._medium.base)
623
self._request_start_time = osutils.timer_func()
624
self._write_args(args)
625
self._request.finished_writing()
626
self._last_verb = args[0]
628
def call_with_body_bytes(self, args, body):
629
"""Make a remote call of args with body bytes 'body'.
631
After calling this, call read_response_tuple to find the result out.
633
if 'hpss' in debug.debug_flags:
634
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
635
if getattr(self._request._medium, '_path', None) is not None:
636
mutter(' (to %s)', self._request._medium._path)
637
mutter(' %d bytes', len(body))
638
self._request_start_time = osutils.timer_func()
639
if 'hpssdetail' in debug.debug_flags:
640
mutter('hpss body content: %s', body)
641
self._write_args(args)
642
bytes = self._encode_bulk_data(body)
643
self._request.accept_bytes(bytes)
644
self._request.finished_writing()
645
self._last_verb = args[0]
647
def call_with_body_readv_array(self, args, body):
648
"""Make a remote call with a readv array.
650
The body is encoded with one line per readv offset pair. The numbers in
651
each pair are separated by a comma, and no trailing \n is emitted.
653
if 'hpss' in debug.debug_flags:
654
mutter('hpss call w/readv: %s', repr(args)[1:-1])
655
if getattr(self._request._medium, '_path', None) is not None:
656
mutter(' (to %s)', self._request._medium._path)
657
self._request_start_time = osutils.timer_func()
658
self._write_args(args)
659
readv_bytes = self._serialise_offsets(body)
660
bytes = self._encode_bulk_data(readv_bytes)
661
self._request.accept_bytes(bytes)
662
self._request.finished_writing()
663
if 'hpss' in debug.debug_flags:
664
mutter(' %d bytes in readv request', len(readv_bytes))
665
self._last_verb = args[0]
667
def call_with_body_stream(self, args, stream):
668
# Protocols v1 and v2 don't support body streams. So it's safe to
669
# assume that a v1/v2 server doesn't support whatever method we're
670
# trying to call with a body stream.
671
self._request.finished_writing()
672
self._request.finished_reading()
673
raise errors.UnknownSmartMethod(args[0])
675
def cancel_read_body(self):
676
"""After expecting a body, a response code may indicate one otherwise.
678
This method lets the domain client inform the protocol that no body
679
will be transmitted. This is a terminal method: after calling it the
680
protocol is not able to be used further.
682
self._request.finished_reading()
684
def _read_response_tuple(self):
685
result = self._recv_tuple()
686
if 'hpss' in debug.debug_flags:
687
if self._request_start_time is not None:
688
mutter(' result: %6.3fs %s',
689
osutils.timer_func() - self._request_start_time,
691
self._request_start_time = None
693
mutter(' result: %s', repr(result)[1:-1])
696
def read_response_tuple(self, expect_body=False):
697
"""Read a response tuple from the wire.
699
This should only be called once.
701
result = self._read_response_tuple()
702
self._response_is_unknown_method(result)
703
self._raise_args_if_error(result)
705
self._request.finished_reading()
708
def _raise_args_if_error(self, result_tuple):
709
# Later protocol versions have an explicit flag in the protocol to say
710
# if an error response is "failed" or not. In version 1 we don't have
711
# that luxury. So here is a complete list of errors that can be
712
# returned in response to existing version 1 smart requests. Responses
713
# starting with these codes are always "failed" responses.
720
'UnicodeEncodeError',
721
'UnicodeDecodeError',
727
'UnlockableTransport',
733
if result_tuple[0] in v1_error_codes:
734
self._request.finished_reading()
735
raise errors.ErrorFromSmartServer(result_tuple)
737
def _response_is_unknown_method(self, result_tuple):
738
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
739
method' response to the request.
741
:param response: The response from a smart client call_expecting_body
743
:param verb: The verb used in that call.
744
:raises: UnexpectedSmartServerResponse
746
if (result_tuple == ('error', "Generic bzr smart protocol error: "
747
"bad request '%s'" % self._last_verb) or
748
result_tuple == ('error', "Generic bzr smart protocol error: "
749
"bad request u'%s'" % self._last_verb)):
750
# The response will have no body, so we've finished reading.
751
self._request.finished_reading()
752
raise errors.UnknownSmartMethod(self._last_verb)
754
def read_body_bytes(self, count=-1):
755
"""Read bytes from the body, decoding into a byte stream.
757
We read all bytes at once to ensure we've checked the trailer for
758
errors, and then feed the buffer back as read_body_bytes is called.
760
if self._body_buffer is not None:
761
return self._body_buffer.read(count)
762
_body_decoder = LengthPrefixedBodyDecoder()
764
while not _body_decoder.finished_reading:
765
bytes = self._request.read_bytes(_body_decoder.next_read_size())
767
# end of file encountered reading from server
768
raise errors.ConnectionReset(
769
"Connection lost while reading response body.")
770
_body_decoder.accept_bytes(bytes)
771
self._request.finished_reading()
772
self._body_buffer = StringIO(_body_decoder.read_pending_data())
773
# XXX: TODO check the trailer result.
774
if 'hpss' in debug.debug_flags:
775
mutter(' %d body bytes read',
776
len(self._body_buffer.getvalue()))
777
return self._body_buffer.read(count)
779
def _recv_tuple(self):
780
"""Receive a tuple from the medium request."""
781
return _decode_tuple(self._request.read_line())
783
def query_version(self):
784
"""Return protocol version number of the server."""
786
resp = self.read_response_tuple()
787
if resp == ('ok', '1'):
789
elif resp == ('ok', '2'):
792
raise errors.SmartProtocolError("bad response %r" % (resp,))
794
def _write_args(self, args):
795
self._write_protocol_version()
796
bytes = _encode_tuple(args)
797
self._request.accept_bytes(bytes)
799
def _write_protocol_version(self):
800
"""Write any prefixes this protocol requires.
802
Version one doesn't send protocol versions.
806
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
807
"""Version two of the client side of the smart protocol.
809
This prefixes the request with the value of REQUEST_VERSION_TWO.
812
response_marker = RESPONSE_VERSION_TWO
813
request_marker = REQUEST_VERSION_TWO
815
def read_response_tuple(self, expect_body=False):
816
"""Read a response tuple from the wire.
818
This should only be called once.
820
version = self._request.read_line()
821
if version != self.response_marker:
822
self._request.finished_reading()
823
raise errors.UnexpectedProtocolVersionMarker(version)
824
response_status = self._request.read_line()
825
result = SmartClientRequestProtocolOne._read_response_tuple(self)
826
self._response_is_unknown_method(result)
827
if response_status == 'success\n':
828
self.response_status = True
830
self._request.finished_reading()
832
elif response_status == 'failed\n':
833
self.response_status = False
834
self._request.finished_reading()
835
raise errors.ErrorFromSmartServer(result)
837
raise errors.SmartProtocolError(
838
'bad protocol status %r' % response_status)
840
def _write_protocol_version(self):
841
"""Write any prefixes this protocol requires.
843
Version two sends the value of REQUEST_VERSION_TWO.
845
self._request.accept_bytes(self.request_marker)
847
def read_streamed_body(self):
848
"""Read bytes from the body, decoding into a byte stream.
850
# Read no more than 64k at a time so that we don't risk error 10055 (no
851
# buffer space available) on Windows.
852
_body_decoder = ChunkedBodyDecoder()
853
while not _body_decoder.finished_reading:
854
bytes = self._request.read_bytes(_body_decoder.next_read_size())
856
# end of file encountered reading from server
857
raise errors.ConnectionReset(
858
"Connection lost while reading streamed body.")
859
_body_decoder.accept_bytes(bytes)
860
for body_bytes in iter(_body_decoder.read_next_chunk, None):
861
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
862
mutter(' %d byte chunk read',
865
self._request.finished_reading()
868
def build_server_protocol_three(backing_transport, write_func,
869
root_client_path, jail_root=None):
870
request_handler = request.SmartServerRequestHandler(
871
backing_transport, commands=request.request_handlers,
872
root_client_path=root_client_path, jail_root=jail_root)
873
responder = ProtocolThreeResponder(write_func)
874
message_handler = message.ConventionalRequestHandler(request_handler, responder)
875
return ProtocolThreeDecoder(message_handler)
878
class ProtocolThreeDecoder(_StatefulDecoder):
880
response_marker = RESPONSE_VERSION_THREE
881
request_marker = REQUEST_VERSION_THREE
883
def __init__(self, message_handler, expect_version_marker=False):
884
_StatefulDecoder.__init__(self)
885
self._has_dispatched = False
887
if expect_version_marker:
888
self.state_accept = self._state_accept_expecting_protocol_version
889
# We're expecting at least the protocol version marker + some
891
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
893
self.state_accept = self._state_accept_expecting_headers
894
self._number_needed_bytes = 4
895
self.decoding_failed = False
896
self.request_handler = self.message_handler = message_handler
898
def accept_bytes(self, bytes):
899
self._number_needed_bytes = None
901
_StatefulDecoder.accept_bytes(self, bytes)
902
except KeyboardInterrupt:
904
except errors.SmartMessageHandlerError, exception:
905
# We do *not* set self.decoding_failed here. The message handler
906
# has raised an error, but the decoder is still able to parse bytes
907
# and determine when this message ends.
908
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
909
log_exception_quietly()
910
self.message_handler.protocol_error(exception.exc_value)
911
# The state machine is ready to continue decoding, but the
912
# exception has interrupted the loop that runs the state machine.
913
# So we call accept_bytes again to restart it.
914
self.accept_bytes('')
915
except Exception, exception:
916
# The decoder itself has raised an exception. We cannot continue
918
self.decoding_failed = True
919
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
920
# This happens during normal operation when the client tries a
921
# protocol version the server doesn't understand, so no need to
922
# log a traceback every time.
923
# Note that this can only happen when
924
# expect_version_marker=True, which is only the case on the
928
log_exception_quietly()
929
self.message_handler.protocol_error(exception)
931
def _extract_length_prefixed_bytes(self):
932
if self._in_buffer_len < 4:
933
# A length prefix by itself is 4 bytes, and we don't even have that
935
raise _NeedMoreBytes(4)
936
(length,) = struct.unpack('!L', self._get_in_bytes(4))
937
end_of_bytes = 4 + length
938
if self._in_buffer_len < end_of_bytes:
939
# We haven't yet read as many bytes as the length-prefix says there
941
raise _NeedMoreBytes(end_of_bytes)
942
# Extract the bytes from the buffer.
943
in_buf = self._get_in_buffer()
944
bytes = in_buf[4:end_of_bytes]
945
self._set_in_buffer(in_buf[end_of_bytes:])
948
def _extract_prefixed_bencoded_data(self):
949
prefixed_bytes = self._extract_length_prefixed_bytes()
951
decoded = bdecode_as_tuple(prefixed_bytes)
953
raise errors.SmartProtocolError(
954
'Bytes %r not bencoded' % (prefixed_bytes,))
957
def _extract_single_byte(self):
958
if self._in_buffer_len == 0:
959
# The buffer is empty
960
raise _NeedMoreBytes(1)
961
in_buf = self._get_in_buffer()
963
self._set_in_buffer(in_buf[1:])
966
def _state_accept_expecting_protocol_version(self):
967
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
968
in_buf = self._get_in_buffer()
970
# We don't have enough bytes to check if the protocol version
971
# marker is right. But we can check if it is already wrong by
972
# checking that the start of MESSAGE_VERSION_THREE matches what
974
# [In fact, if the remote end isn't bzr we might never receive
975
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
976
# are wrong then we should just raise immediately rather than
978
if not MESSAGE_VERSION_THREE.startswith(in_buf):
979
# We have enough bytes to know the protocol version is wrong
980
raise errors.UnexpectedProtocolVersionMarker(in_buf)
981
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
982
if not in_buf.startswith(MESSAGE_VERSION_THREE):
983
raise errors.UnexpectedProtocolVersionMarker(in_buf)
984
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
985
self.state_accept = self._state_accept_expecting_headers
987
def _state_accept_expecting_headers(self):
988
decoded = self._extract_prefixed_bencoded_data()
989
if type(decoded) is not dict:
990
raise errors.SmartProtocolError(
991
'Header object %r is not a dict' % (decoded,))
992
self.state_accept = self._state_accept_expecting_message_part
994
self.message_handler.headers_received(decoded)
996
raise errors.SmartMessageHandlerError(sys.exc_info())
998
def _state_accept_expecting_message_part(self):
999
message_part_kind = self._extract_single_byte()
1000
if message_part_kind == 'o':
1001
self.state_accept = self._state_accept_expecting_one_byte
1002
elif message_part_kind == 's':
1003
self.state_accept = self._state_accept_expecting_structure
1004
elif message_part_kind == 'b':
1005
self.state_accept = self._state_accept_expecting_bytes
1006
elif message_part_kind == 'e':
1009
raise errors.SmartProtocolError(
1010
'Bad message kind byte: %r' % (message_part_kind,))
1012
def _state_accept_expecting_one_byte(self):
1013
byte = self._extract_single_byte()
1014
self.state_accept = self._state_accept_expecting_message_part
1016
self.message_handler.byte_part_received(byte)
1018
raise errors.SmartMessageHandlerError(sys.exc_info())
1020
def _state_accept_expecting_bytes(self):
1021
# XXX: this should not buffer whole message part, but instead deliver
1022
# the bytes as they arrive.
1023
prefixed_bytes = self._extract_length_prefixed_bytes()
1024
self.state_accept = self._state_accept_expecting_message_part
1026
self.message_handler.bytes_part_received(prefixed_bytes)
1028
raise errors.SmartMessageHandlerError(sys.exc_info())
1030
def _state_accept_expecting_structure(self):
1031
structure = self._extract_prefixed_bencoded_data()
1032
self.state_accept = self._state_accept_expecting_message_part
1034
self.message_handler.structure_part_received(structure)
1036
raise errors.SmartMessageHandlerError(sys.exc_info())
1039
self.unused_data = self._get_in_buffer()
1040
self._set_in_buffer(None)
1041
self.state_accept = self._state_accept_reading_unused
1043
self.message_handler.end_received()
1045
raise errors.SmartMessageHandlerError(sys.exc_info())
1047
def _state_accept_reading_unused(self):
1048
self.unused_data += self._get_in_buffer()
1049
self._set_in_buffer(None)
1051
def next_read_size(self):
1052
if self.state_accept == self._state_accept_reading_unused:
1054
elif self.decoding_failed:
1055
# An exception occured while processing this message, probably from
1056
# self.message_handler. We're not sure that this state machine is
1057
# in a consistent state, so just signal that we're done (i.e. give
1061
if self._number_needed_bytes is not None:
1062
return self._number_needed_bytes - self._in_buffer_len
1064
raise AssertionError("don't know how many bytes are expected!")
1067
class _ProtocolThreeEncoder(object):
1069
response_marker = request_marker = MESSAGE_VERSION_THREE
1070
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1072
def __init__(self, write_func):
1075
self._real_write_func = write_func
1077
def _write_func(self, bytes):
1078
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1079
# for total number of bytes to write, rather than buffer based on
1080
# the number of write() calls
1081
# TODO: Another possibility would be to turn this into an async model.
1082
# Where we let another thread know that we have some bytes if
1083
# they want it, but we don't actually block for it
1084
# Note that osutils.send_all always sends 64kB chunks anyway, so
1085
# we might just push out smaller bits at a time?
1086
self._buf.append(bytes)
1087
self._buf_len += len(bytes)
1088
if self._buf_len > self.BUFFER_SIZE:
1093
self._real_write_func(''.join(self._buf))
1097
def _serialise_offsets(self, offsets):
1098
"""Serialise a readv offset list."""
1100
for start, length in offsets:
1101
txt.append('%d,%d' % (start, length))
1102
return '\n'.join(txt)
1104
def _write_protocol_version(self):
1105
self._write_func(MESSAGE_VERSION_THREE)
1107
def _write_prefixed_bencode(self, structure):
1108
bytes = bencode(structure)
1109
self._write_func(struct.pack('!L', len(bytes)))
1110
self._write_func(bytes)
1112
def _write_headers(self, headers):
1113
self._write_prefixed_bencode(headers)
1115
def _write_structure(self, args):
1116
self._write_func('s')
1119
if type(arg) is unicode:
1120
utf8_args.append(arg.encode('utf8'))
1122
utf8_args.append(arg)
1123
self._write_prefixed_bencode(utf8_args)
1125
def _write_end(self):
1126
self._write_func('e')
1129
def _write_prefixed_body(self, bytes):
1130
self._write_func('b')
1131
self._write_func(struct.pack('!L', len(bytes)))
1132
self._write_func(bytes)
1134
def _write_chunked_body_start(self):
1135
self._write_func('oC')
1137
def _write_error_status(self):
1138
self._write_func('oE')
1140
def _write_success_status(self):
1141
self._write_func('oS')
1144
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1146
def __init__(self, write_func):
1147
_ProtocolThreeEncoder.__init__(self, write_func)
1148
self.response_sent = False
1149
self._headers = {'Software version': bzrlib.__version__}
1150
if 'hpss' in debug.debug_flags:
1151
self._thread_id = thread.get_ident()
1152
self._response_start_time = None
1154
def _trace(self, action, message, extra_bytes=None, include_time=False):
1155
if self._response_start_time is None:
1156
self._response_start_time = osutils.timer_func()
1158
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1161
if extra_bytes is None:
1164
extra = ' ' + repr(extra_bytes[:40])
1166
extra = extra[:29] + extra[-1] + '...'
1167
mutter('%12s: [%s] %s%s%s'
1168
% (action, self._thread_id, t, message, extra))
1170
def send_error(self, exception):
1171
if self.response_sent:
1172
raise AssertionError(
1173
"send_error(%s) called, but response already sent."
1175
if isinstance(exception, errors.UnknownSmartMethod):
1176
failure = request.FailedSmartServerResponse(
1177
('UnknownMethod', exception.verb))
1178
self.send_response(failure)
1180
if 'hpss' in debug.debug_flags:
1181
self._trace('error', str(exception))
1182
self.response_sent = True
1183
self._write_protocol_version()
1184
self._write_headers(self._headers)
1185
self._write_error_status()
1186
self._write_structure(('error', str(exception)))
1189
def send_response(self, response):
1190
if self.response_sent:
1191
raise AssertionError(
1192
"send_response(%r) called, but response already sent."
1194
self.response_sent = True
1195
self._write_protocol_version()
1196
self._write_headers(self._headers)
1197
if response.is_successful():
1198
self._write_success_status()
1200
self._write_error_status()
1201
if 'hpss' in debug.debug_flags:
1202
self._trace('response', repr(response.args))
1203
self._write_structure(response.args)
1204
if response.body is not None:
1205
self._write_prefixed_body(response.body)
1206
if 'hpss' in debug.debug_flags:
1207
self._trace('body', '%d bytes' % (len(response.body),),
1208
response.body, include_time=True)
1209
elif response.body_stream is not None:
1210
count = num_bytes = 0
1212
for exc_info, chunk in _iter_with_errors(response.body_stream):
1214
if exc_info is not None:
1215
self._write_error_status()
1216
error_struct = request._translate_error(exc_info[1])
1217
self._write_structure(error_struct)
1220
if isinstance(chunk, request.FailedSmartServerResponse):
1221
self._write_error_status()
1222
self._write_structure(chunk.args)
1224
num_bytes += len(chunk)
1225
if first_chunk is None:
1227
self._write_prefixed_body(chunk)
1228
if 'hpssdetail' in debug.debug_flags:
1229
# Not worth timing separately, as _write_func is
1231
self._trace('body chunk',
1232
'%d bytes' % (len(chunk),),
1233
chunk, suppress_time=True)
1234
if 'hpss' in debug.debug_flags:
1235
self._trace('body stream',
1236
'%d bytes %d chunks' % (num_bytes, count),
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('response end', '', include_time=True)
1243
def _iter_with_errors(iterable):
1244
"""Handle errors from iterable.next().
1248
for exc_info, value in _iter_with_errors(iterable):
1251
This is a safer alternative to::
1254
for value in iterable:
1259
Because the latter will catch errors from the for-loop body, not just
1262
If an error occurs, exc_info will be a exc_info tuple, and the generator
1263
will terminate. Otherwise exc_info will be None, and value will be the
1264
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1265
will not be itercepted.
1267
iterator = iter(iterable)
1270
yield None, iterator.next()
1271
except StopIteration:
1273
except (KeyboardInterrupt, SystemExit):
1276
mutter('_iter_with_errors caught error')
1277
log_exception_quietly()
1278
yield sys.exc_info(), None
1282
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1284
def __init__(self, medium_request):
1285
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1286
self._medium_request = medium_request
1289
def set_headers(self, headers):
1290
self._headers = headers.copy()
1292
def call(self, *args):
1293
if 'hpss' in debug.debug_flags:
1294
mutter('hpss call: %s', repr(args)[1:-1])
1295
base = getattr(self._medium_request._medium, 'base', None)
1296
if base is not None:
1297
mutter(' (to %s)', base)
1298
self._request_start_time = osutils.timer_func()
1299
self._write_protocol_version()
1300
self._write_headers(self._headers)
1301
self._write_structure(args)
1303
self._medium_request.finished_writing()
1305
def call_with_body_bytes(self, args, body):
1306
"""Make a remote call of args with body bytes 'body'.
1308
After calling this, call read_response_tuple to find the result out.
1310
if 'hpss' in debug.debug_flags:
1311
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1312
path = getattr(self._medium_request._medium, '_path', None)
1313
if path is not None:
1314
mutter(' (to %s)', path)
1315
mutter(' %d bytes', len(body))
1316
self._request_start_time = osutils.timer_func()
1317
self._write_protocol_version()
1318
self._write_headers(self._headers)
1319
self._write_structure(args)
1320
self._write_prefixed_body(body)
1322
self._medium_request.finished_writing()
1324
def call_with_body_readv_array(self, args, body):
1325
"""Make a remote call with a readv array.
1327
The body is encoded with one line per readv offset pair. The numbers in
1328
each pair are separated by a comma, and no trailing \n is emitted.
1330
if 'hpss' in debug.debug_flags:
1331
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1332
path = getattr(self._medium_request._medium, '_path', None)
1333
if path is not None:
1334
mutter(' (to %s)', path)
1335
self._request_start_time = osutils.timer_func()
1336
self._write_protocol_version()
1337
self._write_headers(self._headers)
1338
self._write_structure(args)
1339
readv_bytes = self._serialise_offsets(body)
1340
if 'hpss' in debug.debug_flags:
1341
mutter(' %d bytes in readv request', len(readv_bytes))
1342
self._write_prefixed_body(readv_bytes)
1344
self._medium_request.finished_writing()
1346
def call_with_body_stream(self, args, stream):
1347
if 'hpss' in debug.debug_flags:
1348
mutter('hpss call w/body stream: %r', args)
1349
path = getattr(self._medium_request._medium, '_path', None)
1350
if path is not None:
1351
mutter(' (to %s)', path)
1352
self._request_start_time = osutils.timer_func()
1353
self._write_protocol_version()
1354
self._write_headers(self._headers)
1355
self._write_structure(args)
1356
# TODO: notice if the server has sent an early error reply before we
1357
# have finished sending the stream. We would notice at the end
1358
# anyway, but if the medium can deliver it early then it's good
1359
# to short-circuit the whole request...
1360
for exc_info, part in _iter_with_errors(stream):
1361
if exc_info is not None:
1362
# Iterating the stream failed. Cleanly abort the request.
1363
self._write_error_status()
1364
# Currently the client unconditionally sends ('error',) as the
1366
self._write_structure(('error',))
1368
self._medium_request.finished_writing()
1369
raise exc_info[0], exc_info[1], exc_info[2]
1371
self._write_prefixed_body(part)
1374
self._medium_request.finished_writing()