13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode, bencode
37
from bzrlib.bencode import bdecode_as_tuple, bencode
35
40
# Protocol version strings. These are sent as prefixes of bzr requests and
58
63
def _encode_tuple(args):
59
64
"""Encode the tuple args to a bytestream."""
60
return '\x01'.join(args) + '\n'
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
63
74
class Requester(object):
109
120
for start, length in offsets:
110
121
txt.append('%d,%d' % (start, length))
111
122
return '\n'.join(txt)
114
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
126
"""Server-side encoding and decoding logic for smart version 1."""
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
128
def __init__(self, backing_transport, write_func, root_client_path='/',
118
130
self._backing_transport = backing_transport
119
131
self._root_client_path = root_client_path
132
self._jail_root = jail_root
120
133
self.unused_data = ''
121
134
self._finished = False
122
135
self.in_buffer = ''
144
157
req_args = _decode_tuple(first_line)
145
158
self.request = request.SmartServerRequestHandler(
146
159
self._backing_transport, commands=request.request_handlers,
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
149
163
if self.request.finished_reading:
150
164
# trivial request
151
165
self.unused_data = self.in_buffer
516
529
class LengthPrefixedBodyDecoder(_StatefulDecoder):
517
530
"""Decodes the length-prefixed bulk data."""
519
532
def __init__(self):
520
533
_StatefulDecoder.__init__(self)
521
534
self.state_accept = self._state_accept_expecting_length
522
535
self.state_read = self._state_read_no_data
524
537
self._trailer_buffer = ''
526
539
def next_read_size(self):
527
540
if self.bytes_left is not None:
528
541
# Ideally we want to read all the remainder of the body and the
613
626
mutter('hpss call: %s', repr(args)[1:-1])
614
627
if getattr(self._request._medium, 'base', None) is not None:
615
628
mutter(' (to %s)', self._request._medium.base)
616
self._request_start_time = time.time()
629
self._request_start_time = osutils.timer_func()
617
630
self._write_args(args)
618
631
self._request.finished_writing()
619
632
self._last_verb = args[0]
628
641
if getattr(self._request._medium, '_path', None) is not None:
629
642
mutter(' (to %s)', self._request._medium._path)
630
643
mutter(' %d bytes', len(body))
631
self._request_start_time = time.time()
644
self._request_start_time = osutils.timer_func()
632
645
if 'hpssdetail' in debug.debug_flags:
633
646
mutter('hpss body content: %s', body)
634
647
self._write_args(args)
641
654
"""Make a remote call with a readv array.
643
656
The body is encoded with one line per readv offset pair. The numbers in
644
each pair are separated by a comma, and no trailing \n is emitted.
657
each pair are separated by a comma, and no trailing \\n is emitted.
646
659
if 'hpss' in debug.debug_flags:
647
660
mutter('hpss call w/readv: %s', repr(args)[1:-1])
648
661
if getattr(self._request._medium, '_path', None) is not None:
649
662
mutter(' (to %s)', self._request._medium._path)
650
self._request_start_time = time.time()
663
self._request_start_time = osutils.timer_func()
651
664
self._write_args(args)
652
665
readv_bytes = self._serialise_offsets(body)
653
666
bytes = self._encode_bulk_data(readv_bytes)
657
670
mutter(' %d bytes in readv request', len(readv_bytes))
658
671
self._last_verb = args[0]
673
def call_with_body_stream(self, args, stream):
674
# Protocols v1 and v2 don't support body streams. So it's safe to
675
# assume that a v1/v2 server doesn't support whatever method we're
676
# trying to call with a body stream.
677
self._request.finished_writing()
678
self._request.finished_reading()
679
raise errors.UnknownSmartMethod(args[0])
660
681
def cancel_read_body(self):
661
682
"""After expecting a body, a response code may indicate one otherwise.
735
756
# The response will have no body, so we've finished reading.
736
757
self._request.finished_reading()
737
758
raise errors.UnknownSmartMethod(self._last_verb)
739
760
def read_body_bytes(self, count=-1):
740
761
"""Read bytes from the body, decoding into a byte stream.
742
We read all bytes at once to ensure we've checked the trailer for
763
We read all bytes at once to ensure we've checked the trailer for
743
764
errors, and then feed the buffer back as read_body_bytes is called.
745
766
if self._body_buffer is not None:
853
874
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
855
876
request_handler = request.SmartServerRequestHandler(
856
877
backing_transport, commands=request.request_handlers,
857
root_client_path=root_client_path)
878
root_client_path=root_client_path, jail_root=jail_root)
858
879
responder = ProtocolThreeResponder(write_func)
859
880
message_handler = message.ConventionalRequestHandler(request_handler, responder)
860
881
return ProtocolThreeDecoder(message_handler)
890
911
# We do *not* set self.decoding_failed here. The message handler
891
912
# has raised an error, but the decoder is still able to parse bytes
892
913
# and determine when this message ends.
893
log_exception_quietly()
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
894
916
self.message_handler.protocol_error(exception.exc_value)
895
917
# The state machine is ready to continue decoding, but the
896
918
# exception has interrupted the loop that runs the state machine.
1051
1073
class _ProtocolThreeEncoder(object):
1053
1075
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1055
1078
def __init__(self, write_func):
1057
1081
self._real_write_func = write_func
1059
1083
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1062
1097
def flush(self):
1064
self._real_write_func(self._buf)
1099
self._real_write_func(''.join(self._buf))
1067
1103
def _serialise_offsets(self, offsets):
1068
1104
"""Serialise a readv offset list."""
1114
1153
_ProtocolThreeEncoder.__init__(self, write_func)
1115
1154
self.response_sent = False
1116
1155
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1118
1176
def send_error(self, exception):
1119
1177
if self.response_sent:
1144
1204
self._write_success_status()
1146
1206
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1147
1209
self._write_structure(response.args)
1148
1210
if response.body is not None:
1149
1211
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1150
1215
elif response.body_stream is not None:
1151
for chunk in response.body_stream:
1152
self._write_prefixed_body(chunk)
1216
count = num_bytes = 0
1218
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
if exc_info is not None:
1221
self._write_error_status()
1222
error_struct = request._translate_error(exc_info[1])
1223
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
self._write_prefixed_body(chunk)
1235
if 'hpssdetail' in debug.debug_flags:
1236
# Not worth timing separately, as _write_func is
1238
self._trace('body chunk',
1239
'%d bytes' % (len(chunk),),
1240
chunk, suppress_time=True)
1241
if 'hpss' in debug.debug_flags:
1242
self._trace('body stream',
1243
'%d bytes %d chunks' % (num_bytes, count),
1154
1245
self._write_end()
1246
if 'hpss' in debug.debug_flags:
1247
self._trace('response end', '', include_time=True)
1250
def _iter_with_errors(iterable):
1251
"""Handle errors from iterable.next().
1255
for exc_info, value in _iter_with_errors(iterable):
1258
This is a safer alternative to::
1261
for value in iterable:
1266
Because the latter will catch errors from the for-loop body, not just
1269
If an error occurs, exc_info will be a exc_info tuple, and the generator
1270
will terminate. Otherwise exc_info will be None, and value will be the
1271
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1272
will not be itercepted.
1274
iterator = iter(iterable)
1277
yield None, iterator.next()
1278
except StopIteration:
1280
except (KeyboardInterrupt, SystemExit):
1283
mutter('_iter_with_errors caught error')
1284
log_exception_quietly()
1285
yield sys.exc_info(), None
1157
1289
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1164
1296
def set_headers(self, headers):
1165
1297
self._headers = headers.copy()
1167
1299
def call(self, *args):
1168
1300
if 'hpss' in debug.debug_flags:
1169
1301
mutter('hpss call: %s', repr(args)[1:-1])
1170
1302
base = getattr(self._medium_request._medium, 'base', None)
1171
1303
if base is not None:
1172
1304
mutter(' (to %s)', base)
1173
self._request_start_time = time.time()
1305
self._request_start_time = osutils.timer_func()
1174
1306
self._write_protocol_version()
1175
1307
self._write_headers(self._headers)
1176
1308
self._write_structure(args)
1200
1332
"""Make a remote call with a readv array.
1202
1334
The body is encoded with one line per readv offset pair. The numbers in
1203
each pair are separated by a comma, and no trailing \n is emitted.
1335
each pair are separated by a comma, and no trailing \\n is emitted.
1205
1337
if 'hpss' in debug.debug_flags:
1206
1338
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1207
1339
path = getattr(self._medium_request._medium, '_path', None)
1208
1340
if path is not None:
1209
1341
mutter(' (to %s)', path)
1210
self._request_start_time = time.time()
1342
self._request_start_time = osutils.timer_func()
1211
1343
self._write_protocol_version()
1212
1344
self._write_headers(self._headers)
1213
1345
self._write_structure(args)
1218
1350
self._write_end()
1219
1351
self._medium_request.finished_writing()
1353
def call_with_body_stream(self, args, stream):
1354
if 'hpss' in debug.debug_flags:
1355
mutter('hpss call w/body stream: %r', args)
1356
path = getattr(self._medium_request._medium, '_path', None)
1357
if path is not None:
1358
mutter(' (to %s)', path)
1359
self._request_start_time = osutils.timer_func()
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
for exc_info, part in _iter_with_errors(stream):
1368
if exc_info is not None:
1369
# Iterating the stream failed. Cleanly abort the request.
1370
self._write_error_status()
1371
# Currently the client unconditionally sends ('error',) as the
1373
self._write_structure(('error',))
1375
self._medium_request.finished_writing()
1376
raise exc_info[0], exc_info[1], exc_info[2]
1378
self._write_prefixed_body(part)
1381
self._medium_request.finished_writing()