806
627
def _write_protocol_version(self):
807
628
"""Write any prefixes this protocol requires.
809
630
Version one doesn't send protocol versions.
813
634
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
814
635
"""Version two of the client side of the smart protocol.
816
637
This prefixes the request with the value of REQUEST_VERSION_TWO.
819
response_marker = RESPONSE_VERSION_TWO
820
request_marker = REQUEST_VERSION_TWO
822
640
def read_response_tuple(self, expect_body=False):
823
641
"""Read a response tuple from the wire.
825
643
This should only be called once.
827
645
version = self._request.read_line()
828
if version != self.response_marker:
829
self._request.finished_reading()
830
raise errors.UnexpectedProtocolVersionMarker(version)
831
response_status = self._request.read_line()
832
result = SmartClientRequestProtocolOne._read_response_tuple(self)
833
self._response_is_unknown_method(result)
834
if response_status == 'success\n':
835
self.response_status = True
837
self._request.finished_reading()
839
elif response_status == 'failed\n':
840
self.response_status = False
841
self._request.finished_reading()
842
raise errors.ErrorFromSmartServer(result)
646
if version != RESPONSE_VERSION_TWO:
647
raise errors.SmartProtocolError('bad protocol marker %r' % version)
648
response_status = self._recv_line()
649
if response_status not in ('success\n', 'failed\n'):
844
650
raise errors.SmartProtocolError(
845
651
'bad protocol status %r' % response_status)
652
self.response_status = response_status == 'success\n'
653
return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
847
655
def _write_protocol_version(self):
848
656
"""Write any prefixes this protocol requires.
850
658
Version two sends the value of REQUEST_VERSION_TWO.
852
self._request.accept_bytes(self.request_marker)
660
self._request.accept_bytes(REQUEST_VERSION_TWO)
854
662
def read_streamed_body(self):
855
663
"""Read bytes from the body, decoding into a byte stream.
857
665
# Read no more than 64k at a time so that we don't risk error 10055 (no
858
666
# buffer space available) on Windows.
859
668
_body_decoder = ChunkedBodyDecoder()
860
669
while not _body_decoder.finished_reading:
861
bytes = self._request.read_bytes(_body_decoder.next_read_size())
863
# end of file encountered reading from server
864
raise errors.ConnectionReset(
865
"Connection lost while reading streamed body.")
670
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
671
bytes = self._request.read_bytes(bytes_wanted)
866
672
_body_decoder.accept_bytes(bytes)
867
673
for body_bytes in iter(_body_decoder.read_next_chunk, None):
868
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
674
if 'hpss' in debug.debug_flags:
869
675
mutter(' %d byte chunk read',
872
678
self._request.finished_reading()
875
def build_server_protocol_three(backing_transport, write_func,
876
root_client_path, jail_root=None):
877
request_handler = request.SmartServerRequestHandler(
878
backing_transport, commands=request.request_handlers,
879
root_client_path=root_client_path, jail_root=jail_root)
880
responder = ProtocolThreeResponder(write_func)
881
message_handler = message.ConventionalRequestHandler(request_handler, responder)
882
return ProtocolThreeDecoder(message_handler)
885
class ProtocolThreeDecoder(_StatefulDecoder):
887
response_marker = RESPONSE_VERSION_THREE
888
request_marker = REQUEST_VERSION_THREE
890
def __init__(self, message_handler, expect_version_marker=False):
891
_StatefulDecoder.__init__(self)
892
self._has_dispatched = False
894
if expect_version_marker:
895
self.state_accept = self._state_accept_expecting_protocol_version
896
# We're expecting at least the protocol version marker + some
898
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
900
self.state_accept = self._state_accept_expecting_headers
901
self._number_needed_bytes = 4
902
self.decoding_failed = False
903
self.request_handler = self.message_handler = message_handler
905
def accept_bytes(self, bytes):
906
self._number_needed_bytes = None
908
_StatefulDecoder.accept_bytes(self, bytes)
909
except KeyboardInterrupt:
911
except errors.SmartMessageHandlerError, exception:
912
# We do *not* set self.decoding_failed here. The message handler
913
# has raised an error, but the decoder is still able to parse bytes
914
# and determine when this message ends.
915
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
916
log_exception_quietly()
917
self.message_handler.protocol_error(exception.exc_value)
918
# The state machine is ready to continue decoding, but the
919
# exception has interrupted the loop that runs the state machine.
920
# So we call accept_bytes again to restart it.
921
self.accept_bytes('')
922
except Exception, exception:
923
# The decoder itself has raised an exception. We cannot continue
925
self.decoding_failed = True
926
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
927
# This happens during normal operation when the client tries a
928
# protocol version the server doesn't understand, so no need to
929
# log a traceback every time.
930
# Note that this can only happen when
931
# expect_version_marker=True, which is only the case on the
935
log_exception_quietly()
936
self.message_handler.protocol_error(exception)
938
def _extract_length_prefixed_bytes(self):
939
if self._in_buffer_len < 4:
940
# A length prefix by itself is 4 bytes, and we don't even have that
942
raise _NeedMoreBytes(4)
943
(length,) = struct.unpack('!L', self._get_in_bytes(4))
944
end_of_bytes = 4 + length
945
if self._in_buffer_len < end_of_bytes:
946
# We haven't yet read as many bytes as the length-prefix says there
948
raise _NeedMoreBytes(end_of_bytes)
949
# Extract the bytes from the buffer.
950
in_buf = self._get_in_buffer()
951
bytes = in_buf[4:end_of_bytes]
952
self._set_in_buffer(in_buf[end_of_bytes:])
955
def _extract_prefixed_bencoded_data(self):
956
prefixed_bytes = self._extract_length_prefixed_bytes()
958
decoded = bdecode_as_tuple(prefixed_bytes)
960
raise errors.SmartProtocolError(
961
'Bytes %r not bencoded' % (prefixed_bytes,))
964
def _extract_single_byte(self):
965
if self._in_buffer_len == 0:
966
# The buffer is empty
967
raise _NeedMoreBytes(1)
968
in_buf = self._get_in_buffer()
970
self._set_in_buffer(in_buf[1:])
973
def _state_accept_expecting_protocol_version(self):
974
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
975
in_buf = self._get_in_buffer()
977
# We don't have enough bytes to check if the protocol version
978
# marker is right. But we can check if it is already wrong by
979
# checking that the start of MESSAGE_VERSION_THREE matches what
981
# [In fact, if the remote end isn't bzr we might never receive
982
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
983
# are wrong then we should just raise immediately rather than
985
if not MESSAGE_VERSION_THREE.startswith(in_buf):
986
# We have enough bytes to know the protocol version is wrong
987
raise errors.UnexpectedProtocolVersionMarker(in_buf)
988
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
989
if not in_buf.startswith(MESSAGE_VERSION_THREE):
990
raise errors.UnexpectedProtocolVersionMarker(in_buf)
991
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
992
self.state_accept = self._state_accept_expecting_headers
994
def _state_accept_expecting_headers(self):
995
decoded = self._extract_prefixed_bencoded_data()
996
if type(decoded) is not dict:
997
raise errors.SmartProtocolError(
998
'Header object %r is not a dict' % (decoded,))
999
self.state_accept = self._state_accept_expecting_message_part
1001
self.message_handler.headers_received(decoded)
1003
raise errors.SmartMessageHandlerError(sys.exc_info())
1005
def _state_accept_expecting_message_part(self):
1006
message_part_kind = self._extract_single_byte()
1007
if message_part_kind == 'o':
1008
self.state_accept = self._state_accept_expecting_one_byte
1009
elif message_part_kind == 's':
1010
self.state_accept = self._state_accept_expecting_structure
1011
elif message_part_kind == 'b':
1012
self.state_accept = self._state_accept_expecting_bytes
1013
elif message_part_kind == 'e':
1016
raise errors.SmartProtocolError(
1017
'Bad message kind byte: %r' % (message_part_kind,))
1019
def _state_accept_expecting_one_byte(self):
1020
byte = self._extract_single_byte()
1021
self.state_accept = self._state_accept_expecting_message_part
1023
self.message_handler.byte_part_received(byte)
1025
raise errors.SmartMessageHandlerError(sys.exc_info())
1027
def _state_accept_expecting_bytes(self):
1028
# XXX: this should not buffer whole message part, but instead deliver
1029
# the bytes as they arrive.
1030
prefixed_bytes = self._extract_length_prefixed_bytes()
1031
self.state_accept = self._state_accept_expecting_message_part
1033
self.message_handler.bytes_part_received(prefixed_bytes)
1035
raise errors.SmartMessageHandlerError(sys.exc_info())
1037
def _state_accept_expecting_structure(self):
1038
structure = self._extract_prefixed_bencoded_data()
1039
self.state_accept = self._state_accept_expecting_message_part
1041
self.message_handler.structure_part_received(structure)
1043
raise errors.SmartMessageHandlerError(sys.exc_info())
1046
self.unused_data = self._get_in_buffer()
1047
self._set_in_buffer(None)
1048
self.state_accept = self._state_accept_reading_unused
1050
self.message_handler.end_received()
1052
raise errors.SmartMessageHandlerError(sys.exc_info())
1054
def _state_accept_reading_unused(self):
1055
self.unused_data += self._get_in_buffer()
1056
self._set_in_buffer(None)
1058
def next_read_size(self):
1059
if self.state_accept == self._state_accept_reading_unused:
1061
elif self.decoding_failed:
1062
# An exception occured while processing this message, probably from
1063
# self.message_handler. We're not sure that this state machine is
1064
# in a consistent state, so just signal that we're done (i.e. give
1068
if self._number_needed_bytes is not None:
1069
return self._number_needed_bytes - self._in_buffer_len
1071
raise AssertionError("don't know how many bytes are expected!")
1074
class _ProtocolThreeEncoder(object):
1076
response_marker = request_marker = MESSAGE_VERSION_THREE
1077
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1079
def __init__(self, write_func):
1082
self._real_write_func = write_func
1084
def _write_func(self, bytes):
1085
# TODO: Another possibility would be to turn this into an async model.
1086
# Where we let another thread know that we have some bytes if
1087
# they want it, but we don't actually block for it
1088
# Note that osutils.send_all always sends 64kB chunks anyway, so
1089
# we might just push out smaller bits at a time?
1090
self._buf.append(bytes)
1091
self._buf_len += len(bytes)
1092
if self._buf_len > self.BUFFER_SIZE:
1097
self._real_write_func(''.join(self._buf))
1101
def _serialise_offsets(self, offsets):
1102
"""Serialise a readv offset list."""
1104
for start, length in offsets:
1105
txt.append('%d,%d' % (start, length))
1106
return '\n'.join(txt)
1108
def _write_protocol_version(self):
1109
self._write_func(MESSAGE_VERSION_THREE)
1111
def _write_prefixed_bencode(self, structure):
1112
bytes = bencode(structure)
1113
self._write_func(struct.pack('!L', len(bytes)))
1114
self._write_func(bytes)
1116
def _write_headers(self, headers):
1117
self._write_prefixed_bencode(headers)
1119
def _write_structure(self, args):
1120
self._write_func('s')
1123
if type(arg) is unicode:
1124
utf8_args.append(arg.encode('utf8'))
1126
utf8_args.append(arg)
1127
self._write_prefixed_bencode(utf8_args)
1129
def _write_end(self):
1130
self._write_func('e')
1133
def _write_prefixed_body(self, bytes):
1134
self._write_func('b')
1135
self._write_func(struct.pack('!L', len(bytes)))
1136
self._write_func(bytes)
1138
def _write_chunked_body_start(self):
1139
self._write_func('oC')
1141
def _write_error_status(self):
1142
self._write_func('oE')
1144
def _write_success_status(self):
1145
self._write_func('oS')
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1150
def __init__(self, write_func):
1151
_ProtocolThreeEncoder.__init__(self, write_func)
1152
self.response_sent = False
1153
self._headers = {'Software version': bzrlib.__version__}
1154
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1156
self._response_start_time = None
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
if extra_bytes is None:
1168
extra = ' ' + repr(extra_bytes[:40])
1170
extra = extra[:29] + extra[-1] + '...'
1171
mutter('%12s: [%s] %s%s%s'
1172
% (action, self._thread_id, t, message, extra))
1174
def send_error(self, exception):
1175
if self.response_sent:
1176
raise AssertionError(
1177
"send_error(%s) called, but response already sent."
1179
if isinstance(exception, errors.UnknownSmartMethod):
1180
failure = request.FailedSmartServerResponse(
1181
('UnknownMethod', exception.verb))
1182
self.send_response(failure)
1184
if 'hpss' in debug.debug_flags:
1185
self._trace('error', str(exception))
1186
self.response_sent = True
1187
self._write_protocol_version()
1188
self._write_headers(self._headers)
1189
self._write_error_status()
1190
self._write_structure(('error', str(exception)))
1193
def send_response(self, response):
1194
if self.response_sent:
1195
raise AssertionError(
1196
"send_response(%r) called, but response already sent."
1198
self.response_sent = True
1199
self._write_protocol_version()
1200
self._write_headers(self._headers)
1201
if response.is_successful():
1202
self._write_success_status()
1204
self._write_error_status()
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('response', repr(response.args))
1207
self._write_structure(response.args)
1208
if response.body is not None:
1209
self._write_prefixed_body(response.body)
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('body', '%d bytes' % (len(response.body),),
1212
response.body, include_time=True)
1213
elif response.body_stream is not None:
1214
count = num_bytes = 0
1216
for exc_info, chunk in _iter_with_errors(response.body_stream):
1218
if exc_info is not None:
1219
self._write_error_status()
1220
error_struct = request._translate_error(exc_info[1])
1221
self._write_structure(error_struct)
1224
if isinstance(chunk, request.FailedSmartServerResponse):
1225
self._write_error_status()
1226
self._write_structure(chunk.args)
1228
num_bytes += len(chunk)
1229
if first_chunk is None:
1231
self._write_prefixed_body(chunk)
1233
if 'hpssdetail' in debug.debug_flags:
1234
# Not worth timing separately, as _write_func is
1236
self._trace('body chunk',
1237
'%d bytes' % (len(chunk),),
1238
chunk, suppress_time=True)
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('body stream',
1241
'%d bytes %d chunks' % (num_bytes, count),
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('response end', '', include_time=True)
1248
def _iter_with_errors(iterable):
1249
"""Handle errors from iterable.next().
1253
for exc_info, value in _iter_with_errors(iterable):
1256
This is a safer alternative to::
1259
for value in iterable:
1264
Because the latter will catch errors from the for-loop body, not just
1267
If an error occurs, exc_info will be a exc_info tuple, and the generator
1268
will terminate. Otherwise exc_info will be None, and value will be the
1269
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1270
will not be itercepted.
1272
iterator = iter(iterable)
1275
yield None, iterator.next()
1276
except StopIteration:
1278
except (KeyboardInterrupt, SystemExit):
1281
mutter('_iter_with_errors caught error')
1282
log_exception_quietly()
1283
yield sys.exc_info(), None
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1289
def __init__(self, medium_request):
1290
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1291
self._medium_request = medium_request
1293
self.body_stream_started = None
1295
def set_headers(self, headers):
1296
self._headers = headers.copy()
1298
def call(self, *args):
1299
if 'hpss' in debug.debug_flags:
1300
mutter('hpss call: %s', repr(args)[1:-1])
1301
base = getattr(self._medium_request._medium, 'base', None)
1302
if base is not None:
1303
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1305
self._write_protocol_version()
1306
self._write_headers(self._headers)
1307
self._write_structure(args)
1309
self._medium_request.finished_writing()
1311
def call_with_body_bytes(self, args, body):
1312
"""Make a remote call of args with body bytes 'body'.
1314
After calling this, call read_response_tuple to find the result out.
1316
if 'hpss' in debug.debug_flags:
1317
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1318
path = getattr(self._medium_request._medium, '_path', None)
1319
if path is not None:
1320
mutter(' (to %s)', path)
1321
mutter(' %d bytes', len(body))
1322
self._request_start_time = osutils.timer_func()
1323
self._write_protocol_version()
1324
self._write_headers(self._headers)
1325
self._write_structure(args)
1326
self._write_prefixed_body(body)
1328
self._medium_request.finished_writing()
1330
def call_with_body_readv_array(self, args, body):
1331
"""Make a remote call with a readv array.
1333
The body is encoded with one line per readv offset pair. The numbers in
1334
each pair are separated by a comma, and no trailing \\n is emitted.
1336
if 'hpss' in debug.debug_flags:
1337
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1338
path = getattr(self._medium_request._medium, '_path', None)
1339
if path is not None:
1340
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1342
self._write_protocol_version()
1343
self._write_headers(self._headers)
1344
self._write_structure(args)
1345
readv_bytes = self._serialise_offsets(body)
1346
if 'hpss' in debug.debug_flags:
1347
mutter(' %d bytes in readv request', len(readv_bytes))
1348
self._write_prefixed_body(readv_bytes)
1350
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self.body_stream_started = False
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
# Provoke any ConnectionReset failures before we start the body stream.
1369
self.body_stream_started = True
1370
for exc_info, part in _iter_with_errors(stream):
1371
if exc_info is not None:
1372
# Iterating the stream failed. Cleanly abort the request.
1373
self._write_error_status()
1374
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1378
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1381
self._write_prefixed_body(part)
1384
self._medium_request.finished_writing()