13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode_as_tuple, bencode
37
from bzrlib.bencode import bdecode_as_tuple, bencode
35
40
# Protocol version strings. These are sent as prefixes of bzr requests and
58
63
def _encode_tuple(args):
59
64
"""Encode the tuple args to a bytestream."""
60
return '\x01'.join(args) + '\n'
65
joined = '\x01'.join(args) + '\n'
66
if type(joined) is unicode:
67
# XXX: We should fix things so this never happens! -AJB, 20100304
68
mutter('response args contain unicode, should be only bytes: %r',
70
joined = joined.encode('ascii')
63
74
class Requester(object):
109
120
for start, length in offsets:
110
121
txt.append('%d,%d' % (start, length))
111
122
return '\n'.join(txt)
114
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
126
"""Server-side encoding and decoding logic for smart version 1."""
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
128
def __init__(self, backing_transport, write_func, root_client_path='/',
118
130
self._backing_transport = backing_transport
119
131
self._root_client_path = root_client_path
132
self._jail_root = jail_root
120
133
self.unused_data = ''
121
134
self._finished = False
122
135
self.in_buffer = ''
144
157
req_args = _decode_tuple(first_line)
145
158
self.request = request.SmartServerRequestHandler(
146
159
self._backing_transport, commands=request.request_handlers,
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
160
root_client_path=self._root_client_path,
161
jail_root=self._jail_root)
162
self.request.args_received(req_args)
149
163
if self.request.finished_reading:
150
164
# trivial request
151
165
self.unused_data = self.in_buffer
515
529
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
530
"""Decodes the length-prefixed bulk data."""
518
532
def __init__(self):
519
533
_StatefulDecoder.__init__(self)
520
534
self.state_accept = self._state_accept_expecting_length
521
535
self.state_read = self._state_read_no_data
523
537
self._trailer_buffer = ''
525
539
def next_read_size(self):
526
540
if self.bytes_left is not None:
527
541
# Ideally we want to read all the remainder of the body and the
612
626
mutter('hpss call: %s', repr(args)[1:-1])
613
627
if getattr(self._request._medium, 'base', None) is not None:
614
628
mutter(' (to %s)', self._request._medium.base)
615
self._request_start_time = time.time()
629
self._request_start_time = osutils.timer_func()
616
630
self._write_args(args)
617
631
self._request.finished_writing()
618
632
self._last_verb = args[0]
627
641
if getattr(self._request._medium, '_path', None) is not None:
628
642
mutter(' (to %s)', self._request._medium._path)
629
643
mutter(' %d bytes', len(body))
630
self._request_start_time = time.time()
644
self._request_start_time = osutils.timer_func()
631
645
if 'hpssdetail' in debug.debug_flags:
632
646
mutter('hpss body content: %s', body)
633
647
self._write_args(args)
646
660
mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
661
if getattr(self._request._medium, '_path', None) is not None:
648
662
mutter(' (to %s)', self._request._medium._path)
649
self._request_start_time = time.time()
663
self._request_start_time = osutils.timer_func()
650
664
self._write_args(args)
651
665
readv_bytes = self._serialise_offsets(body)
652
666
bytes = self._encode_bulk_data(readv_bytes)
742
756
# The response will have no body, so we've finished reading.
743
757
self._request.finished_reading()
744
758
raise errors.UnknownSmartMethod(self._last_verb)
746
760
def read_body_bytes(self, count=-1):
747
761
"""Read bytes from the body, decoding into a byte stream.
749
We read all bytes at once to ensure we've checked the trailer for
763
We read all bytes at once to ensure we've checked the trailer for
750
764
errors, and then feed the buffer back as read_body_bytes is called.
752
766
if self._body_buffer is not None:
860
874
def build_server_protocol_three(backing_transport, write_func,
875
root_client_path, jail_root=None):
862
876
request_handler = request.SmartServerRequestHandler(
863
877
backing_transport, commands=request.request_handlers,
864
root_client_path=root_client_path)
878
root_client_path=root_client_path, jail_root=jail_root)
865
879
responder = ProtocolThreeResponder(write_func)
866
880
message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
881
return ProtocolThreeDecoder(message_handler)
897
911
# We do *not* set self.decoding_failed here. The message handler
898
912
# has raised an error, but the decoder is still able to parse bytes
899
913
# and determine when this message ends.
900
log_exception_quietly()
914
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
log_exception_quietly()
901
916
self.message_handler.protocol_error(exception.exc_value)
902
917
# The state machine is ready to continue decoding, but the
903
918
# exception has interrupted the loop that runs the state machine.
1058
1073
class _ProtocolThreeEncoder(object):
1060
1075
response_marker = request_marker = MESSAGE_VERSION_THREE
1076
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1062
1078
def __init__(self, write_func):
1064
1081
self._real_write_func = write_func
1066
1083
def _write_func(self, bytes):
1084
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
# for total number of bytes to write, rather than buffer based on
1086
# the number of write() calls
1087
# TODO: Another possibility would be to turn this into an async model.
1088
# Where we let another thread know that we have some bytes if
1089
# they want it, but we don't actually block for it
1090
# Note that osutils.send_all always sends 64kB chunks anyway, so
1091
# we might just push out smaller bits at a time?
1092
self._buf.append(bytes)
1093
self._buf_len += len(bytes)
1094
if self._buf_len > self.BUFFER_SIZE:
1069
1097
def flush(self):
1071
self._real_write_func(self._buf)
1099
self._real_write_func(''.join(self._buf))
1074
1103
def _serialise_offsets(self, offsets):
1075
1104
"""Serialise a readv offset list."""
1124
1153
_ProtocolThreeEncoder.__init__(self, write_func)
1125
1154
self.response_sent = False
1126
1155
self._headers = {'Software version': bzrlib.__version__}
1156
if 'hpss' in debug.debug_flags:
1157
self._thread_id = thread.get_ident()
1158
self._response_start_time = None
1160
def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
if self._response_start_time is None:
1162
self._response_start_time = osutils.timer_func()
1164
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1167
if extra_bytes is None:
1170
extra = ' ' + repr(extra_bytes[:40])
1172
extra = extra[:29] + extra[-1] + '...'
1173
mutter('%12s: [%s] %s%s%s'
1174
% (action, self._thread_id, t, message, extra))
1128
1176
def send_error(self, exception):
1129
1177
if self.response_sent:
1154
1204
self._write_success_status()
1156
1206
self._write_error_status()
1207
if 'hpss' in debug.debug_flags:
1208
self._trace('response', repr(response.args))
1157
1209
self._write_structure(response.args)
1158
1210
if response.body is not None:
1159
1211
self._write_prefixed_body(response.body)
1212
if 'hpss' in debug.debug_flags:
1213
self._trace('body', '%d bytes' % (len(response.body),),
1214
response.body, include_time=True)
1160
1215
elif response.body_stream is not None:
1161
for chunk in response.body_stream:
1162
self._write_prefixed_body(chunk)
1216
count = num_bytes = 0
1218
for exc_info, chunk in _iter_with_errors(response.body_stream):
1220
if exc_info is not None:
1221
self._write_error_status()
1222
error_struct = request._translate_error(exc_info[1])
1223
self._write_structure(error_struct)
1226
if isinstance(chunk, request.FailedSmartServerResponse):
1227
self._write_error_status()
1228
self._write_structure(chunk.args)
1230
num_bytes += len(chunk)
1231
if first_chunk is None:
1233
self._write_prefixed_body(chunk)
1234
if 'hpssdetail' in debug.debug_flags:
1235
# Not worth timing separately, as _write_func is
1237
self._trace('body chunk',
1238
'%d bytes' % (len(chunk),),
1239
chunk, suppress_time=True)
1240
if 'hpss' in debug.debug_flags:
1241
self._trace('body stream',
1242
'%d bytes %d chunks' % (num_bytes, count),
1164
1244
self._write_end()
1245
if 'hpss' in debug.debug_flags:
1246
self._trace('response end', '', include_time=True)
1249
def _iter_with_errors(iterable):
1250
"""Handle errors from iterable.next().
1254
for exc_info, value in _iter_with_errors(iterable):
1257
This is a safer alternative to::
1260
for value in iterable:
1265
Because the latter will catch errors from the for-loop body, not just
1268
If an error occurs, exc_info will be a exc_info tuple, and the generator
1269
will terminate. Otherwise exc_info will be None, and value will be the
1270
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1271
will not be itercepted.
1273
iterator = iter(iterable)
1276
yield None, iterator.next()
1277
except StopIteration:
1279
except (KeyboardInterrupt, SystemExit):
1282
mutter('_iter_with_errors caught error')
1283
log_exception_quietly()
1284
yield sys.exc_info(), None
1167
1288
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1174
1295
def set_headers(self, headers):
1175
1296
self._headers = headers.copy()
1177
1298
def call(self, *args):
1178
1299
if 'hpss' in debug.debug_flags:
1179
1300
mutter('hpss call: %s', repr(args)[1:-1])
1180
1301
base = getattr(self._medium_request._medium, 'base', None)
1181
1302
if base is not None:
1182
1303
mutter(' (to %s)', base)
1183
self._request_start_time = time.time()
1304
self._request_start_time = osutils.timer_func()
1184
1305
self._write_protocol_version()
1185
1306
self._write_headers(self._headers)
1186
1307
self._write_structure(args)
1242
1363
# have finished sending the stream. We would notice at the end
1243
1364
# anyway, but if the medium can deliver it early then it's good
1244
1365
# to short-circuit the whole request...
1366
for exc_info, part in _iter_with_errors(stream):
1367
if exc_info is not None:
1368
# Iterating the stream failed. Cleanly abort the request.
1369
self._write_error_status()
1370
# Currently the client unconditionally sends ('error',) as the
1372
self._write_structure(('error',))
1374
self._medium_request.finished_writing()
1375
raise exc_info[0], exc_info[1], exc_info[2]
1247
1377
self._write_prefixed_body(part)
1250
# Iterating the stream failed. Cleanly abort the request.
1251
self._write_error_status()
1252
# Currently the client unconditionally sends ('error',) as the
1254
self._write_structure(('error',))
1255
1379
self._write_end()
1256
1380
self._medium_request.finished_writing()