13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
24
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
36
30
from bzrlib.smart import message, request
37
31
from bzrlib.trace import log_exception_quietly, mutter
38
from bzrlib.bencode import bdecode_as_tuple, bencode
32
from bzrlib.util.bencode import bdecode, bencode
41
35
# Protocol version strings. These are sent as prefixes of bzr requests and
64
58
def _encode_tuple(args):
65
59
"""Encode the tuple args to a bytestream."""
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
60
return '\x01'.join(args) + '\n'
75
63
class Requester(object):
121
109
for start, length in offsets:
122
110
txt.append('%d,%d' % (start, length))
123
111
return '\n'.join(txt)
126
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
127
115
"""Server-side encoding and decoding logic for smart version 1."""
129
def __init__(self, backing_transport, write_func, root_client_path='/',
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
131
118
self._backing_transport = backing_transport
132
119
self._root_client_path = root_client_path
133
self._jail_root = jail_root
134
120
self.unused_data = ''
135
121
self._finished = False
136
122
self.in_buffer = ''
158
144
req_args = _decode_tuple(first_line)
159
145
self.request = request.SmartServerRequestHandler(
160
146
self._backing_transport, commands=request.request_handlers,
161
root_client_path=self._root_client_path,
162
jail_root=self._jail_root)
163
self.request.args_received(req_args)
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
164
149
if self.request.finished_reading:
165
150
# trivial request
166
151
self.unused_data = self.in_buffer
530
516
class LengthPrefixedBodyDecoder(_StatefulDecoder):
531
517
"""Decodes the length-prefixed bulk data."""
533
519
def __init__(self):
534
520
_StatefulDecoder.__init__(self)
535
521
self.state_accept = self._state_accept_expecting_length
536
522
self.state_read = self._state_read_no_data
538
524
self._trailer_buffer = ''
540
526
def next_read_size(self):
541
527
if self.bytes_left is not None:
542
528
# Ideally we want to read all the remainder of the body and the
627
613
mutter('hpss call: %s', repr(args)[1:-1])
628
614
if getattr(self._request._medium, 'base', None) is not None:
629
615
mutter(' (to %s)', self._request._medium.base)
630
self._request_start_time = osutils.timer_func()
616
self._request_start_time = time.time()
631
617
self._write_args(args)
632
618
self._request.finished_writing()
633
619
self._last_verb = args[0]
642
628
if getattr(self._request._medium, '_path', None) is not None:
643
629
mutter(' (to %s)', self._request._medium._path)
644
630
mutter(' %d bytes', len(body))
645
self._request_start_time = osutils.timer_func()
631
self._request_start_time = time.time()
646
632
if 'hpssdetail' in debug.debug_flags:
647
633
mutter('hpss body content: %s', body)
648
634
self._write_args(args)
655
641
"""Make a remote call with a readv array.
657
643
The body is encoded with one line per readv offset pair. The numbers in
658
each pair are separated by a comma, and no trailing \\n is emitted.
644
each pair are separated by a comma, and no trailing \n is emitted.
660
646
if 'hpss' in debug.debug_flags:
661
647
mutter('hpss call w/readv: %s', repr(args)[1:-1])
662
648
if getattr(self._request._medium, '_path', None) is not None:
663
649
mutter(' (to %s)', self._request._medium._path)
664
self._request_start_time = osutils.timer_func()
650
self._request_start_time = time.time()
665
651
self._write_args(args)
666
652
readv_bytes = self._serialise_offsets(body)
667
653
bytes = self._encode_bulk_data(readv_bytes)
671
657
mutter(' %d bytes in readv request', len(readv_bytes))
672
658
self._last_verb = args[0]
674
def call_with_body_stream(self, args, stream):
675
# Protocols v1 and v2 don't support body streams. So it's safe to
676
# assume that a v1/v2 server doesn't support whatever method we're
677
# trying to call with a body stream.
678
self._request.finished_writing()
679
self._request.finished_reading()
680
raise errors.UnknownSmartMethod(args[0])
682
660
def cancel_read_body(self):
683
661
"""After expecting a body, a response code may indicate one otherwise.
757
735
# The response will have no body, so we've finished reading.
758
736
self._request.finished_reading()
759
737
raise errors.UnknownSmartMethod(self._last_verb)
761
739
def read_body_bytes(self, count=-1):
762
740
"""Read bytes from the body, decoding into a byte stream.
764
We read all bytes at once to ensure we've checked the trailer for
742
We read all bytes at once to ensure we've checked the trailer for
765
743
errors, and then feed the buffer back as read_body_bytes is called.
767
745
if self._body_buffer is not None:
875
853
def build_server_protocol_three(backing_transport, write_func,
876
root_client_path, jail_root=None):
877
855
request_handler = request.SmartServerRequestHandler(
878
856
backing_transport, commands=request.request_handlers,
879
root_client_path=root_client_path, jail_root=jail_root)
857
root_client_path=root_client_path)
880
858
responder = ProtocolThreeResponder(write_func)
881
859
message_handler = message.ConventionalRequestHandler(request_handler, responder)
882
860
return ProtocolThreeDecoder(message_handler)
912
890
# We do *not* set self.decoding_failed here. The message handler
913
891
# has raised an error, but the decoder is still able to parse bytes
914
892
# and determine when this message ends.
915
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
916
log_exception_quietly()
893
log_exception_quietly()
917
894
self.message_handler.protocol_error(exception.exc_value)
918
895
# The state machine is ready to continue decoding, but the
919
896
# exception has interrupted the loop that runs the state machine.
1074
1051
class _ProtocolThreeEncoder(object):
1076
1053
response_marker = request_marker = MESSAGE_VERSION_THREE
1077
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1079
1055
def __init__(self, write_func):
1082
1057
self._real_write_func = write_func
1084
1059
def _write_func(self, bytes):
1085
# TODO: Another possibility would be to turn this into an async model.
1086
# Where we let another thread know that we have some bytes if
1087
# they want it, but we don't actually block for it
1088
# Note that osutils.send_all always sends 64kB chunks anyway, so
1089
# we might just push out smaller bits at a time?
1090
self._buf.append(bytes)
1091
self._buf_len += len(bytes)
1092
if self._buf_len > self.BUFFER_SIZE:
1095
1062
def flush(self):
1097
self._real_write_func(''.join(self._buf))
1064
self._real_write_func(self._buf)
1101
1067
def _serialise_offsets(self, offsets):
1102
1068
"""Serialise a readv offset list."""
1151
1114
_ProtocolThreeEncoder.__init__(self, write_func)
1152
1115
self.response_sent = False
1153
1116
self._headers = {'Software version': bzrlib.__version__}
1154
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1156
self._response_start_time = None
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
if extra_bytes is None:
1168
extra = ' ' + repr(extra_bytes[:40])
1170
extra = extra[:29] + extra[-1] + '...'
1171
mutter('%12s: [%s] %s%s%s'
1172
% (action, self._thread_id, t, message, extra))
1174
1118
def send_error(self, exception):
1175
1119
if self.response_sent:
1202
1144
self._write_success_status()
1204
1146
self._write_error_status()
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('response', repr(response.args))
1207
1147
self._write_structure(response.args)
1208
1148
if response.body is not None:
1209
1149
self._write_prefixed_body(response.body)
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('body', '%d bytes' % (len(response.body),),
1212
response.body, include_time=True)
1213
1150
elif response.body_stream is not None:
1214
count = num_bytes = 0
1216
for exc_info, chunk in _iter_with_errors(response.body_stream):
1218
if exc_info is not None:
1219
self._write_error_status()
1220
error_struct = request._translate_error(exc_info[1])
1221
self._write_structure(error_struct)
1224
if isinstance(chunk, request.FailedSmartServerResponse):
1225
self._write_error_status()
1226
self._write_structure(chunk.args)
1228
num_bytes += len(chunk)
1229
if first_chunk is None:
1231
self._write_prefixed_body(chunk)
1233
if 'hpssdetail' in debug.debug_flags:
1234
# Not worth timing separately, as _write_func is
1236
self._trace('body chunk',
1237
'%d bytes' % (len(chunk),),
1238
chunk, suppress_time=True)
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('body stream',
1241
'%d bytes %d chunks' % (num_bytes, count),
1151
for chunk in response.body_stream:
1152
self._write_prefixed_body(chunk)
1243
1154
self._write_end()
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('response end', '', include_time=True)
1248
def _iter_with_errors(iterable):
1249
"""Handle errors from iterable.next().
1253
for exc_info, value in _iter_with_errors(iterable):
1256
This is a safer alternative to::
1259
for value in iterable:
1264
Because the latter will catch errors from the for-loop body, not just
1267
If an error occurs, exc_info will be a exc_info tuple, and the generator
1268
will terminate. Otherwise exc_info will be None, and value will be the
1269
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1270
will not be itercepted.
1272
iterator = iter(iterable)
1275
yield None, iterator.next()
1276
except StopIteration:
1278
except (KeyboardInterrupt, SystemExit):
1281
mutter('_iter_with_errors caught error')
1282
log_exception_quietly()
1283
yield sys.exc_info(), None
1287
1157
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1290
1160
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1291
1161
self._medium_request = medium_request
1292
1162
self._headers = {}
1293
self.body_stream_started = None
1295
1164
def set_headers(self, headers):
1296
1165
self._headers = headers.copy()
1298
1167
def call(self, *args):
1299
1168
if 'hpss' in debug.debug_flags:
1300
1169
mutter('hpss call: %s', repr(args)[1:-1])
1301
1170
base = getattr(self._medium_request._medium, 'base', None)
1302
1171
if base is not None:
1303
1172
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1173
self._request_start_time = time.time()
1305
1174
self._write_protocol_version()
1306
1175
self._write_headers(self._headers)
1307
1176
self._write_structure(args)
1331
1200
"""Make a remote call with a readv array.
1333
1202
The body is encoded with one line per readv offset pair. The numbers in
1334
each pair are separated by a comma, and no trailing \\n is emitted.
1203
each pair are separated by a comma, and no trailing \n is emitted.
1336
1205
if 'hpss' in debug.debug_flags:
1337
1206
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1338
1207
path = getattr(self._medium_request._medium, '_path', None)
1339
1208
if path is not None:
1340
1209
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1210
self._request_start_time = time.time()
1342
1211
self._write_protocol_version()
1343
1212
self._write_headers(self._headers)
1344
1213
self._write_structure(args)
1349
1218
self._write_end()
1350
1219
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self.body_stream_started = False
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
# Provoke any ConnectionReset failures before we start the body stream.
1369
self.body_stream_started = True
1370
for exc_info, part in _iter_with_errors(stream):
1371
if exc_info is not None:
1372
# Iterating the stream failed. Cleanly abort the request.
1373
self._write_error_status()
1374
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1378
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1381
self._write_prefixed_body(part)
1384
self._medium_request.finished_writing()