13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
29
29
from bzrlib import errors
30
30
from bzrlib.smart import message, request
31
31
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode, bencode
32
from bzrlib.bencode import bdecode_as_tuple, bencode
35
35
# Protocol version strings. These are sent as prefixes of bzr requests and
109
109
for start, length in offsets:
110
110
txt.append('%d,%d' % (start, length))
111
111
return '\n'.join(txt)
114
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
115
"""Server-side encoding and decoding logic for smart version 1."""
117
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
118
118
self._backing_transport = backing_transport
119
119
self._root_client_path = root_client_path
128
128
def accept_bytes(self, bytes):
129
129
"""Take bytes, and advance the internal state machine appropriately.
131
131
:param bytes: must be a byte string
133
133
if not isinstance(bytes, str):
324
324
def __init__(self):
325
325
self.finished_reading = False
326
self._in_buffer_list = []
327
self._in_buffer_len = 0
327
328
self.unused_data = ''
328
329
self.bytes_left = None
329
330
self._number_needed_bytes = None
332
def _get_in_buffer(self):
333
if len(self._in_buffer_list) == 1:
334
return self._in_buffer_list[0]
335
in_buffer = ''.join(self._in_buffer_list)
336
if len(in_buffer) != self._in_buffer_len:
337
raise AssertionError(
338
"Length of buffer did not match expected value: %s != %s"
339
% self._in_buffer_len, len(in_buffer))
340
self._in_buffer_list = [in_buffer]
343
def _get_in_bytes(self, count):
344
"""Grab X bytes from the input_buffer.
346
Callers should have already checked that self._in_buffer_len is >
347
count. Note, this does not consume the bytes from the buffer. The
348
caller will still need to call _get_in_buffer() and then
349
_set_in_buffer() if they actually need to consume the bytes.
351
# check if we can yield the bytes from just the first entry in our list
352
if len(self._in_buffer_list) == 0:
353
raise AssertionError('Callers must be sure we have buffered bytes'
354
' before calling _get_in_bytes')
355
if len(self._in_buffer_list[0]) > count:
356
return self._in_buffer_list[0][:count]
357
# We can't yield it from the first buffer, so collapse all buffers, and
359
in_buf = self._get_in_buffer()
360
return in_buf[:count]
362
def _set_in_buffer(self, new_buf):
363
if new_buf is not None:
364
self._in_buffer_list = [new_buf]
365
self._in_buffer_len = len(new_buf)
367
self._in_buffer_list = []
368
self._in_buffer_len = 0
331
370
def accept_bytes(self, bytes):
332
371
"""Decode as much of bytes as possible.
338
377
data will be appended to self.unused_data.
340
379
# accept_bytes is allowed to change the state
341
current_state = self.state_accept
342
380
self._number_needed_bytes = None
343
self._in_buffer += bytes
381
# lsprof puts a very large amount of time on this specific call for
383
self._in_buffer_list.append(bytes)
384
self._in_buffer_len += len(bytes)
345
386
# Run the function for the current state.
387
current_state = self.state_accept
346
388
self.state_accept()
347
389
while current_state != self.state_accept:
348
390
# The current state has changed. Run the function for the new
370
412
self.chunks = collections.deque()
371
413
self.error = False
372
414
self.error_in_progress = None
374
416
def next_read_size(self):
375
417
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
376
418
# end-of-body marker is 4 bytes: 'END\n'.
379
421
# the rest of this chunk plus an END chunk.
380
422
return self.bytes_left + 4
381
423
elif self.state_accept == self._state_accept_expecting_length:
382
if self._in_buffer == '':
424
if self._in_buffer_len == 0:
383
425
# We're expecting a chunk length. There's at least two bytes
384
426
# left: a digit plus '\n'.
390
432
elif self.state_accept == self._state_accept_reading_unused:
392
434
elif self.state_accept == self._state_accept_expecting_header:
393
return max(0, len('chunked\n') - len(self._in_buffer))
435
return max(0, len('chunked\n') - self._in_buffer_len)
395
437
raise AssertionError("Impossible state: %r" % (self.state_accept,))
403
445
def _extract_line(self):
404
pos = self._in_buffer.find('\n')
446
in_buf = self._get_in_buffer()
447
pos = in_buf.find('\n')
406
449
# We haven't read a complete line yet, so request more bytes before
408
451
raise _NeedMoreBytes(1)
409
line = self._in_buffer[:pos]
410
453
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
411
self._in_buffer = self._in_buffer[pos+1:]
454
self._set_in_buffer(in_buf[pos+1:])
414
457
def _finished(self):
415
self.unused_data = self._in_buffer
458
self.unused_data = self._get_in_buffer()
459
self._in_buffer_list = []
460
self._in_buffer_len = 0
417
461
self.state_accept = self._state_accept_reading_unused
419
463
error_args = tuple(self.error_in_progress)
448
492
self.state_accept = self._state_accept_reading_chunk
450
494
def _state_accept_reading_chunk(self):
451
in_buffer_len = len(self._in_buffer)
452
self.chunk_in_progress += self._in_buffer[:self.bytes_left]
453
self._in_buffer = self._in_buffer[self.bytes_left:]
495
in_buf = self._get_in_buffer()
496
in_buffer_len = len(in_buf)
497
self.chunk_in_progress += in_buf[:self.bytes_left]
498
self._set_in_buffer(in_buf[self.bytes_left:])
454
499
self.bytes_left -= in_buffer_len
455
500
if self.bytes_left <= 0:
456
501
# Finished with chunk
461
506
self.chunks.append(self.chunk_in_progress)
462
507
self.chunk_in_progress = None
463
508
self.state_accept = self._state_accept_expecting_length
465
510
def _state_accept_reading_unused(self):
466
self.unused_data += self._in_buffer
511
self.unused_data += self._get_in_buffer()
512
self._in_buffer_list = []
470
515
class LengthPrefixedBodyDecoder(_StatefulDecoder):
471
516
"""Decodes the length-prefixed bulk data."""
473
518
def __init__(self):
474
519
_StatefulDecoder.__init__(self)
475
520
self.state_accept = self._state_accept_expecting_length
476
521
self.state_read = self._state_read_no_data
478
523
self._trailer_buffer = ''
480
525
def next_read_size(self):
481
526
if self.bytes_left is not None:
482
527
# Ideally we want to read all the remainder of the body and the
493
538
# Reading excess data. Either way, 1 byte at a time is fine.
496
541
def read_pending_data(self):
497
542
"""Return any pending data that has been decoded."""
498
543
return self.state_read()
500
545
def _state_accept_expecting_length(self):
501
pos = self._in_buffer.find('\n')
546
in_buf = self._get_in_buffer()
547
pos = in_buf.find('\n')
504
self.bytes_left = int(self._in_buffer[:pos])
505
self._in_buffer = self._in_buffer[pos+1:]
550
self.bytes_left = int(in_buf[:pos])
551
self._set_in_buffer(in_buf[pos+1:])
506
552
self.state_accept = self._state_accept_reading_body
507
553
self.state_read = self._state_read_body_buffer
509
555
def _state_accept_reading_body(self):
510
self._body += self._in_buffer
511
self.bytes_left -= len(self._in_buffer)
556
in_buf = self._get_in_buffer()
558
self.bytes_left -= len(in_buf)
559
self._set_in_buffer(None)
513
560
if self.bytes_left <= 0:
514
561
# Finished with body
515
562
if self.bytes_left != 0:
517
564
self._body = self._body[:self.bytes_left]
518
565
self.bytes_left = None
519
566
self.state_accept = self._state_accept_reading_trailer
521
568
def _state_accept_reading_trailer(self):
522
self._trailer_buffer += self._in_buffer
569
self._trailer_buffer += self._get_in_buffer()
570
self._set_in_buffer(None)
524
571
# TODO: what if the trailer does not match "done\n"? Should this raise
525
572
# a ProtocolViolation exception?
526
573
if self._trailer_buffer.startswith('done\n'):
527
574
self.unused_data = self._trailer_buffer[len('done\n'):]
528
575
self.state_accept = self._state_accept_reading_unused
529
576
self.finished_reading = True
531
578
def _state_accept_reading_unused(self):
532
self.unused_data += self._in_buffer
579
self.unused_data += self._get_in_buffer()
580
self._set_in_buffer(None)
535
582
def _state_read_no_data(self):
609
656
mutter(' %d bytes in readv request', len(readv_bytes))
610
657
self._last_verb = args[0]
659
def call_with_body_stream(self, args, stream):
660
# Protocols v1 and v2 don't support body streams. So it's safe to
661
# assume that a v1/v2 server doesn't support whatever method we're
662
# trying to call with a body stream.
663
self._request.finished_writing()
664
self._request.finished_reading()
665
raise errors.UnknownSmartMethod(args[0])
612
667
def cancel_read_body(self):
613
668
"""After expecting a body, a response code may indicate one otherwise.
674
729
def _response_is_unknown_method(self, result_tuple):
675
730
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
676
731
method' response to the request.
678
733
:param response: The response from a smart client call_expecting_body
680
735
:param verb: The verb used in that call.
687
742
# The response will have no body, so we've finished reading.
688
743
self._request.finished_reading()
689
744
raise errors.UnknownSmartMethod(self._last_verb)
691
746
def read_body_bytes(self, count=-1):
692
747
"""Read bytes from the body, decoding into a byte stream.
694
We read all bytes at once to ensure we've checked the trailer for
749
We read all bytes at once to ensure we've checked the trailer for
695
750
errors, and then feed the buffer back as read_body_bytes is called.
697
752
if self._body_buffer is not None:
698
753
return self._body_buffer.read(count)
699
754
_body_decoder = LengthPrefixedBodyDecoder()
701
# Read no more than 64k at a time so that we don't risk error 10055 (no
702
# buffer space available) on Windows.
704
756
while not _body_decoder.finished_reading:
705
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
706
bytes = self._request.read_bytes(bytes_wanted)
757
bytes = self._request.read_bytes(_body_decoder.next_read_size())
759
# end of file encountered reading from server
760
raise errors.ConnectionReset(
761
"Connection lost while reading response body.")
707
762
_body_decoder.accept_bytes(bytes)
708
763
self._request.finished_reading()
709
764
self._body_buffer = StringIO(_body_decoder.read_pending_data())
716
771
def _recv_tuple(self):
717
772
"""Receive a tuple from the medium request."""
718
return _decode_tuple(self._recv_line())
720
def _recv_line(self):
721
"""Read an entire line from the medium request."""
723
while not line or line[-1] != '\n':
724
# TODO: this is inefficient - but tuples are short.
725
new_char = self._request.read_bytes(1)
727
# end of file encountered reading from server
728
raise errors.ConnectionReset(
729
"please check connectivity and permissions",
730
"(and try -Dhpss if further diagnosis is required)")
773
return _decode_tuple(self._request.read_line())
734
775
def query_version(self):
735
776
"""Return protocol version number of the server."""
750
791
def _write_protocol_version(self):
751
792
"""Write any prefixes this protocol requires.
753
794
Version one doesn't send protocol versions.
757
798
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
758
799
"""Version two of the client side of the smart protocol.
760
801
This prefixes the request with the value of REQUEST_VERSION_TWO.
772
813
if version != self.response_marker:
773
814
self._request.finished_reading()
774
815
raise errors.UnexpectedProtocolVersionMarker(version)
775
response_status = self._recv_line()
816
response_status = self._request.read_line()
776
817
result = SmartClientRequestProtocolOne._read_response_tuple(self)
777
818
self._response_is_unknown_method(result)
778
819
if response_status == 'success\n':
801
842
# Read no more than 64k at a time so that we don't risk error 10055 (no
802
843
# buffer space available) on Windows.
804
844
_body_decoder = ChunkedBodyDecoder()
805
845
while not _body_decoder.finished_reading:
806
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
807
bytes = self._request.read_bytes(bytes_wanted)
846
bytes = self._request.read_bytes(_body_decoder.next_read_size())
848
# end of file encountered reading from server
849
raise errors.ConnectionReset(
850
"Connection lost while reading streamed body.")
808
851
_body_decoder.accept_bytes(bytes)
809
852
for body_bytes in iter(_body_decoder.read_next_chunk, None):
810
853
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
877
920
self.message_handler.protocol_error(exception)
879
922
def _extract_length_prefixed_bytes(self):
880
if len(self._in_buffer) < 4:
923
if self._in_buffer_len < 4:
881
924
# A length prefix by itself is 4 bytes, and we don't even have that
883
926
raise _NeedMoreBytes(4)
884
(length,) = struct.unpack('!L', self._in_buffer[:4])
927
(length,) = struct.unpack('!L', self._get_in_bytes(4))
885
928
end_of_bytes = 4 + length
886
if len(self._in_buffer) < end_of_bytes:
929
if self._in_buffer_len < end_of_bytes:
887
930
# We haven't yet read as many bytes as the length-prefix says there
889
932
raise _NeedMoreBytes(end_of_bytes)
890
933
# Extract the bytes from the buffer.
891
bytes = self._in_buffer[4:end_of_bytes]
892
self._in_buffer = self._in_buffer[end_of_bytes:]
934
in_buf = self._get_in_buffer()
935
bytes = in_buf[4:end_of_bytes]
936
self._set_in_buffer(in_buf[end_of_bytes:])
895
939
def _extract_prefixed_bencoded_data(self):
896
940
prefixed_bytes = self._extract_length_prefixed_bytes()
898
decoded = bdecode(prefixed_bytes)
942
decoded = bdecode_as_tuple(prefixed_bytes)
899
943
except ValueError:
900
944
raise errors.SmartProtocolError(
901
945
'Bytes %r not bencoded' % (prefixed_bytes,))
904
948
def _extract_single_byte(self):
905
if self._in_buffer == '':
949
if self._in_buffer_len == 0:
906
950
# The buffer is empty
907
951
raise _NeedMoreBytes(1)
908
one_byte = self._in_buffer[0]
909
self._in_buffer = self._in_buffer[1:]
952
in_buf = self._get_in_buffer()
954
self._set_in_buffer(in_buf[1:])
912
957
def _state_accept_expecting_protocol_version(self):
913
needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
958
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
959
in_buf = self._get_in_buffer()
914
960
if needed_bytes > 0:
915
961
# We don't have enough bytes to check if the protocol version
916
962
# marker is right. But we can check if it is already wrong by
920
966
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
921
967
# are wrong then we should just raise immediately rather than
923
if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
969
if not MESSAGE_VERSION_THREE.startswith(in_buf):
924
970
# We have enough bytes to know the protocol version is wrong
925
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
971
raise errors.UnexpectedProtocolVersionMarker(in_buf)
926
972
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
927
if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
928
raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
929
self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
973
if not in_buf.startswith(MESSAGE_VERSION_THREE):
974
raise errors.UnexpectedProtocolVersionMarker(in_buf)
975
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
930
976
self.state_accept = self._state_accept_expecting_headers
932
978
def _state_accept_expecting_headers(self):
939
985
self.message_handler.headers_received(decoded)
941
987
raise errors.SmartMessageHandlerError(sys.exc_info())
943
989
def _state_accept_expecting_message_part(self):
944
990
message_part_kind = self._extract_single_byte()
945
991
if message_part_kind == 'o':
981
1027
raise errors.SmartMessageHandlerError(sys.exc_info())
984
self.unused_data = self._in_buffer
1030
self.unused_data = self._get_in_buffer()
1031
self._set_in_buffer(None)
986
1032
self.state_accept = self._state_accept_reading_unused
988
1034
self.message_handler.end_received()
990
1036
raise errors.SmartMessageHandlerError(sys.exc_info())
992
1038
def _state_accept_reading_unused(self):
993
self.unused_data += self._in_buffer
1039
self.unused_data = self._get_in_buffer()
1040
self._set_in_buffer(None)
996
1042
def next_read_size(self):
997
1043
if self.state_accept == self._state_accept_reading_unused:
1014
1060
response_marker = request_marker = MESSAGE_VERSION_THREE
1016
1062
def __init__(self, write_func):
1018
1064
self._real_write_func = write_func
1020
1066
def _write_func(self, bytes):
1067
self._buf.append(bytes)
1068
if len(self._buf) > 100:
1023
1071
def flush(self):
1025
self._real_write_func(self._buf)
1073
self._real_write_func(''.join(self._buf))
1028
1076
def _serialise_offsets(self, offsets):
1029
1077
"""Serialise a readv offset list."""
1031
1079
for start, length in offsets:
1032
1080
txt.append('%d,%d' % (start, length))
1033
1081
return '\n'.join(txt)
1035
1083
def _write_protocol_version(self):
1036
1084
self._write_func(MESSAGE_VERSION_THREE)
1109
1160
if response.body is not None:
1110
1161
self._write_prefixed_body(response.body)
1111
1162
elif response.body_stream is not None:
1112
for chunk in response.body_stream:
1113
self._write_prefixed_body(chunk)
1163
for exc_info, chunk in _iter_with_errors(response.body_stream):
1164
if exc_info is not None:
1165
self._write_error_status()
1166
error_struct = request._translate_error(exc_info[1])
1167
self._write_structure(error_struct)
1170
if isinstance(chunk, request.FailedSmartServerResponse):
1171
self._write_error_status()
1172
self._write_structure(chunk.args)
1174
self._write_prefixed_body(chunk)
1115
1175
self._write_end()
1178
def _iter_with_errors(iterable):
1179
"""Handle errors from iterable.next().
1183
for exc_info, value in _iter_with_errors(iterable):
1186
This is a safer alternative to::
1189
for value in iterable:
1194
Because the latter will catch errors from the for-loop body, not just
1197
If an error occurs, exc_info will be a exc_info tuple, and the generator
1198
will terminate. Otherwise exc_info will be None, and value will be the
1199
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1200
will not be itercepted.
1202
iterator = iter(iterable)
1205
yield None, iterator.next()
1206
except StopIteration:
1208
except (KeyboardInterrupt, SystemExit):
1211
yield sys.exc_info(), None
1118
1215
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1125
1222
def set_headers(self, headers):
1126
1223
self._headers = headers.copy()
1128
1225
def call(self, *args):
1129
1226
if 'hpss' in debug.debug_flags:
1130
1227
mutter('hpss call: %s', repr(args)[1:-1])
1179
1276
self._write_end()
1180
1277
self._medium_request.finished_writing()
1279
def call_with_body_stream(self, args, stream):
1280
if 'hpss' in debug.debug_flags:
1281
mutter('hpss call w/body stream: %r', args)
1282
path = getattr(self._medium_request._medium, '_path', None)
1283
if path is not None:
1284
mutter(' (to %s)', path)
1285
self._request_start_time = time.time()
1286
self._write_protocol_version()
1287
self._write_headers(self._headers)
1288
self._write_structure(args)
1289
# TODO: notice if the server has sent an early error reply before we
1290
# have finished sending the stream. We would notice at the end
1291
# anyway, but if the medium can deliver it early then it's good
1292
# to short-circuit the whole request...
1293
for exc_info, part in _iter_with_errors(stream):
1294
if exc_info is not None:
1295
# Iterating the stream failed. Cleanly abort the request.
1296
self._write_error_status()
1297
# Currently the client unconditionally sends ('error',) as the
1299
self._write_structure(('error',))
1301
self._medium_request.finished_writing()
1302
raise exc_info[0], exc_info[1], exc_info[2]
1304
self._write_prefixed_body(part)
1307
self._medium_request.finished_writing()