13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
24
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
36
30
from bzrlib.smart import message, request
37
31
from bzrlib.trace import log_exception_quietly, mutter
38
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode, bencode
41
35
# Protocol version strings. These are sent as prefixes of bzr requests and
64
58
def _encode_tuple(args):
65
59
"""Encode the tuple args to a bytestream."""
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
60
return '\x01'.join(args) + '\n'
75
63
class Requester(object):
121
109
for start, length in offsets:
122
110
txt.append('%d,%d' % (start, length))
123
111
return '\n'.join(txt)
126
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
127
115
"""Server-side encoding and decoding logic for smart version 1."""
129
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
131
118
self._backing_transport = backing_transport
132
119
self._root_client_path = root_client_path
133
self._jail_root = jail_root
134
120
self.unused_data = ''
135
121
self._finished = False
136
122
self.in_buffer = ''
158
144
req_args = _decode_tuple(first_line)
159
145
self.request = request.SmartServerRequestHandler(
160
146
self._backing_transport, commands=request.request_handlers,
161
root_client_path=self._root_client_path,
162
jail_root=self._jail_root)
163
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
164
149
if self.request.finished_reading:
165
150
# trivial request
166
151
self.unused_data = self.in_buffer
530
515
class LengthPrefixedBodyDecoder(_StatefulDecoder):
531
516
"""Decodes the length-prefixed bulk data."""
533
518
def __init__(self):
534
519
_StatefulDecoder.__init__(self)
535
520
self.state_accept = self._state_accept_expecting_length
536
521
self.state_read = self._state_read_no_data
538
523
self._trailer_buffer = ''
540
525
def next_read_size(self):
541
526
if self.bytes_left is not None:
542
527
# Ideally we want to read all the remainder of the body and the
627
612
mutter('hpss call: %s', repr(args)[1:-1])
628
613
if getattr(self._request._medium, 'base', None) is not None:
629
614
mutter(' (to %s)', self._request._medium.base)
630
self._request_start_time = osutils.timer_func()
615
self._request_start_time = time.time()
631
616
self._write_args(args)
632
617
self._request.finished_writing()
633
618
self._last_verb = args[0]
642
627
if getattr(self._request._medium, '_path', None) is not None:
643
628
mutter(' (to %s)', self._request._medium._path)
644
629
mutter(' %d bytes', len(body))
645
self._request_start_time = osutils.timer_func()
630
self._request_start_time = time.time()
646
631
if 'hpssdetail' in debug.debug_flags:
647
632
mutter('hpss body content: %s', body)
648
633
self._write_args(args)
655
640
"""Make a remote call with a readv array.
657
642
The body is encoded with one line per readv offset pair. The numbers in
658
each pair are separated by a comma, and no trailing \\n is emitted.
643
each pair are separated by a comma, and no trailing \n is emitted.
660
645
if 'hpss' in debug.debug_flags:
661
646
mutter('hpss call w/readv: %s', repr(args)[1:-1])
662
647
if getattr(self._request._medium, '_path', None) is not None:
663
648
mutter(' (to %s)', self._request._medium._path)
664
self._request_start_time = osutils.timer_func()
649
self._request_start_time = time.time()
665
650
self._write_args(args)
666
651
readv_bytes = self._serialise_offsets(body)
667
652
bytes = self._encode_bulk_data(readv_bytes)
671
656
mutter(' %d bytes in readv request', len(readv_bytes))
672
657
self._last_verb = args[0]
674
def call_with_body_stream(self, args, stream):
675
# Protocols v1 and v2 don't support body streams. So it's safe to
676
# assume that a v1/v2 server doesn't support whatever method we're
677
# trying to call with a body stream.
678
self._request.finished_writing()
679
self._request.finished_reading()
680
raise errors.UnknownSmartMethod(args[0])
682
659
def cancel_read_body(self):
683
660
"""After expecting a body, a response code may indicate one otherwise.
757
734
# The response will have no body, so we've finished reading.
758
735
self._request.finished_reading()
759
736
raise errors.UnknownSmartMethod(self._last_verb)
761
738
def read_body_bytes(self, count=-1):
762
739
"""Read bytes from the body, decoding into a byte stream.
764
We read all bytes at once to ensure we've checked the trailer for
741
We read all bytes at once to ensure we've checked the trailer for
765
742
errors, and then feed the buffer back as read_body_bytes is called.
767
744
if self._body_buffer is not None:
875
852
def build_server_protocol_three(backing_transport, write_func,
876
root_client_path, jail_root=None):
877
854
request_handler = request.SmartServerRequestHandler(
878
855
backing_transport, commands=request.request_handlers,
879
root_client_path=root_client_path, jail_root=jail_root)
856
root_client_path=root_client_path)
880
857
responder = ProtocolThreeResponder(write_func)
881
858
message_handler = message.ConventionalRequestHandler(request_handler, responder)
882
859
return ProtocolThreeDecoder(message_handler)
912
889
# We do *not* set self.decoding_failed here. The message handler
913
890
# has raised an error, but the decoder is still able to parse bytes
914
891
# and determine when this message ends.
915
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
916
log_exception_quietly()
892
log_exception_quietly()
917
893
self.message_handler.protocol_error(exception.exc_value)
918
894
# The state machine is ready to continue decoding, but the
919
895
# exception has interrupted the loop that runs the state machine.
1074
1050
class _ProtocolThreeEncoder(object):
1076
1052
response_marker = request_marker = MESSAGE_VERSION_THREE
1077
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1079
1054
def __init__(self, write_func):
1082
1056
self._real_write_func = write_func
1084
1058
def _write_func(self, bytes):
1085
# TODO: Another possibility would be to turn this into an async model.
1086
# Where we let another thread know that we have some bytes if
1087
# they want it, but we don't actually block for it
1088
# Note that osutils.send_all always sends 64kB chunks anyway, so
1089
# we might just push out smaller bits at a time?
1090
self._buf.append(bytes)
1091
self._buf_len += len(bytes)
1092
if self._buf_len > self.BUFFER_SIZE:
1095
1061
def flush(self):
1097
self._real_write_func(''.join(self._buf))
1063
self._real_write_func(self._buf)
1101
1066
def _serialise_offsets(self, offsets):
1102
1067
"""Serialise a readv offset list."""
1151
1113
_ProtocolThreeEncoder.__init__(self, write_func)
1152
1114
self.response_sent = False
1153
1115
self._headers = {'Software version': bzrlib.__version__}
1154
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1156
self._response_start_time = None
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
if extra_bytes is None:
1168
extra = ' ' + repr(extra_bytes[:40])
1170
extra = extra[:29] + extra[-1] + '...'
1171
mutter('%12s: [%s] %s%s%s'
1172
% (action, self._thread_id, t, message, extra))
1174
1117
def send_error(self, exception):
1175
1118
if self.response_sent:
1202
1143
self._write_success_status()
1204
1145
self._write_error_status()
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('response', repr(response.args))
1207
1146
self._write_structure(response.args)
1208
1147
if response.body is not None:
1209
1148
self._write_prefixed_body(response.body)
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('body', '%d bytes' % (len(response.body),),
1212
response.body, include_time=True)
1213
1149
elif response.body_stream is not None:
1214
count = num_bytes = 0
1216
for exc_info, chunk in _iter_with_errors(response.body_stream):
1218
if exc_info is not None:
1219
self._write_error_status()
1220
error_struct = request._translate_error(exc_info[1])
1221
self._write_structure(error_struct)
1224
if isinstance(chunk, request.FailedSmartServerResponse):
1225
self._write_error_status()
1226
self._write_structure(chunk.args)
1228
num_bytes += len(chunk)
1229
if first_chunk is None:
1231
self._write_prefixed_body(chunk)
1233
if 'hpssdetail' in debug.debug_flags:
1234
# Not worth timing separately, as _write_func is
1236
self._trace('body chunk',
1237
'%d bytes' % (len(chunk),),
1238
chunk, suppress_time=True)
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('body stream',
1241
'%d bytes %d chunks' % (num_bytes, count),
1150
for chunk in response.body_stream:
1151
self._write_prefixed_body(chunk)
1243
1153
self._write_end()
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('response end', '', include_time=True)
1248
def _iter_with_errors(iterable):
1249
"""Handle errors from iterable.next().
1253
for exc_info, value in _iter_with_errors(iterable):
1256
This is a safer alternative to::
1259
for value in iterable:
1264
Because the latter will catch errors from the for-loop body, not just
1267
If an error occurs, exc_info will be a exc_info tuple, and the generator
1268
will terminate. Otherwise exc_info will be None, and value will be the
1269
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1270
will not be itercepted.
1272
iterator = iter(iterable)
1275
yield None, iterator.next()
1276
except StopIteration:
1278
except (KeyboardInterrupt, SystemExit):
1281
mutter('_iter_with_errors caught error')
1282
log_exception_quietly()
1283
yield sys.exc_info(), None
1287
1156
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1290
1159
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1291
1160
self._medium_request = medium_request
1292
1161
self._headers = {}
1293
self.body_stream_started = None
1295
1163
def set_headers(self, headers):
1296
1164
self._headers = headers.copy()
1298
1166
def call(self, *args):
1299
1167
if 'hpss' in debug.debug_flags:
1300
1168
mutter('hpss call: %s', repr(args)[1:-1])
1301
1169
base = getattr(self._medium_request._medium, 'base', None)
1302
1170
if base is not None:
1303
1171
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1172
self._request_start_time = time.time()
1305
1173
self._write_protocol_version()
1306
1174
self._write_headers(self._headers)
1307
1175
self._write_structure(args)
1331
1199
"""Make a remote call with a readv array.
1333
1201
The body is encoded with one line per readv offset pair. The numbers in
1334
each pair are separated by a comma, and no trailing \\n is emitted.
1202
each pair are separated by a comma, and no trailing \n is emitted.
1336
1204
if 'hpss' in debug.debug_flags:
1337
1205
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1338
1206
path = getattr(self._medium_request._medium, '_path', None)
1339
1207
if path is not None:
1340
1208
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1209
self._request_start_time = time.time()
1342
1210
self._write_protocol_version()
1343
1211
self._write_headers(self._headers)
1344
1212
self._write_structure(args)
1349
1217
self._write_end()
1350
1218
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self.body_stream_started = False
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
# Provoke any ConnectionReset failures before we start the body stream.
1369
self.body_stream_started = True
1370
for exc_info, part in _iter_with_errors(stream):
1371
if exc_info is not None:
1372
# Iterating the stream failed. Cleanly abort the request.
1373
self._write_error_status()
1374
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1378
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1381
self._write_prefixed_body(part)
1384
self._medium_request.finished_writing()