13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
29
29
from bzrlib import errors
30
30
from bzrlib.smart import message, request
31
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode, bencode
35
35
# Protocol version strings. These are sent as prefixes of bzr requests and
109
109
for start, length in offsets:
110
110
txt.append('%d,%d' % (start, length))
111
111
return '\n'.join(txt)
114
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
115
"""Server-side encoding and decoding logic for smart version 1."""
117
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
118
118
self._backing_transport = backing_transport
119
119
self._root_client_path = root_client_path
128
128
def accept_bytes(self, bytes):
129
129
"""Take bytes, and advance the internal state machine appropriately.
131
131
:param bytes: must be a byte string
133
133
if not isinstance(bytes, str):
324
324
def __init__(self):
325
325
self.finished_reading = False
326
self._in_buffer_list = []
327
self._in_buffer_len = 0
328
327
self.unused_data = ''
329
328
self.bytes_left = None
330
329
self._number_needed_bytes = None
332
def _get_in_buffer(self):
333
if len(self._in_buffer_list) == 1:
334
return self._in_buffer_list[0]
335
in_buffer = ''.join(self._in_buffer_list)
336
if len(in_buffer) != self._in_buffer_len:
337
raise AssertionError(
338
"Length of buffer did not match expected value: %s != %s"
339
% self._in_buffer_len, len(in_buffer))
340
self._in_buffer_list = [in_buffer]
343
def _get_in_bytes(self, count):
344
"""Grab X bytes from the input_buffer.
346
Callers should have already checked that self._in_buffer_len is >
347
count. Note, this does not consume the bytes from the buffer. The
348
caller will still need to call _get_in_buffer() and then
349
_set_in_buffer() if they actually need to consume the bytes.
351
# check if we can yield the bytes from just the first entry in our list
352
if len(self._in_buffer_list) == 0:
353
raise AssertionError('Callers must be sure we have buffered bytes'
354
' before calling _get_in_bytes')
355
if len(self._in_buffer_list[0]) > count:
356
return self._in_buffer_list[0][:count]
357
# We can't yield it from the first buffer, so collapse all buffers, and
359
in_buf = self._get_in_buffer()
360
return in_buf[:count]
362
def _set_in_buffer(self, new_buf):
363
if new_buf is not None:
364
self._in_buffer_list = [new_buf]
365
self._in_buffer_len = len(new_buf)
367
self._in_buffer_list = []
368
self._in_buffer_len = 0
370
331
def accept_bytes(self, bytes):
371
332
"""Decode as much of bytes as possible.
377
338
data will be appended to self.unused_data.
379
340
# accept_bytes is allowed to change the state
341
current_state = self.state_accept
380
342
self._number_needed_bytes = None
381
# lsprof puts a very large amount of time on this specific call for
383
self._in_buffer_list.append(bytes)
384
self._in_buffer_len += len(bytes)
343
self._in_buffer += bytes
386
345
# Run the function for the current state.
387
current_state = self.state_accept
388
346
self.state_accept()
389
347
while current_state != self.state_accept:
390
348
# The current state has changed. Run the function for the new
412
370
self.chunks = collections.deque()
413
371
self.error = False
414
372
self.error_in_progress = None
416
374
def next_read_size(self):
417
375
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
418
376
# end-of-body marker is 4 bytes: 'END\n'.
421
379
# the rest of this chunk plus an END chunk.
422
380
return self.bytes_left + 4
423
381
elif self.state_accept == self._state_accept_expecting_length:
424
if self._in_buffer_len == 0:
382
if self._in_buffer == '':
425
383
# We're expecting a chunk length. There's at least two bytes
426
384
# left: a digit plus '\n'.
432
390
elif self.state_accept == self._state_accept_reading_unused:
434
392
elif self.state_accept == self._state_accept_expecting_header:
435
return max(0, len('chunked\n') - self._in_buffer_len)
393
return max(0, len('chunked\n') - len(self._in_buffer))
437
395
raise AssertionError("Impossible state: %r" % (self.state_accept,))
445
403
def _extract_line(self):
446
in_buf = self._get_in_buffer()
447
pos = in_buf.find('\n')
404
pos = self._in_buffer.find('\n')
449
406
# We haven't read a complete line yet, so request more bytes before
451
408
raise _NeedMoreBytes(1)
409
line = self._in_buffer[:pos]
453
410
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
454
self._set_in_buffer(in_buf[pos+1:])
411
self._in_buffer = self._in_buffer[pos+1:]
457
414
def _finished(self):
458
self.unused_data = self._get_in_buffer()
459
self._in_buffer_list = []
460
self._in_buffer_len = 0
415
self.unused_data = self._in_buffer
461
417
self.state_accept = self._state_accept_reading_unused
463
419
error_args = tuple(self.error_in_progress)
492
448
self.state_accept = self._state_accept_reading_chunk
494
450
def _state_accept_reading_chunk(self):
495
in_buf = self._get_in_buffer()
496
in_buffer_len = len(in_buf)
497
self.chunk_in_progress += in_buf[:self.bytes_left]
498
self._set_in_buffer(in_buf[self.bytes_left:])
451
in_buffer_len = len(self._in_buffer)
452
self.chunk_in_progress += self._in_buffer[:self.bytes_left]
453
self._in_buffer = self._in_buffer[self.bytes_left:]
499
454
self.bytes_left -= in_buffer_len
500
455
if self.bytes_left <= 0:
501
456
# Finished with chunk
506
461
self.chunks.append(self.chunk_in_progress)
507
462
self.chunk_in_progress = None
508
463
self.state_accept = self._state_accept_expecting_length
510
465
def _state_accept_reading_unused(self):
511
self.unused_data += self._get_in_buffer()
512
self._in_buffer_list = []
466
self.unused_data += self._in_buffer
515
470
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
471
"""Decodes the length-prefixed bulk data."""
518
473
def __init__(self):
519
474
_StatefulDecoder.__init__(self)
520
475
self.state_accept = self._state_accept_expecting_length
521
476
self.state_read = self._state_read_no_data
523
478
self._trailer_buffer = ''
525
480
def next_read_size(self):
526
481
if self.bytes_left is not None:
527
482
# Ideally we want to read all the remainder of the body and the
538
493
# Reading excess data. Either way, 1 byte at a time is fine.
541
496
def read_pending_data(self):
542
497
"""Return any pending data that has been decoded."""
543
498
return self.state_read()
545
500
def _state_accept_expecting_length(self):
546
in_buf = self._get_in_buffer()
547
pos = in_buf.find('\n')
501
pos = self._in_buffer.find('\n')
550
self.bytes_left = int(in_buf[:pos])
551
self._set_in_buffer(in_buf[pos+1:])
504
self.bytes_left = int(self._in_buffer[:pos])
505
self._in_buffer = self._in_buffer[pos+1:]
552
506
self.state_accept = self._state_accept_reading_body
553
507
self.state_read = self._state_read_body_buffer
555
509
def _state_accept_reading_body(self):
556
in_buf = self._get_in_buffer()
558
self.bytes_left -= len(in_buf)
559
self._set_in_buffer(None)
510
self._body += self._in_buffer
511
self.bytes_left -= len(self._in_buffer)
560
513
if self.bytes_left <= 0:
561
514
# Finished with body
562
515
if self.bytes_left != 0:
564
517
self._body = self._body[:self.bytes_left]
565
518
self.bytes_left = None
566
519
self.state_accept = self._state_accept_reading_trailer
568
521
def _state_accept_reading_trailer(self):
569
self._trailer_buffer += self._get_in_buffer()
570
self._set_in_buffer(None)
522
self._trailer_buffer += self._in_buffer
571
524
# TODO: what if the trailer does not match "done\n"? Should this raise
572
525
# a ProtocolViolation exception?
573
526
if self._trailer_buffer.startswith('done\n'):
574
527
self.unused_data = self._trailer_buffer[len('done\n'):]
575
528
self.state_accept = self._state_accept_reading_unused
576
529
self.finished_reading = True
578
531
def _state_accept_reading_unused(self):
579
self.unused_data += self._get_in_buffer()
580
self._set_in_buffer(None)
532
self.unused_data += self._in_buffer
582
535
def _state_read_no_data(self):
656
609
mutter(' %d bytes in readv request', len(readv_bytes))
657
610
self._last_verb = args[0]
659
def call_with_body_stream(self, args, stream):
660
# Protocols v1 and v2 don't support body streams. So it's safe to
661
# assume that a v1/v2 server doesn't support whatever method we're
662
# trying to call with a body stream.
663
self._request.finished_writing()
664
self._request.finished_reading()
665
raise errors.UnknownSmartMethod(args[0])
667
612
def cancel_read_body(self):
668
613
"""After expecting a body, a response code may indicate one otherwise.
729
674
def _response_is_unknown_method(self, result_tuple):
730
675
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
731
676
method' response to the request.
733
678
:param response: The response from a smart client call_expecting_body
735
680
:param verb: The verb used in that call.
742
687
# The response will have no body, so we've finished reading.
743
688
self._request.finished_reading()
744
689
raise errors.UnknownSmartMethod(self._last_verb)
746
691
def read_body_bytes(self, count=-1):
747
692
"""Read bytes from the body, decoding into a byte stream.
749
We read all bytes at once to ensure we've checked the trailer for
694
We read all bytes at once to ensure we've checked the trailer for
750
695
errors, and then feed the buffer back as read_body_bytes is called.
752
697
if self._body_buffer is not None:
791
736
def _write_protocol_version(self):
792
737
"""Write any prefixes this protocol requires.
794
739
Version one doesn't send protocol versions.
798
743
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
799
744
"""Version two of the client side of the smart protocol.
801
746
This prefixes the request with the value of REQUEST_VERSION_TWO.
897
842
# We do *not* set self.decoding_failed here. The message handler
898
843
# has raised an error, but the decoder is still able to parse bytes
899
844
# and determine when this message ends.
900
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
901
log_exception_quietly()
845
log_exception_quietly()
902
846
self.message_handler.protocol_error(exception.exc_value)
903
847
# The state machine is ready to continue decoding, but the
904
848
# exception has interrupted the loop that runs the state machine.
921
865
self.message_handler.protocol_error(exception)
923
867
def _extract_length_prefixed_bytes(self):
924
if self._in_buffer_len < 4:
868
if len(self._in_buffer) < 4:
925
869
# A length prefix by itself is 4 bytes, and we don't even have that
927
871
raise _NeedMoreBytes(4)
928
(length,) = struct.unpack('!L', self._get_in_bytes(4))
872
(length,) = struct.unpack('!L', self._in_buffer[:4])
929
873
end_of_bytes = 4 + length
930
if self._in_buffer_len < end_of_bytes:
874
if len(self._in_buffer) < end_of_bytes:
931
875
# We haven't yet read as many bytes as the length-prefix says there
933
877
raise _NeedMoreBytes(end_of_bytes)
934
878
# Extract the bytes from the buffer.
935
in_buf = self._get_in_buffer()
936
bytes = in_buf[4:end_of_bytes]
937
self._set_in_buffer(in_buf[end_of_bytes:])
879
bytes = self._in_buffer[4:end_of_bytes]
880
self._in_buffer = self._in_buffer[end_of_bytes:]
940
883
def _extract_prefixed_bencoded_data(self):
941
884
prefixed_bytes = self._extract_length_prefixed_bytes()
943
decoded = bdecode_as_tuple(prefixed_bytes)
886
decoded = bdecode(prefixed_bytes)
944
887
except ValueError:
945
888
raise errors.SmartProtocolError(
946
889
'Bytes %r not bencoded' % (prefixed_bytes,))
949
892
def _extract_single_byte(self):
950
if self._in_buffer_len == 0:
893
if self._in_buffer == '':
951
894
# The buffer is empty
952
895
raise _NeedMoreBytes(1)
953
in_buf = self._get_in_buffer()
955
self._set_in_buffer(in_buf[1:])
896
one_byte = self._in_buffer[0]
897
self._in_buffer = self._in_buffer[1:]
958
900
def _state_accept_expecting_protocol_version(self):
959
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
960
in_buf = self._get_in_buffer()
901
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
961
902
if needed_bytes > 0:
962
903
# We don't have enough bytes to check if the protocol version
963
904
# marker is right. But we can check if it is already wrong by
967
908
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
968
909
# are wrong then we should just raise immediately rather than
970
if not MESSAGE_VERSION_THREE.startswith(in_buf):
911
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
971
912
# We have enough bytes to know the protocol version is wrong
972
raise errors.UnexpectedProtocolVersionMarker(in_buf)
913
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
973
914
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
974
if not in_buf.startswith(MESSAGE_VERSION_THREE):
975
raise errors.UnexpectedProtocolVersionMarker(in_buf)
976
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
915
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
916
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
917
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
977
918
self.state_accept = self._state_accept_expecting_headers
979
920
def _state_accept_expecting_headers(self):
986
927
self.message_handler.headers_received(decoded)
988
929
raise errors.SmartMessageHandlerError(sys.exc_info())
990
931
def _state_accept_expecting_message_part(self):
991
932
message_part_kind = self._extract_single_byte()
992
933
if message_part_kind == 'o':
1028
969
raise errors.SmartMessageHandlerError(sys.exc_info())
1031
self.unused_data = self._get_in_buffer()
1032
self._set_in_buffer(None)
972
self.unused_data = self._in_buffer
1033
974
self.state_accept = self._state_accept_reading_unused
1035
976
self.message_handler.end_received()
1037
978
raise errors.SmartMessageHandlerError(sys.exc_info())
1039
980
def _state_accept_reading_unused(self):
1040
self.unused_data += self._get_in_buffer()
1041
self._set_in_buffer(None)
981
self.unused_data += self._in_buffer
1043
984
def next_read_size(self):
1044
985
if self.state_accept == self._state_accept_reading_unused:
1061
1002
response_marker = request_marker = MESSAGE_VERSION_THREE
1063
1004
def __init__(self, write_func):
1065
1006
self._real_write_func = write_func
1067
1008
def _write_func(self, bytes):
1068
self._buf.append(bytes)
1069
if len(self._buf) > 100:
1072
1011
def flush(self):
1074
self._real_write_func(''.join(self._buf))
1013
self._real_write_func(self._buf)
1077
1016
def _serialise_offsets(self, offsets):
1078
1017
"""Serialise a readv offset list."""
1080
1019
for start, length in offsets:
1081
1020
txt.append('%d,%d' % (start, length))
1082
1021
return '\n'.join(txt)
1084
1023
def _write_protocol_version(self):
1085
1024
self._write_func(MESSAGE_VERSION_THREE)
1161
1097
if response.body is not None:
1162
1098
self._write_prefixed_body(response.body)
1163
1099
elif response.body_stream is not None:
1164
for exc_info, chunk in _iter_with_errors(response.body_stream):
1165
if exc_info is not None:
1166
self._write_error_status()
1167
error_struct = request._translate_error(exc_info[1])
1168
self._write_structure(error_struct)
1171
if isinstance(chunk, request.FailedSmartServerResponse):
1172
self._write_error_status()
1173
self._write_structure(chunk.args)
1175
self._write_prefixed_body(chunk)
1100
for chunk in response.body_stream:
1101
self._write_prefixed_body(chunk)
1176
1103
self._write_end()
1179
def _iter_with_errors(iterable):
1180
"""Handle errors from iterable.next().
1184
for exc_info, value in _iter_with_errors(iterable):
1187
This is a safer alternative to::
1190
for value in iterable:
1195
Because the latter will catch errors from the for-loop body, not just
1198
If an error occurs, exc_info will be a exc_info tuple, and the generator
1199
will terminate. Otherwise exc_info will be None, and value will be the
1200
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1201
will not be itercepted.
1203
iterator = iter(iterable)
1206
yield None, iterator.next()
1207
except StopIteration:
1209
except (KeyboardInterrupt, SystemExit):
1212
yield sys.exc_info(), None
1216
1106
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1223
1113
def set_headers(self, headers):
1224
1114
self._headers = headers.copy()
1226
1116
def call(self, *args):
1227
1117
if 'hpss' in debug.debug_flags:
1228
1118
mutter('hpss call: %s', repr(args)[1:-1])
1277
1167
self._write_end()
1278
1168
self._medium_request.finished_writing()
1280
def call_with_body_stream(self, args, stream):
1281
if 'hpss' in debug.debug_flags:
1282
mutter('hpss call w/body stream: %r', args)
1283
path = getattr(self._medium_request._medium, '_path', None)
1284
if path is not None:
1285
mutter(' (to %s)', path)
1286
self._request_start_time = time.time()
1287
self._write_protocol_version()
1288
self._write_headers(self._headers)
1289
self._write_structure(args)
1290
# TODO: notice if the server has sent an early error reply before we
1291
# have finished sending the stream. We would notice at the end
1292
# anyway, but if the medium can deliver it early then it's good
1293
# to short-circuit the whole request...
1294
for exc_info, part in _iter_with_errors(stream):
1295
if exc_info is not None:
1296
# Iterating the stream failed. Cleanly abort the request.
1297
self._write_error_status()
1298
# Currently the client unconditionally sends ('error',) as the
1300
self._write_structure(('error',))
1302
self._medium_request.finished_writing()
1303
raise exc_info[0], exc_info[1], exc_info[2]
1305
self._write_prefixed_body(part)
1308
self._medium_request.finished_writing()