13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
22
22
from cStringIO import StringIO
28
from bzrlib import debug
29
from bzrlib import errors
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
from bzrlib.util.bencode import bdecode_as_tuple, bencode
37
from bzrlib.bencode import bdecode_as_tuple, bencode
35
40
# Protocol version strings. These are sent as prefixes of bzr requests and
109
114
for start, length in offsets:
110
115
txt.append('%d,%d' % (start, length))
111
116
return '\n'.join(txt)
114
119
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
120
"""Server-side encoding and decoding logic for smart version 1."""
117
def __init__(self, backing_transport, write_func, root_client_path='/'):
122
def __init__(self, backing_transport, write_func, root_client_path='/',
118
124
self._backing_transport = backing_transport
119
125
self._root_client_path = root_client_path
126
self._jail_root = jail_root
120
127
self.unused_data = ''
121
128
self._finished = False
122
129
self.in_buffer = ''
144
151
req_args = _decode_tuple(first_line)
145
152
self.request = request.SmartServerRequestHandler(
146
153
self._backing_transport, commands=request.request_handlers,
147
root_client_path=self._root_client_path)
148
self.request.dispatch_command(req_args[0], req_args[1:])
154
root_client_path=self._root_client_path,
155
jail_root=self._jail_root)
156
self.request.args_received(req_args)
149
157
if self.request.finished_reading:
150
158
# trivial request
151
159
self.unused_data = self.in_buffer
515
523
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
524
"""Decodes the length-prefixed bulk data."""
518
526
def __init__(self):
519
527
_StatefulDecoder.__init__(self)
520
528
self.state_accept = self._state_accept_expecting_length
521
529
self.state_read = self._state_read_no_data
523
531
self._trailer_buffer = ''
525
533
def next_read_size(self):
526
534
if self.bytes_left is not None:
527
535
# Ideally we want to read all the remainder of the body and the
612
620
mutter('hpss call: %s', repr(args)[1:-1])
613
621
if getattr(self._request._medium, 'base', None) is not None:
614
622
mutter(' (to %s)', self._request._medium.base)
615
self._request_start_time = time.time()
623
self._request_start_time = osutils.timer_func()
616
624
self._write_args(args)
617
625
self._request.finished_writing()
618
626
self._last_verb = args[0]
627
635
if getattr(self._request._medium, '_path', None) is not None:
628
636
mutter(' (to %s)', self._request._medium._path)
629
637
mutter(' %d bytes', len(body))
630
self._request_start_time = time.time()
638
self._request_start_time = osutils.timer_func()
631
639
if 'hpssdetail' in debug.debug_flags:
632
640
mutter('hpss body content: %s', body)
633
641
self._write_args(args)
646
654
mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
655
if getattr(self._request._medium, '_path', None) is not None:
648
656
mutter(' (to %s)', self._request._medium._path)
649
self._request_start_time = time.time()
657
self._request_start_time = osutils.timer_func()
650
658
self._write_args(args)
651
659
readv_bytes = self._serialise_offsets(body)
652
660
bytes = self._encode_bulk_data(readv_bytes)
742
750
# The response will have no body, so we've finished reading.
743
751
self._request.finished_reading()
744
752
raise errors.UnknownSmartMethod(self._last_verb)
746
754
def read_body_bytes(self, count=-1):
747
755
"""Read bytes from the body, decoding into a byte stream.
749
We read all bytes at once to ensure we've checked the trailer for
757
We read all bytes at once to ensure we've checked the trailer for
750
758
errors, and then feed the buffer back as read_body_bytes is called.
752
760
if self._body_buffer is not None:
860
868
def build_server_protocol_three(backing_transport, write_func,
869
root_client_path, jail_root=None):
862
870
request_handler = request.SmartServerRequestHandler(
863
871
backing_transport, commands=request.request_handlers,
864
root_client_path=root_client_path)
872
root_client_path=root_client_path, jail_root=jail_root)
865
873
responder = ProtocolThreeResponder(write_func)
866
874
message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
875
return ProtocolThreeDecoder(message_handler)
897
905
# We do *not* set self.decoding_failed here. The message handler
898
906
# has raised an error, but the decoder is still able to parse bytes
899
907
# and determine when this message ends.
900
log_exception_quietly()
908
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
909
log_exception_quietly()
901
910
self.message_handler.protocol_error(exception.exc_value)
902
911
# The state machine is ready to continue decoding, but the
903
912
# exception has interrupted the loop that runs the state machine.
1058
1067
class _ProtocolThreeEncoder(object):
1060
1069
response_marker = request_marker = MESSAGE_VERSION_THREE
1070
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1062
1072
def __init__(self, write_func):
1064
1075
self._real_write_func = write_func
1066
1077
def _write_func(self, bytes):
1078
# TODO: It is probably more appropriate to use sum(map(len, _buf))
1079
# for total number of bytes to write, rather than buffer based on
1080
# the number of write() calls
1081
# TODO: Another possibility would be to turn this into an async model.
1082
# Where we let another thread know that we have some bytes if
1083
# they want it, but we don't actually block for it
1084
# Note that osutils.send_all always sends 64kB chunks anyway, so
1085
# we might just push out smaller bits at a time?
1086
self._buf.append(bytes)
1087
self._buf_len += len(bytes)
1088
if self._buf_len > self.BUFFER_SIZE:
1069
1091
def flush(self):
1071
self._real_write_func(self._buf)
1093
self._real_write_func(''.join(self._buf))
1074
1097
def _serialise_offsets(self, offsets):
1075
1098
"""Serialise a readv offset list."""
1124
1147
_ProtocolThreeEncoder.__init__(self, write_func)
1125
1148
self.response_sent = False
1126
1149
self._headers = {'Software version': bzrlib.__version__}
1150
if 'hpss' in debug.debug_flags:
1151
self._thread_id = thread.get_ident()
1152
self._response_start_time = None
1154
def _trace(self, action, message, extra_bytes=None, include_time=False):
1155
if self._response_start_time is None:
1156
self._response_start_time = osutils.timer_func()
1158
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1161
if extra_bytes is None:
1164
extra = ' ' + repr(extra_bytes[:40])
1166
extra = extra[:29] + extra[-1] + '...'
1167
mutter('%12s: [%s] %s%s%s'
1168
% (action, self._thread_id, t, message, extra))
1128
1170
def send_error(self, exception):
1129
1171
if self.response_sent:
1154
1198
self._write_success_status()
1156
1200
self._write_error_status()
1201
if 'hpss' in debug.debug_flags:
1202
self._trace('response', repr(response.args))
1157
1203
self._write_structure(response.args)
1158
1204
if response.body is not None:
1159
1205
self._write_prefixed_body(response.body)
1206
if 'hpss' in debug.debug_flags:
1207
self._trace('body', '%d bytes' % (len(response.body),),
1208
response.body, include_time=True)
1160
1209
elif response.body_stream is not None:
1161
for chunk in response.body_stream:
1162
self._write_prefixed_body(chunk)
1210
count = num_bytes = 0
1212
for exc_info, chunk in _iter_with_errors(response.body_stream):
1214
if exc_info is not None:
1215
self._write_error_status()
1216
error_struct = request._translate_error(exc_info[1])
1217
self._write_structure(error_struct)
1220
if isinstance(chunk, request.FailedSmartServerResponse):
1221
self._write_error_status()
1222
self._write_structure(chunk.args)
1224
num_bytes += len(chunk)
1225
if first_chunk is None:
1227
self._write_prefixed_body(chunk)
1228
if 'hpssdetail' in debug.debug_flags:
1229
# Not worth timing separately, as _write_func is
1231
self._trace('body chunk',
1232
'%d bytes' % (len(chunk),),
1233
chunk, suppress_time=True)
1234
if 'hpss' in debug.debug_flags:
1235
self._trace('body stream',
1236
'%d bytes %d chunks' % (num_bytes, count),
1164
1238
self._write_end()
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('response end', '', include_time=True)
1243
def _iter_with_errors(iterable):
1244
"""Handle errors from iterable.next().
1248
for exc_info, value in _iter_with_errors(iterable):
1251
This is a safer alternative to::
1254
for value in iterable:
1259
Because the latter will catch errors from the for-loop body, not just
1262
If an error occurs, exc_info will be a exc_info tuple, and the generator
1263
will terminate. Otherwise exc_info will be None, and value will be the
1264
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1265
will not be itercepted.
1267
iterator = iter(iterable)
1270
yield None, iterator.next()
1271
except StopIteration:
1273
except (KeyboardInterrupt, SystemExit):
1276
mutter('_iter_with_errors caught error')
1277
log_exception_quietly()
1278
yield sys.exc_info(), None
1167
1282
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1174
1289
def set_headers(self, headers):
1175
1290
self._headers = headers.copy()
1177
1292
def call(self, *args):
1178
1293
if 'hpss' in debug.debug_flags:
1179
1294
mutter('hpss call: %s', repr(args)[1:-1])
1180
1295
base = getattr(self._medium_request._medium, 'base', None)
1181
1296
if base is not None:
1182
1297
mutter(' (to %s)', base)
1183
self._request_start_time = time.time()
1298
self._request_start_time = osutils.timer_func()
1184
1299
self._write_protocol_version()
1185
1300
self._write_headers(self._headers)
1186
1301
self._write_structure(args)
1242
1357
# have finished sending the stream. We would notice at the end
1243
1358
# anyway, but if the medium can deliver it early then it's good
1244
1359
# to short-circuit the whole request...
1360
for exc_info, part in _iter_with_errors(stream):
1361
if exc_info is not None:
1362
# Iterating the stream failed. Cleanly abort the request.
1363
self._write_error_status()
1364
# Currently the client unconditionally sends ('error',) as the
1366
self._write_structure(('error',))
1368
self._medium_request.finished_writing()
1369
raise exc_info[0], exc_info[1], exc_info[2]
1247
1371
self._write_prefixed_body(part)
1250
# Iterating the stream failed. Cleanly abort the request.
1251
self._write_error_status()
1252
# Currently the client unconditionally sends ('error',) as the
1254
self._write_structure(('error',))
1255
1373
self._write_end()
1256
1374
self._medium_request.finished_writing()