13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
22
24
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
30
36
from bzrlib.smart import message, request
31
37
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode, bencode
38
from bzrlib.bencode import bdecode_as_tuple, bencode
35
41
# Protocol version strings. These are sent as prefixes of bzr requests and
58
64
def _encode_tuple(args):
59
65
"""Encode the tuple args to a bytestream."""
60
return '\x01'.join(args) + '\n'
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
63
75
class Requester(object):
109
121
for start, length in offsets:
110
122
txt.append('%d,%d' % (start, length))
111
123
return '\n'.join(txt)
114
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
127
"""Server-side encoding and decoding logic for smart version 1."""
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
129
def __init__(self, backing_transport, write_func, root_client_path='/',
118
131
self._backing_transport = backing_transport
119
132
self._root_client_path = root_client_path
133
self._jail_root = jail_root
120
134
self.unused_data = ''
121
135
self._finished = False
122
136
self.in_buffer = ''
144
158
req_args = _decode_tuple(first_line)
145
159
self.request = request.SmartServerRequestHandler(
146
160
self._backing_transport, commands=request.request_handlers,
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
161
root_client_path=self._root_client_path,
162
jail_root=self._jail_root)
163
self.request.args_received(req_args)
149
164
if self.request.finished_reading:
150
165
# trivial request
151
166
self.unused_data = self.in_buffer
516
530
class LengthPrefixedBodyDecoder(_StatefulDecoder):
517
531
"""Decodes the length-prefixed bulk data."""
519
533
def __init__(self):
520
534
_StatefulDecoder.__init__(self)
521
535
self.state_accept = self._state_accept_expecting_length
522
536
self.state_read = self._state_read_no_data
524
538
self._trailer_buffer = ''
526
540
def next_read_size(self):
527
541
if self.bytes_left is not None:
528
542
# Ideally we want to read all the remainder of the body and the
613
627
mutter('hpss call: %s', repr(args)[1:-1])
614
628
if getattr(self._request._medium, 'base', None) is not None:
615
629
mutter(' (to %s)', self._request._medium.base)
616
self._request_start_time = time.time()
630
self._request_start_time = osutils.timer_func()
617
631
self._write_args(args)
618
632
self._request.finished_writing()
619
633
self._last_verb = args[0]
628
642
if getattr(self._request._medium, '_path', None) is not None:
629
643
mutter(' (to %s)', self._request._medium._path)
630
644
mutter(' %d bytes', len(body))
631
self._request_start_time = time.time()
645
self._request_start_time = osutils.timer_func()
632
646
if 'hpssdetail' in debug.debug_flags:
633
647
mutter('hpss body content: %s', body)
634
648
self._write_args(args)
641
655
"""Make a remote call with a readv array.
643
657
The body is encoded with one line per readv offset pair. The numbers in
644
each pair are separated by a comma, and no trailing \n is emitted.
658
each pair are separated by a comma, and no trailing \\n is emitted.
646
660
if 'hpss' in debug.debug_flags:
647
661
mutter('hpss call w/readv: %s', repr(args)[1:-1])
648
662
if getattr(self._request._medium, '_path', None) is not None:
649
663
mutter(' (to %s)', self._request._medium._path)
650
self._request_start_time = time.time()
664
self._request_start_time = osutils.timer_func()
651
665
self._write_args(args)
652
666
readv_bytes = self._serialise_offsets(body)
653
667
bytes = self._encode_bulk_data(readv_bytes)
657
671
mutter(' %d bytes in readv request', len(readv_bytes))
658
672
self._last_verb = args[0]
674
def call_with_body_stream(self, args, stream):
675
# Protocols v1 and v2 don't support body streams. So it's safe to
676
# assume that a v1/v2 server doesn't support whatever method we're
677
# trying to call with a body stream.
678
self._request.finished_writing()
679
self._request.finished_reading()
680
raise errors.UnknownSmartMethod(args[0])
660
682
def cancel_read_body(self):
661
683
"""After expecting a body, a response code may indicate one otherwise.
735
757
# The response will have no body, so we've finished reading.
736
758
self._request.finished_reading()
737
759
raise errors.UnknownSmartMethod(self._last_verb)
739
761
def read_body_bytes(self, count=-1):
740
762
"""Read bytes from the body, decoding into a byte stream.
742
We read all bytes at once to ensure we've checked the trailer for
764
We read all bytes at once to ensure we've checked the trailer for
743
765
errors, and then feed the buffer back as read_body_bytes is called.
745
767
if self._body_buffer is not None:
853
875
def build_server_protocol_three(backing_transport, write_func,
876
root_client_path, jail_root=None):
855
877
request_handler = request.SmartServerRequestHandler(
856
878
backing_transport, commands=request.request_handlers,
857
root_client_path=root_client_path)
879
root_client_path=root_client_path, jail_root=jail_root)
858
880
responder = ProtocolThreeResponder(write_func)
859
881
message_handler = message.ConventionalRequestHandler(request_handler, responder)
860
882
return ProtocolThreeDecoder(message_handler)
890
912
# We do *not* set self.decoding_failed here. The message handler
891
913
# has raised an error, but the decoder is still able to parse bytes
892
914
# and determine when this message ends.
893
log_exception_quietly()
915
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
916
log_exception_quietly()
894
917
self.message_handler.protocol_error(exception.exc_value)
895
918
# The state machine is ready to continue decoding, but the
896
919
# exception has interrupted the loop that runs the state machine.
1051
1074
class _ProtocolThreeEncoder(object):
1053
1076
response_marker = request_marker = MESSAGE_VERSION_THREE
1077
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1055
1079
def __init__(self, write_func):
1057
1082
self._real_write_func = write_func
1059
1084
def _write_func(self, bytes):
1085
# TODO: Another possibility would be to turn this into an async model.
1086
# Where we let another thread know that we have some bytes if
1087
# they want it, but we don't actually block for it
1088
# Note that osutils.send_all always sends 64kB chunks anyway, so
1089
# we might just push out smaller bits at a time?
1090
self._buf.append(bytes)
1091
self._buf_len += len(bytes)
1092
if self._buf_len > self.BUFFER_SIZE:
1062
1095
def flush(self):
1064
self._real_write_func(self._buf)
1097
self._real_write_func(''.join(self._buf))
1067
1101
def _serialise_offsets(self, offsets):
1068
1102
"""Serialise a readv offset list."""
1114
1151
_ProtocolThreeEncoder.__init__(self, write_func)
1115
1152
self.response_sent = False
1116
1153
self._headers = {'Software version': bzrlib.__version__}
1154
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1156
self._response_start_time = None
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
if extra_bytes is None:
1168
extra = ' ' + repr(extra_bytes[:40])
1170
extra = extra[:29] + extra[-1] + '...'
1171
mutter('%12s: [%s] %s%s%s'
1172
% (action, self._thread_id, t, message, extra))
1118
1174
def send_error(self, exception):
1119
1175
if self.response_sent:
1144
1202
self._write_success_status()
1146
1204
self._write_error_status()
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('response', repr(response.args))
1147
1207
self._write_structure(response.args)
1148
1208
if response.body is not None:
1149
1209
self._write_prefixed_body(response.body)
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('body', '%d bytes' % (len(response.body),),
1212
response.body, include_time=True)
1150
1213
elif response.body_stream is not None:
1151
for chunk in response.body_stream:
1152
self._write_prefixed_body(chunk)
1214
count = num_bytes = 0
1216
for exc_info, chunk in _iter_with_errors(response.body_stream):
1218
if exc_info is not None:
1219
self._write_error_status()
1220
error_struct = request._translate_error(exc_info[1])
1221
self._write_structure(error_struct)
1224
if isinstance(chunk, request.FailedSmartServerResponse):
1225
self._write_error_status()
1226
self._write_structure(chunk.args)
1228
num_bytes += len(chunk)
1229
if first_chunk is None:
1231
self._write_prefixed_body(chunk)
1233
if 'hpssdetail' in debug.debug_flags:
1234
# Not worth timing separately, as _write_func is
1236
self._trace('body chunk',
1237
'%d bytes' % (len(chunk),),
1238
chunk, suppress_time=True)
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('body stream',
1241
'%d bytes %d chunks' % (num_bytes, count),
1154
1243
self._write_end()
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('response end', '', include_time=True)
1248
def _iter_with_errors(iterable):
1249
"""Handle errors from iterable.next().
1253
for exc_info, value in _iter_with_errors(iterable):
1256
This is a safer alternative to::
1259
for value in iterable:
1264
Because the latter will catch errors from the for-loop body, not just
1267
If an error occurs, exc_info will be a exc_info tuple, and the generator
1268
will terminate. Otherwise exc_info will be None, and value will be the
1269
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1270
will not be itercepted.
1272
iterator = iter(iterable)
1275
yield None, iterator.next()
1276
except StopIteration:
1278
except (KeyboardInterrupt, SystemExit):
1281
mutter('_iter_with_errors caught error')
1282
log_exception_quietly()
1283
yield sys.exc_info(), None
1157
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1160
1290
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1161
1291
self._medium_request = medium_request
1162
1292
self._headers = {}
1293
self.body_stream_started = None
1164
1295
def set_headers(self, headers):
1165
1296
self._headers = headers.copy()
1167
1298
def call(self, *args):
1168
1299
if 'hpss' in debug.debug_flags:
1169
1300
mutter('hpss call: %s', repr(args)[1:-1])
1170
1301
base = getattr(self._medium_request._medium, 'base', None)
1171
1302
if base is not None:
1172
1303
mutter(' (to %s)', base)
1173
self._request_start_time = time.time()
1304
self._request_start_time = osutils.timer_func()
1174
1305
self._write_protocol_version()
1175
1306
self._write_headers(self._headers)
1176
1307
self._write_structure(args)
1200
1331
"""Make a remote call with a readv array.
1202
1333
The body is encoded with one line per readv offset pair. The numbers in
1203
each pair are separated by a comma, and no trailing \n is emitted.
1334
each pair are separated by a comma, and no trailing \\n is emitted.
1205
1336
if 'hpss' in debug.debug_flags:
1206
1337
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1207
1338
path = getattr(self._medium_request._medium, '_path', None)
1208
1339
if path is not None:
1209
1340
mutter(' (to %s)', path)
1210
self._request_start_time = time.time()
1341
self._request_start_time = osutils.timer_func()
1211
1342
self._write_protocol_version()
1212
1343
self._write_headers(self._headers)
1213
1344
self._write_structure(args)
1218
1349
self._write_end()
1219
1350
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self.body_stream_started = False
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
# Provoke any ConnectionReset failures before we start the body stream.
1369
self.body_stream_started = True
1370
for exc_info, part in _iter_with_errors(stream):
1371
if exc_info is not None:
1372
# Iterating the stream failed. Cleanly abort the request.
1373
self._write_error_status()
1374
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1378
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1381
self._write_prefixed_body(part)
1384
self._medium_request.finished_writing()