13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
29
29
from bzrlib import errors
30
30
from bzrlib.smart import message, request
31
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode, bencode
32
from bzrlib.bencode import bdecode_as_tuple, bencode
35
35
# Protocol version strings. These are sent as prefixes of bzr requests and
109
109
for start, length in offsets:
110
110
txt.append('%d,%d' % (start, length))
111
111
return '\n'.join(txt)
114
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
115
"""Server-side encoding and decoding logic for smart version 1."""
117
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
118
118
self._backing_transport = backing_transport
119
119
self._root_client_path = root_client_path
128
128
def accept_bytes(self, bytes):
129
129
"""Take bytes, and advance the internal state machine appropriately.
131
131
:param bytes: must be a byte string
133
133
if not isinstance(bytes, str):
324
324
def __init__(self):
325
325
self.finished_reading = False
326
self._in_buffer_list = []
327
self._in_buffer_len = 0
327
328
self.unused_data = ''
328
329
self.bytes_left = None
329
330
self._number_needed_bytes = None
332
def _get_in_buffer(self):
333
if len(self._in_buffer_list) == 1:
334
return self._in_buffer_list[0]
335
in_buffer = ''.join(self._in_buffer_list)
336
if len(in_buffer) != self._in_buffer_len:
337
raise AssertionError(
338
"Length of buffer did not match expected value: %s != %s"
339
% self._in_buffer_len, len(in_buffer))
340
self._in_buffer_list = [in_buffer]
343
def _get_in_bytes(self, count):
344
"""Grab X bytes from the input_buffer.
346
Callers should have already checked that self._in_buffer_len is >
347
count. Note, this does not consume the bytes from the buffer. The
348
caller will still need to call _get_in_buffer() and then
349
_set_in_buffer() if they actually need to consume the bytes.
351
# check if we can yield the bytes from just the first entry in our list
352
if len(self._in_buffer_list) == 0:
353
raise AssertionError('Callers must be sure we have buffered bytes'
354
' before calling _get_in_bytes')
355
if len(self._in_buffer_list[0]) > count:
356
return self._in_buffer_list[0][:count]
357
# We can't yield it from the first buffer, so collapse all buffers, and
359
in_buf = self._get_in_buffer()
360
return in_buf[:count]
362
def _set_in_buffer(self, new_buf):
363
if new_buf is not None:
364
self._in_buffer_list = [new_buf]
365
self._in_buffer_len = len(new_buf)
367
self._in_buffer_list = []
368
self._in_buffer_len = 0
331
370
def accept_bytes(self, bytes):
332
371
"""Decode as much of bytes as possible.
338
377
data will be appended to self.unused_data.
340
379
# accept_bytes is allowed to change the state
341
current_state = self.state_accept
342
380
self._number_needed_bytes = None
343
self._in_buffer += bytes
381
# lsprof puts a very large amount of time on this specific call for
383
self._in_buffer_list.append(bytes)
384
self._in_buffer_len += len(bytes)
345
386
# Run the function for the current state.
387
current_state = self.state_accept
346
388
self.state_accept()
347
389
while current_state != self.state_accept:
348
390
# The current state has changed. Run the function for the new
370
412
self.chunks = collections.deque()
371
413
self.error = False
372
414
self.error_in_progress = None
374
416
def next_read_size(self):
375
417
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
376
418
# end-of-body marker is 4 bytes: 'END\n'.
379
421
# the rest of this chunk plus an END chunk.
380
422
return self.bytes_left + 4
381
423
elif self.state_accept == self._state_accept_expecting_length:
382
if self._in_buffer == '':
424
if self._in_buffer_len == 0:
383
425
# We're expecting a chunk length. There's at least two bytes
384
426
# left: a digit plus '\n'.
390
432
elif self.state_accept == self._state_accept_reading_unused:
392
434
elif self.state_accept == self._state_accept_expecting_header:
393
return max(0, len('chunked\n') - len(self._in_buffer))
435
return max(0, len('chunked\n') - self._in_buffer_len)
395
437
raise AssertionError("Impossible state: %r" % (self.state_accept,))
403
445
def _extract_line(self):
404
pos = self._in_buffer.find('\n')
446
in_buf = self._get_in_buffer()
447
pos = in_buf.find('\n')
406
449
# We haven't read a complete line yet, so request more bytes before
408
451
raise _NeedMoreBytes(1)
409
line = self._in_buffer[:pos]
410
453
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
411
self._in_buffer = self._in_buffer[pos+1:]
454
self._set_in_buffer(in_buf[pos+1:])
414
457
def _finished(self):
415
self.unused_data = self._in_buffer
458
self.unused_data = self._get_in_buffer()
459
self._in_buffer_list = []
460
self._in_buffer_len = 0
417
461
self.state_accept = self._state_accept_reading_unused
419
463
error_args = tuple(self.error_in_progress)
448
492
self.state_accept = self._state_accept_reading_chunk
450
494
def _state_accept_reading_chunk(self):
451
in_buffer_len = len(self._in_buffer)
452
self.chunk_in_progress += self._in_buffer[:self.bytes_left]
453
self._in_buffer = self._in_buffer[self.bytes_left:]
495
in_buf = self._get_in_buffer()
496
in_buffer_len = len(in_buf)
497
self.chunk_in_progress += in_buf[:self.bytes_left]
498
self._set_in_buffer(in_buf[self.bytes_left:])
454
499
self.bytes_left -= in_buffer_len
455
500
if self.bytes_left <= 0:
456
501
# Finished with chunk
461
506
self.chunks.append(self.chunk_in_progress)
462
507
self.chunk_in_progress = None
463
508
self.state_accept = self._state_accept_expecting_length
465
510
def _state_accept_reading_unused(self):
466
self.unused_data += self._in_buffer
511
self.unused_data += self._get_in_buffer()
512
self._in_buffer_list = []
470
515
class LengthPrefixedBodyDecoder(_StatefulDecoder):
471
516
"""Decodes the length-prefixed bulk data."""
473
518
def __init__(self):
474
519
_StatefulDecoder.__init__(self)
475
520
self.state_accept = self._state_accept_expecting_length
476
521
self.state_read = self._state_read_no_data
478
523
self._trailer_buffer = ''
480
525
def next_read_size(self):
481
526
if self.bytes_left is not None:
482
527
# Ideally we want to read all the remainder of the body and the
493
538
# Reading excess data. Either way, 1 byte at a time is fine.
496
541
def read_pending_data(self):
497
542
"""Return any pending data that has been decoded."""
498
543
return self.state_read()
500
545
def _state_accept_expecting_length(self):
501
pos = self._in_buffer.find('\n')
546
in_buf = self._get_in_buffer()
547
pos = in_buf.find('\n')
504
self.bytes_left = int(self._in_buffer[:pos])
505
self._in_buffer = self._in_buffer[pos+1:]
550
self.bytes_left = int(in_buf[:pos])
551
self._set_in_buffer(in_buf[pos+1:])
506
552
self.state_accept = self._state_accept_reading_body
507
553
self.state_read = self._state_read_body_buffer
509
555
def _state_accept_reading_body(self):
510
self._body += self._in_buffer
511
self.bytes_left -= len(self._in_buffer)
556
in_buf = self._get_in_buffer()
558
self.bytes_left -= len(in_buf)
559
self._set_in_buffer(None)
513
560
if self.bytes_left <= 0:
514
561
# Finished with body
515
562
if self.bytes_left != 0:
517
564
self._body = self._body[:self.bytes_left]
518
565
self.bytes_left = None
519
566
self.state_accept = self._state_accept_reading_trailer
521
568
def _state_accept_reading_trailer(self):
522
self._trailer_buffer += self._in_buffer
569
self._trailer_buffer += self._get_in_buffer()
570
self._set_in_buffer(None)
524
571
# TODO: what if the trailer does not match "done\n"? Should this raise
525
572
# a ProtocolViolation exception?
526
573
if self._trailer_buffer.startswith('done\n'):
527
574
self.unused_data = self._trailer_buffer[len('done\n'):]
528
575
self.state_accept = self._state_accept_reading_unused
529
576
self.finished_reading = True
531
578
def _state_accept_reading_unused(self):
532
self.unused_data += self._in_buffer
579
self.unused_data += self._get_in_buffer()
580
self._set_in_buffer(None)
535
582
def _state_read_no_data(self):
609
656
mutter(' %d bytes in readv request', len(readv_bytes))
610
657
self._last_verb = args[0]
659
def call_with_body_stream(self, args, stream):
660
# Protocols v1 and v2 don't support body streams. So it's safe to
661
# assume that a v1/v2 server doesn't support whatever method we're
662
# trying to call with a body stream.
663
self._request.finished_writing()
664
self._request.finished_reading()
665
raise errors.UnknownSmartMethod(args[0])
612
667
def cancel_read_body(self):
613
668
"""After expecting a body, a response code may indicate one otherwise.
674
729
def _response_is_unknown_method(self, result_tuple):
675
730
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
676
731
method' response to the request.
678
733
:param response: The response from a smart client call_expecting_body
680
735
:param verb: The verb used in that call.
687
742
# The response will have no body, so we've finished reading.
688
743
self._request.finished_reading()
689
744
raise errors.UnknownSmartMethod(self._last_verb)
691
746
def read_body_bytes(self, count=-1):
692
747
"""Read bytes from the body, decoding into a byte stream.
694
We read all bytes at once to ensure we've checked the trailer for
749
We read all bytes at once to ensure we've checked the trailer for
695
750
errors, and then feed the buffer back as read_body_bytes is called.
697
752
if self._body_buffer is not None:
698
753
return self._body_buffer.read(count)
699
754
_body_decoder = LengthPrefixedBodyDecoder()
701
# Read no more than 64k at a time so that we don't risk error 10055 (no
702
# buffer space available) on Windows.
704
756
while not _body_decoder.finished_reading:
705
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
706
bytes = self._request.read_bytes(bytes_wanted)
757
bytes = self._request.read_bytes(_body_decoder.next_read_size())
708
759
# end of file encountered reading from server
709
760
raise errors.ConnectionReset(
720
771
def _recv_tuple(self):
721
772
"""Receive a tuple from the medium request."""
722
return _decode_tuple(self._recv_line())
724
def _recv_line(self):
725
"""Read an entire line from the medium request."""
727
while not line or line[-1] != '\n':
728
# TODO: this is inefficient - but tuples are short.
729
new_char = self._request.read_bytes(1)
731
# end of file encountered reading from server
732
raise errors.ConnectionReset(
733
"please check connectivity and permissions",
734
"(and try -Dhpss if further diagnosis is required)")
773
return _decode_tuple(self._request.read_line())
738
775
def query_version(self):
739
776
"""Return protocol version number of the server."""
754
791
def _write_protocol_version(self):
755
792
"""Write any prefixes this protocol requires.
757
794
Version one doesn't send protocol versions.
761
798
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
762
799
"""Version two of the client side of the smart protocol.
764
801
This prefixes the request with the value of REQUEST_VERSION_TWO.
776
813
if version != self.response_marker:
777
814
self._request.finished_reading()
778
815
raise errors.UnexpectedProtocolVersionMarker(version)
779
response_status = self._recv_line()
816
response_status = self._request.read_line()
780
817
result = SmartClientRequestProtocolOne._read_response_tuple(self)
781
818
self._response_is_unknown_method(result)
782
819
if response_status == 'success\n':
805
842
# Read no more than 64k at a time so that we don't risk error 10055 (no
806
843
# buffer space available) on Windows.
808
844
_body_decoder = ChunkedBodyDecoder()
809
845
while not _body_decoder.finished_reading:
810
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
811
bytes = self._request.read_bytes(bytes_wanted)
846
bytes = self._request.read_bytes(_body_decoder.next_read_size())
813
848
# end of file encountered reading from server
814
849
raise errors.ConnectionReset(
862
897
# We do *not* set self.decoding_failed here. The message handler
863
898
# has raised an error, but the decoder is still able to parse bytes
864
899
# and determine when this message ends.
865
log_exception_quietly()
900
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
901
log_exception_quietly()
866
902
self.message_handler.protocol_error(exception.exc_value)
867
903
# The state machine is ready to continue decoding, but the
868
904
# exception has interrupted the loop that runs the state machine.
885
921
self.message_handler.protocol_error(exception)
887
923
def _extract_length_prefixed_bytes(self):
888
if len(self._in_buffer) < 4:
924
if self._in_buffer_len < 4:
889
925
# A length prefix by itself is 4 bytes, and we don't even have that
891
927
raise _NeedMoreBytes(4)
892
(length,) = struct.unpack('!L', self._in_buffer[:4])
928
(length,) = struct.unpack('!L', self._get_in_bytes(4))
893
929
end_of_bytes = 4 + length
894
if len(self._in_buffer) < end_of_bytes:
930
if self._in_buffer_len < end_of_bytes:
895
931
# We haven't yet read as many bytes as the length-prefix says there
897
933
raise _NeedMoreBytes(end_of_bytes)
898
934
# Extract the bytes from the buffer.
899
bytes = self._in_buffer[4:end_of_bytes]
900
self._in_buffer = self._in_buffer[end_of_bytes:]
935
in_buf = self._get_in_buffer()
936
bytes = in_buf[4:end_of_bytes]
937
self._set_in_buffer(in_buf[end_of_bytes:])
903
940
def _extract_prefixed_bencoded_data(self):
904
941
prefixed_bytes = self._extract_length_prefixed_bytes()
906
decoded = bdecode(prefixed_bytes)
943
decoded = bdecode_as_tuple(prefixed_bytes)
907
944
except ValueError:
908
945
raise errors.SmartProtocolError(
909
946
'Bytes %r not bencoded' % (prefixed_bytes,))
912
949
def _extract_single_byte(self):
913
if self._in_buffer == '':
950
if self._in_buffer_len == 0:
914
951
# The buffer is empty
915
952
raise _NeedMoreBytes(1)
916
one_byte = self._in_buffer[0]
917
self._in_buffer = self._in_buffer[1:]
953
in_buf = self._get_in_buffer()
955
self._set_in_buffer(in_buf[1:])
920
958
def _state_accept_expecting_protocol_version(self):
921
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
959
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
960
in_buf = self._get_in_buffer()
922
961
if needed_bytes > 0:
923
962
# We don't have enough bytes to check if the protocol version
924
963
# marker is right. But we can check if it is already wrong by
928
967
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
929
968
# are wrong then we should just raise immediately rather than
931
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
970
if not MESSAGE_VERSION_THREE.startswith(in_buf):
932
971
# We have enough bytes to know the protocol version is wrong
933
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
972
raise errors.UnexpectedProtocolVersionMarker(in_buf)
934
973
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
935
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
936
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
937
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
974
if not in_buf.startswith(MESSAGE_VERSION_THREE):
975
raise errors.UnexpectedProtocolVersionMarker(in_buf)
976
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
938
977
self.state_accept = self._state_accept_expecting_headers
940
979
def _state_accept_expecting_headers(self):
947
986
self.message_handler.headers_received(decoded)
949
988
raise errors.SmartMessageHandlerError(sys.exc_info())
951
990
def _state_accept_expecting_message_part(self):
952
991
message_part_kind = self._extract_single_byte()
953
992
if message_part_kind == 'o':
989
1028
raise errors.SmartMessageHandlerError(sys.exc_info())
992
self.unused_data = self._in_buffer
1031
self.unused_data = self._get_in_buffer()
1032
self._set_in_buffer(None)
994
1033
self.state_accept = self._state_accept_reading_unused
996
1035
self.message_handler.end_received()
998
1037
raise errors.SmartMessageHandlerError(sys.exc_info())
1000
1039
def _state_accept_reading_unused(self):
1001
self.unused_data += self._in_buffer
1002
self._in_buffer = ''
1040
self.unused_data += self._get_in_buffer()
1041
self._set_in_buffer(None)
1004
1043
def next_read_size(self):
1005
1044
if self.state_accept == self._state_accept_reading_unused:
1022
1061
response_marker = request_marker = MESSAGE_VERSION_THREE
1024
1063
def __init__(self, write_func):
1026
1065
self._real_write_func = write_func
1028
1067
def _write_func(self, bytes):
1068
self._buf.append(bytes)
1069
if len(self._buf) > 100:
1031
1072
def flush(self):
1033
self._real_write_func(self._buf)
1074
self._real_write_func(''.join(self._buf))
1036
1077
def _serialise_offsets(self, offsets):
1037
1078
"""Serialise a readv offset list."""
1039
1080
for start, length in offsets:
1040
1081
txt.append('%d,%d' % (start, length))
1041
1082
return '\n'.join(txt)
1043
1084
def _write_protocol_version(self):
1044
1085
self._write_func(MESSAGE_VERSION_THREE)
1117
1161
if response.body is not None:
1118
1162
self._write_prefixed_body(response.body)
1119
1163
elif response.body_stream is not None:
1120
for chunk in response.body_stream:
1121
self._write_prefixed_body(chunk)
1164
for exc_info, chunk in _iter_with_errors(response.body_stream):
1165
if exc_info is not None:
1166
self._write_error_status()
1167
error_struct = request._translate_error(exc_info[1])
1168
self._write_structure(error_struct)
1171
if isinstance(chunk, request.FailedSmartServerResponse):
1172
self._write_error_status()
1173
self._write_structure(chunk.args)
1175
self._write_prefixed_body(chunk)
1123
1176
self._write_end()
1179
def _iter_with_errors(iterable):
1180
"""Handle errors from iterable.next().
1184
for exc_info, value in _iter_with_errors(iterable):
1187
This is a safer alternative to::
1190
for value in iterable:
1195
Because the latter will catch errors from the for-loop body, not just
1198
If an error occurs, exc_info will be a exc_info tuple, and the generator
1199
will terminate. Otherwise exc_info will be None, and value will be the
1200
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1201
will not be itercepted.
1203
iterator = iter(iterable)
1206
yield None, iterator.next()
1207
except StopIteration:
1209
except (KeyboardInterrupt, SystemExit):
1212
mutter('_iter_with_errors caught error')
1213
log_exception_quietly()
1214
yield sys.exc_info(), None
1126
1218
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1133
1225
def set_headers(self, headers):
1134
1226
self._headers = headers.copy()
1136
1228
def call(self, *args):
1137
1229
if 'hpss' in debug.debug_flags:
1138
1230
mutter('hpss call: %s', repr(args)[1:-1])
1187
1279
self._write_end()
1188
1280
self._medium_request.finished_writing()
1282
def call_with_body_stream(self, args, stream):
1283
if 'hpss' in debug.debug_flags:
1284
mutter('hpss call w/body stream: %r', args)
1285
path = getattr(self._medium_request._medium, '_path', None)
1286
if path is not None:
1287
mutter(' (to %s)', path)
1288
self._request_start_time = time.time()
1289
self._write_protocol_version()
1290
self._write_headers(self._headers)
1291
self._write_structure(args)
1292
# TODO: notice if the server has sent an early error reply before we
1293
# have finished sending the stream. We would notice at the end
1294
# anyway, but if the medium can deliver it early then it's good
1295
# to short-circuit the whole request...
1296
for exc_info, part in _iter_with_errors(stream):
1297
if exc_info is not None:
1298
# Iterating the stream failed. Cleanly abort the request.
1299
self._write_error_status()
1300
# Currently the client unconditionally sends ('error',) as the
1302
self._write_structure(('error',))
1304
self._medium_request.finished_writing()
1305
raise exc_info[0], exc_info[1], exc_info[2]
1307
self._write_prefixed_body(part)
1310
self._medium_request.finished_writing()