805
603
def _write_protocol_version(self):
806
604
"""Write any prefixes this protocol requires.
808
606
Version one doesn't send protocol versions.
812
610
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
813
611
"""Version two of the client side of the smart protocol.
815
613
This prefixes the request with the value of REQUEST_VERSION_TWO.
818
response_marker = RESPONSE_VERSION_TWO
819
request_marker = REQUEST_VERSION_TWO
821
616
def read_response_tuple(self, expect_body=False):
822
617
"""Read a response tuple from the wire.
824
619
This should only be called once.
826
621
version = self._request.read_line()
827
if version != self.response_marker:
828
self._request.finished_reading()
829
raise errors.UnexpectedProtocolVersionMarker(version)
830
response_status = self._request.read_line()
831
result = SmartClientRequestProtocolOne._read_response_tuple(self)
832
self._response_is_unknown_method(result)
833
if response_status == 'success\n':
834
self.response_status = True
836
self._request.finished_reading()
838
elif response_status == 'failed\n':
839
self.response_status = False
840
self._request.finished_reading()
841
raise errors.ErrorFromSmartServer(result)
622
if version != RESPONSE_VERSION_TWO:
623
raise errors.SmartProtocolError('bad protocol marker %r' % version)
624
response_status = self._recv_line()
625
if response_status not in ('success\n', 'failed\n'):
843
626
raise errors.SmartProtocolError(
844
627
'bad protocol status %r' % response_status)
628
self.response_status = response_status == 'success\n'
629
return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
846
631
def _write_protocol_version(self):
847
632
"""Write any prefixes this protocol requires.
849
634
Version two sends the value of REQUEST_VERSION_TWO.
851
self._request.accept_bytes(self.request_marker)
636
self._request.accept_bytes(REQUEST_VERSION_TWO)
853
638
def read_streamed_body(self):
854
639
"""Read bytes from the body, decoding into a byte stream.
856
641
# Read no more than 64k at a time so that we don't risk error 10055 (no
857
642
# buffer space available) on Windows.
858
644
_body_decoder = ChunkedBodyDecoder()
859
645
while not _body_decoder.finished_reading:
860
bytes = self._request.read_bytes(_body_decoder.next_read_size())
862
# end of file encountered reading from server
863
raise errors.ConnectionReset(
864
"Connection lost while reading streamed body.")
646
bytes_wanted = min(_body_decoder.next_read_size(), max_read)
647
bytes = self._request.read_bytes(bytes_wanted)
865
648
_body_decoder.accept_bytes(bytes)
866
649
for body_bytes in iter(_body_decoder.read_next_chunk, None):
867
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
650
if 'hpss' in debug.debug_flags:
868
651
mutter(' %d byte chunk read',
871
654
self._request.finished_reading()
874
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
876
request_handler = request.SmartServerRequestHandler(
877
backing_transport, commands=request.request_handlers,
878
root_client_path=root_client_path, jail_root=jail_root)
879
responder = ProtocolThreeResponder(write_func)
880
message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
return ProtocolThreeDecoder(message_handler)
884
class ProtocolThreeDecoder(_StatefulDecoder):
886
response_marker = RESPONSE_VERSION_THREE
887
request_marker = REQUEST_VERSION_THREE
889
def __init__(self, message_handler, expect_version_marker=False):
890
_StatefulDecoder.__init__(self)
891
self._has_dispatched = False
893
if expect_version_marker:
894
self.state_accept = self._state_accept_expecting_protocol_version
895
# We're expecting at least the protocol version marker + some
897
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
899
self.state_accept = self._state_accept_expecting_headers
900
self._number_needed_bytes = 4
901
self.decoding_failed = False
902
self.request_handler = self.message_handler = message_handler
904
def accept_bytes(self, bytes):
905
self._number_needed_bytes = None
907
_StatefulDecoder.accept_bytes(self, bytes)
908
except KeyboardInterrupt:
910
except errors.SmartMessageHandlerError, exception:
911
# We do *not* set self.decoding_failed here. The message handler
912
# has raised an error, but the decoder is still able to parse bytes
913
# and determine when this message ends.
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
916
self.message_handler.protocol_error(exception.exc_value)
917
# The state machine is ready to continue decoding, but the
918
# exception has interrupted the loop that runs the state machine.
919
# So we call accept_bytes again to restart it.
920
self.accept_bytes('')
921
except Exception, exception:
922
# The decoder itself has raised an exception. We cannot continue
924
self.decoding_failed = True
925
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
926
# This happens during normal operation when the client tries a
927
# protocol version the server doesn't understand, so no need to
928
# log a traceback every time.
929
# Note that this can only happen when
930
# expect_version_marker=True, which is only the case on the
934
log_exception_quietly()
935
self.message_handler.protocol_error(exception)
937
def _extract_length_prefixed_bytes(self):
938
if self._in_buffer_len < 4:
939
# A length prefix by itself is 4 bytes, and we don't even have that
941
raise _NeedMoreBytes(4)
942
(length,) = struct.unpack('!L', self._get_in_bytes(4))
943
end_of_bytes = 4 + length
944
if self._in_buffer_len < end_of_bytes:
945
# We haven't yet read as many bytes as the length-prefix says there
947
raise _NeedMoreBytes(end_of_bytes)
948
# Extract the bytes from the buffer.
949
in_buf = self._get_in_buffer()
950
bytes = in_buf[4:end_of_bytes]
951
self._set_in_buffer(in_buf[end_of_bytes:])
954
def _extract_prefixed_bencoded_data(self):
955
prefixed_bytes = self._extract_length_prefixed_bytes()
957
decoded = bdecode_as_tuple(prefixed_bytes)
959
raise errors.SmartProtocolError(
960
'Bytes %r not bencoded' % (prefixed_bytes,))
963
def _extract_single_byte(self):
964
if self._in_buffer_len == 0:
965
# The buffer is empty
966
raise _NeedMoreBytes(1)
967
in_buf = self._get_in_buffer()
969
self._set_in_buffer(in_buf[1:])
972
def _state_accept_expecting_protocol_version(self):
973
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
974
in_buf = self._get_in_buffer()
976
# We don't have enough bytes to check if the protocol version
977
# marker is right. But we can check if it is already wrong by
978
# checking that the start of MESSAGE_VERSION_THREE matches what
980
# [In fact, if the remote end isn't bzr we might never receive
981
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
982
# are wrong then we should just raise immediately rather than
984
if not MESSAGE_VERSION_THREE.startswith(in_buf):
985
# We have enough bytes to know the protocol version is wrong
986
raise errors.UnexpectedProtocolVersionMarker(in_buf)
987
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
988
if not in_buf.startswith(MESSAGE_VERSION_THREE):
989
raise errors.UnexpectedProtocolVersionMarker(in_buf)
990
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
991
self.state_accept = self._state_accept_expecting_headers
993
def _state_accept_expecting_headers(self):
994
decoded = self._extract_prefixed_bencoded_data()
995
if type(decoded) is not dict:
996
raise errors.SmartProtocolError(
997
'Header object %r is not a dict' % (decoded,))
998
self.state_accept = self._state_accept_expecting_message_part
1000
self.message_handler.headers_received(decoded)
1002
raise errors.SmartMessageHandlerError(sys.exc_info())
1004
def _state_accept_expecting_message_part(self):
1005
message_part_kind = self._extract_single_byte()
1006
if message_part_kind == 'o':
1007
self.state_accept = self._state_accept_expecting_one_byte
1008
elif message_part_kind == 's':
1009
self.state_accept = self._state_accept_expecting_structure
1010
elif message_part_kind == 'b':
1011
self.state_accept = self._state_accept_expecting_bytes
1012
elif message_part_kind == 'e':
1015
raise errors.SmartProtocolError(
1016
'Bad message kind byte: %r' % (message_part_kind,))
1018
def _state_accept_expecting_one_byte(self):
1019
byte = self._extract_single_byte()
1020
self.state_accept = self._state_accept_expecting_message_part
1022
self.message_handler.byte_part_received(byte)
1024
raise errors.SmartMessageHandlerError(sys.exc_info())
1026
def _state_accept_expecting_bytes(self):
1027
# XXX: this should not buffer whole message part, but instead deliver
1028
# the bytes as they arrive.
1029
prefixed_bytes = self._extract_length_prefixed_bytes()
1030
self.state_accept = self._state_accept_expecting_message_part
1032
self.message_handler.bytes_part_received(prefixed_bytes)
1034
raise errors.SmartMessageHandlerError(sys.exc_info())
1036
def _state_accept_expecting_structure(self):
1037
structure = self._extract_prefixed_bencoded_data()
1038
self.state_accept = self._state_accept_expecting_message_part
1040
self.message_handler.structure_part_received(structure)
1042
raise errors.SmartMessageHandlerError(sys.exc_info())
1045
self.unused_data = self._get_in_buffer()
1046
self._set_in_buffer(None)
1047
self.state_accept = self._state_accept_reading_unused
1049
self.message_handler.end_received()
1051
raise errors.SmartMessageHandlerError(sys.exc_info())
1053
def _state_accept_reading_unused(self):
1054
self.unused_data += self._get_in_buffer()
1055
self._set_in_buffer(None)
1057
def next_read_size(self):
1058
if self.state_accept == self._state_accept_reading_unused:
1060
elif self.decoding_failed:
1061
# An exception occured while processing this message, probably from
1062
# self.message_handler. We're not sure that this state machine is
1063
# in a consistent state, so just signal that we're done (i.e. give
1067
if self._number_needed_bytes is not None:
1068
return self._number_needed_bytes - self._in_buffer_len
1070
raise AssertionError("don't know how many bytes are expected!")
1073
class _ProtocolThreeEncoder(object):
1075
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1078
def __init__(self, write_func):
1081
self._real_write_func = write_func
1083
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1099
self._real_write_func(''.join(self._buf))
1103
def _serialise_offsets(self, offsets):
1104
"""Serialise a readv offset list."""
1106
for start, length in offsets:
1107
txt.append('%d,%d' % (start, length))
1108
return '\n'.join(txt)
1110
def _write_protocol_version(self):
1111
self._write_func(MESSAGE_VERSION_THREE)
1113
def _write_prefixed_bencode(self, structure):
1114
bytes = bencode(structure)
1115
self._write_func(struct.pack('!L', len(bytes)))
1116
self._write_func(bytes)
1118
def _write_headers(self, headers):
1119
self._write_prefixed_bencode(headers)
1121
def _write_structure(self, args):
1122
self._write_func('s')
1125
if type(arg) is unicode:
1126
utf8_args.append(arg.encode('utf8'))
1128
utf8_args.append(arg)
1129
self._write_prefixed_bencode(utf8_args)
1131
def _write_end(self):
1132
self._write_func('e')
1135
def _write_prefixed_body(self, bytes):
1136
self._write_func('b')
1137
self._write_func(struct.pack('!L', len(bytes)))
1138
self._write_func(bytes)
1140
def _write_chunked_body_start(self):
1141
self._write_func('oC')
1143
def _write_error_status(self):
1144
self._write_func('oE')
1146
def _write_success_status(self):
1147
self._write_func('oS')
1150
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1152
def __init__(self, write_func):
1153
_ProtocolThreeEncoder.__init__(self, write_func)
1154
self.response_sent = False
1155
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1176
def send_error(self, exception):
1177
if self.response_sent:
1178
raise AssertionError(
1179
"send_error(%s) called, but response already sent."
1181
if isinstance(exception, errors.UnknownSmartMethod):
1182
failure = request.FailedSmartServerResponse(
1183
('UnknownMethod', exception.verb))
1184
self.send_response(failure)
1186
if 'hpss' in debug.debug_flags:
1187
self._trace('error', str(exception))
1188
self.response_sent = True
1189
self._write_protocol_version()
1190
self._write_headers(self._headers)
1191
self._write_error_status()
1192
self._write_structure(('error', str(exception)))
1195
def send_response(self, response):
1196
if self.response_sent:
1197
raise AssertionError(
1198
"send_response(%r) called, but response already sent."
1200
self.response_sent = True
1201
self._write_protocol_version()
1202
self._write_headers(self._headers)
1203
if response.is_successful():
1204
self._write_success_status()
1206
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1209
self._write_structure(response.args)
1210
if response.body is not None:
1211
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1215
elif response.body_stream is not None:
1216
count = num_bytes = 0
1218
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
if exc_info is not None:
1221
self._write_error_status()
1222
error_struct = request._translate_error(exc_info[1])
1223
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
self._write_prefixed_body(chunk)
1235
if 'hpssdetail' in debug.debug_flags:
1236
# Not worth timing separately, as _write_func is
1238
self._trace('body chunk',
1239
'%d bytes' % (len(chunk),),
1240
chunk, suppress_time=True)
1241
if 'hpss' in debug.debug_flags:
1242
self._trace('body stream',
1243
'%d bytes %d chunks' % (num_bytes, count),
1246
if 'hpss' in debug.debug_flags:
1247
self._trace('response end', '', include_time=True)
1250
def _iter_with_errors(iterable):
1251
"""Handle errors from iterable.next().
1255
for exc_info, value in _iter_with_errors(iterable):
1258
This is a safer alternative to::
1261
for value in iterable:
1266
Because the latter will catch errors from the for-loop body, not just
1269
If an error occurs, exc_info will be a exc_info tuple, and the generator
1270
will terminate. Otherwise exc_info will be None, and value will be the
1271
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1272
will not be itercepted.
1274
iterator = iter(iterable)
1277
yield None, iterator.next()
1278
except StopIteration:
1280
except (KeyboardInterrupt, SystemExit):
1283
mutter('_iter_with_errors caught error')
1284
log_exception_quietly()
1285
yield sys.exc_info(), None
1289
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1291
def __init__(self, medium_request):
1292
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1293
self._medium_request = medium_request
1296
def set_headers(self, headers):
1297
self._headers = headers.copy()
1299
def call(self, *args):
1300
if 'hpss' in debug.debug_flags:
1301
mutter('hpss call: %s', repr(args)[1:-1])
1302
base = getattr(self._medium_request._medium, 'base', None)
1303
if base is not None:
1304
mutter(' (to %s)', base)
1305
self._request_start_time = osutils.timer_func()
1306
self._write_protocol_version()
1307
self._write_headers(self._headers)
1308
self._write_structure(args)
1310
self._medium_request.finished_writing()
1312
def call_with_body_bytes(self, args, body):
1313
"""Make a remote call of args with body bytes 'body'.
1315
After calling this, call read_response_tuple to find the result out.
1317
if 'hpss' in debug.debug_flags:
1318
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1319
path = getattr(self._medium_request._medium, '_path', None)
1320
if path is not None:
1321
mutter(' (to %s)', path)
1322
mutter(' %d bytes', len(body))
1323
self._request_start_time = osutils.timer_func()
1324
self._write_protocol_version()
1325
self._write_headers(self._headers)
1326
self._write_structure(args)
1327
self._write_prefixed_body(body)
1329
self._medium_request.finished_writing()
1331
def call_with_body_readv_array(self, args, body):
1332
"""Make a remote call with a readv array.
1334
The body is encoded with one line per readv offset pair. The numbers in
1335
each pair are separated by a comma, and no trailing \\n is emitted.
1337
if 'hpss' in debug.debug_flags:
1338
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1339
path = getattr(self._medium_request._medium, '_path', None)
1340
if path is not None:
1341
mutter(' (to %s)', path)
1342
self._request_start_time = osutils.timer_func()
1343
self._write_protocol_version()
1344
self._write_headers(self._headers)
1345
self._write_structure(args)
1346
readv_bytes = self._serialise_offsets(body)
1347
if 'hpss' in debug.debug_flags:
1348
mutter(' %d bytes in readv request', len(readv_bytes))
1349
self._write_prefixed_body(readv_bytes)
1351
self._medium_request.finished_writing()
1353
def call_with_body_stream(self, args, stream):
1354
if 'hpss' in debug.debug_flags:
1355
mutter('hpss call w/body stream: %r', args)
1356
path = getattr(self._medium_request._medium, '_path', None)
1357
if path is not None:
1358
mutter(' (to %s)', path)
1359
self._request_start_time = osutils.timer_func()
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
for exc_info, part in _iter_with_errors(stream):
1368
if exc_info is not None:
1369
# Iterating the stream failed. Cleanly abort the request.
1370
self._write_error_status()
1371
# Currently the client unconditionally sends ('error',) as the
1373
self._write_structure(('error',))
1375
self._medium_request.finished_writing()
1376
raise exc_info[0], exc_info[1], exc_info[2]
1378
self._write_prefixed_body(part)
1381
self._medium_request.finished_writing()