13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode, bencode
40
35
# Protocol version strings. These are sent as prefixes of bzr requests and
63
58
def _encode_tuple(args):
64
59
"""Encode the tuple args to a bytestream."""
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
60
return '\x01'.join(args) + '\n'
74
63
class Requester(object):
120
109
for start, length in offsets:
121
110
txt.append('%d,%d' % (start, length))
122
111
return '\n'.join(txt)
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
"""Server-side encoding and decoding logic for smart version 1."""
128
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
self._backing_transport = backing_transport
131
119
self._root_client_path = root_client_path
132
self._jail_root = jail_root
133
120
self.unused_data = ''
134
121
self._finished = False
135
122
self.in_buffer = ''
157
144
req_args = _decode_tuple(first_line)
158
145
self.request = request.SmartServerRequestHandler(
159
146
self._backing_transport, commands=request.request_handlers,
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
163
149
if self.request.finished_reading:
164
150
# trivial request
165
151
self.unused_data = self.in_buffer
529
516
class LengthPrefixedBodyDecoder(_StatefulDecoder):
530
517
"""Decodes the length-prefixed bulk data."""
532
519
def __init__(self):
533
520
_StatefulDecoder.__init__(self)
534
521
self.state_accept = self._state_accept_expecting_length
535
522
self.state_read = self._state_read_no_data
537
524
self._trailer_buffer = ''
539
526
def next_read_size(self):
540
527
if self.bytes_left is not None:
541
528
# Ideally we want to read all the remainder of the body and the
626
613
mutter('hpss call: %s', repr(args)[1:-1])
627
614
if getattr(self._request._medium, 'base', None) is not None:
628
615
mutter(' (to %s)', self._request._medium.base)
629
self._request_start_time = osutils.timer_func()
616
self._request_start_time = time.time()
630
617
self._write_args(args)
631
618
self._request.finished_writing()
632
619
self._last_verb = args[0]
641
628
if getattr(self._request._medium, '_path', None) is not None:
642
629
mutter(' (to %s)', self._request._medium._path)
643
630
mutter(' %d bytes', len(body))
644
self._request_start_time = osutils.timer_func()
631
self._request_start_time = time.time()
645
632
if 'hpssdetail' in debug.debug_flags:
646
633
mutter('hpss body content: %s', body)
647
634
self._write_args(args)
654
641
"""Make a remote call with a readv array.
656
643
The body is encoded with one line per readv offset pair. The numbers in
657
each pair are separated by a comma, and no trailing \\n is emitted.
644
each pair are separated by a comma, and no trailing \n is emitted.
659
646
if 'hpss' in debug.debug_flags:
660
647
mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
648
if getattr(self._request._medium, '_path', None) is not None:
662
649
mutter(' (to %s)', self._request._medium._path)
663
self._request_start_time = osutils.timer_func()
650
self._request_start_time = time.time()
664
651
self._write_args(args)
665
652
readv_bytes = self._serialise_offsets(body)
666
653
bytes = self._encode_bulk_data(readv_bytes)
670
657
mutter(' %d bytes in readv request', len(readv_bytes))
671
658
self._last_verb = args[0]
673
def call_with_body_stream(self, args, stream):
674
# Protocols v1 and v2 don't support body streams. So it's safe to
675
# assume that a v1/v2 server doesn't support whatever method we're
676
# trying to call with a body stream.
677
self._request.finished_writing()
678
self._request.finished_reading()
679
raise errors.UnknownSmartMethod(args[0])
681
660
def cancel_read_body(self):
682
661
"""After expecting a body, a response code may indicate one otherwise.
756
735
# The response will have no body, so we've finished reading.
757
736
self._request.finished_reading()
758
737
raise errors.UnknownSmartMethod(self._last_verb)
760
739
def read_body_bytes(self, count=-1):
761
740
"""Read bytes from the body, decoding into a byte stream.
763
We read all bytes at once to ensure we've checked the trailer for
742
We read all bytes at once to ensure we've checked the trailer for
764
743
errors, and then feed the buffer back as read_body_bytes is called.
766
745
if self._body_buffer is not None:
874
853
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
876
855
request_handler = request.SmartServerRequestHandler(
877
856
backing_transport, commands=request.request_handlers,
878
root_client_path=root_client_path, jail_root=jail_root)
857
root_client_path=root_client_path)
879
858
responder = ProtocolThreeResponder(write_func)
880
859
message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
860
return ProtocolThreeDecoder(message_handler)
911
890
# We do *not* set self.decoding_failed here. The message handler
912
891
# has raised an error, but the decoder is still able to parse bytes
913
892
# and determine when this message ends.
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
893
log_exception_quietly()
916
894
self.message_handler.protocol_error(exception.exc_value)
917
895
# The state machine is ready to continue decoding, but the
918
896
# exception has interrupted the loop that runs the state machine.
1073
1051
class _ProtocolThreeEncoder(object):
1075
1053
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1078
1055
def __init__(self, write_func):
1081
1057
self._real_write_func = write_func
1083
1059
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1097
1062
def flush(self):
1099
self._real_write_func(''.join(self._buf))
1064
self._real_write_func(self._buf)
1103
1067
def _serialise_offsets(self, offsets):
1104
1068
"""Serialise a readv offset list."""
1153
1114
_ProtocolThreeEncoder.__init__(self, write_func)
1154
1115
self.response_sent = False
1155
1116
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1176
1118
def send_error(self, exception):
1177
1119
if self.response_sent:
1204
1144
self._write_success_status()
1206
1146
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1209
1147
self._write_structure(response.args)
1210
1148
if response.body is not None:
1211
1149
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1215
1150
elif response.body_stream is not None:
1216
count = num_bytes = 0
1218
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
if exc_info is not None:
1221
self._write_error_status()
1222
error_struct = request._translate_error(exc_info[1])
1223
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
self._write_prefixed_body(chunk)
1235
if 'hpssdetail' in debug.debug_flags:
1236
# Not worth timing separately, as _write_func is
1238
self._trace('body chunk',
1239
'%d bytes' % (len(chunk),),
1240
chunk, suppress_time=True)
1241
if 'hpss' in debug.debug_flags:
1242
self._trace('body stream',
1243
'%d bytes %d chunks' % (num_bytes, count),
1151
for chunk in response.body_stream:
1152
self._write_prefixed_body(chunk)
1245
1154
self._write_end()
1246
if 'hpss' in debug.debug_flags:
1247
self._trace('response end', '', include_time=True)
1250
def _iter_with_errors(iterable):
1251
"""Handle errors from iterable.next().
1255
for exc_info, value in _iter_with_errors(iterable):
1258
This is a safer alternative to::
1261
for value in iterable:
1266
Because the latter will catch errors from the for-loop body, not just
1269
If an error occurs, exc_info will be a exc_info tuple, and the generator
1270
will terminate. Otherwise exc_info will be None, and value will be the
1271
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1272
will not be itercepted.
1274
iterator = iter(iterable)
1277
yield None, iterator.next()
1278
except StopIteration:
1280
except (KeyboardInterrupt, SystemExit):
1283
mutter('_iter_with_errors caught error')
1284
log_exception_quietly()
1285
yield sys.exc_info(), None
1289
1157
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1296
1164
def set_headers(self, headers):
1297
1165
self._headers = headers.copy()
1299
1167
def call(self, *args):
1300
1168
if 'hpss' in debug.debug_flags:
1301
1169
mutter('hpss call: %s', repr(args)[1:-1])
1302
1170
base = getattr(self._medium_request._medium, 'base', None)
1303
1171
if base is not None:
1304
1172
mutter(' (to %s)', base)
1305
self._request_start_time = osutils.timer_func()
1173
self._request_start_time = time.time()
1306
1174
self._write_protocol_version()
1307
1175
self._write_headers(self._headers)
1308
1176
self._write_structure(args)
1332
1200
"""Make a remote call with a readv array.
1334
1202
The body is encoded with one line per readv offset pair. The numbers in
1335
each pair are separated by a comma, and no trailing \\n is emitted.
1203
each pair are separated by a comma, and no trailing \n is emitted.
1337
1205
if 'hpss' in debug.debug_flags:
1338
1206
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1339
1207
path = getattr(self._medium_request._medium, '_path', None)
1340
1208
if path is not None:
1341
1209
mutter(' (to %s)', path)
1342
self._request_start_time = osutils.timer_func()
1210
self._request_start_time = time.time()
1343
1211
self._write_protocol_version()
1344
1212
self._write_headers(self._headers)
1345
1213
self._write_structure(args)
1350
1218
self._write_end()
1351
1219
self._medium_request.finished_writing()
1353
def call_with_body_stream(self, args, stream):
1354
if 'hpss' in debug.debug_flags:
1355
mutter('hpss call w/body stream: %r', args)
1356
path = getattr(self._medium_request._medium, '_path', None)
1357
if path is not None:
1358
mutter(' (to %s)', path)
1359
self._request_start_time = osutils.timer_func()
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
for exc_info, part in _iter_with_errors(stream):
1368
if exc_info is not None:
1369
# Iterating the stream failed. Cleanly abort the request.
1370
self._write_error_status()
1371
# Currently the client unconditionally sends ('error',) as the
1373
self._write_structure(('error',))
1375
self._medium_request.finished_writing()
1376
raise exc_info[0], exc_info[1], exc_info[2]
1378
self._write_prefixed_body(part)
1381
self._medium_request.finished_writing()