~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2010-01-14 00:01:32 UTC
  • mfrom: (4957.1.1 jam-integration)
  • Revision ID: pqm@pqm.ubuntu.com-20100114000132-3p3rabnonjw3gzqb
(jam) Merge bzr.stable, bringing in bug fixes #175839, #504390

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode_as_tuple, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
114
119
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
120
    """Server-side encoding and decoding logic for smart version 1."""
116
121
 
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
122
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
123
            jail_root=None):
118
124
        self._backing_transport = backing_transport
119
125
        self._root_client_path = root_client_path
 
126
        self._jail_root = jail_root
120
127
        self.unused_data = ''
121
128
        self._finished = False
122
129
        self.in_buffer = ''
144
151
                req_args = _decode_tuple(first_line)
145
152
                self.request = request.SmartServerRequestHandler(
146
153
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
154
                    root_client_path=self._root_client_path,
 
155
                    jail_root=self._jail_root)
 
156
                self.request.args_received(req_args)
149
157
                if self.request.finished_reading:
150
158
                    # trivial request
151
159
                    self.unused_data = self.in_buffer
612
620
            mutter('hpss call:   %s', repr(args)[1:-1])
613
621
            if getattr(self._request._medium, 'base', None) is not None:
614
622
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
623
            self._request_start_time = osutils.timer_func()
616
624
        self._write_args(args)
617
625
        self._request.finished_writing()
618
626
        self._last_verb = args[0]
627
635
            if getattr(self._request._medium, '_path', None) is not None:
628
636
                mutter('                  (to %s)', self._request._medium._path)
629
637
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
638
            self._request_start_time = osutils.timer_func()
631
639
            if 'hpssdetail' in debug.debug_flags:
632
640
                mutter('hpss body content: %s', body)
633
641
        self._write_args(args)
646
654
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
655
            if getattr(self._request._medium, '_path', None) is not None:
648
656
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
657
            self._request_start_time = osutils.timer_func()
650
658
        self._write_args(args)
651
659
        readv_bytes = self._serialise_offsets(body)
652
660
        bytes = self._encode_bulk_data(readv_bytes)
678
686
        if 'hpss' in debug.debug_flags:
679
687
            if self._request_start_time is not None:
680
688
                mutter('   result:   %6.3fs  %s',
681
 
                       time.time() - self._request_start_time,
 
689
                       osutils.timer_func() - self._request_start_time,
682
690
                       repr(result)[1:-1])
683
691
                self._request_start_time = None
684
692
            else:
858
866
 
859
867
 
860
868
def build_server_protocol_three(backing_transport, write_func,
861
 
                                root_client_path):
 
869
                                root_client_path, jail_root=None):
862
870
    request_handler = request.SmartServerRequestHandler(
863
871
        backing_transport, commands=request.request_handlers,
864
 
        root_client_path=root_client_path)
 
872
        root_client_path=root_client_path, jail_root=jail_root)
865
873
    responder = ProtocolThreeResponder(write_func)
866
874
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
875
    return ProtocolThreeDecoder(message_handler)
897
905
            # We do *not* set self.decoding_failed here.  The message handler
898
906
            # has raised an error, but the decoder is still able to parse bytes
899
907
            # and determine when this message ends.
900
 
            log_exception_quietly()
 
908
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
909
                log_exception_quietly()
901
910
            self.message_handler.protocol_error(exception.exc_value)
902
911
            # The state machine is ready to continue decoding, but the
903
912
            # exception has interrupted the loop that runs the state machine.
1036
1045
            raise errors.SmartMessageHandlerError(sys.exc_info())
1037
1046
 
1038
1047
    def _state_accept_reading_unused(self):
1039
 
        self.unused_data = self._get_in_buffer()
 
1048
        self.unused_data += self._get_in_buffer()
1040
1049
        self._set_in_buffer(None)
1041
1050
 
1042
1051
    def next_read_size(self):
1058
1067
class _ProtocolThreeEncoder(object):
1059
1068
 
1060
1069
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1070
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1061
1071
 
1062
1072
    def __init__(self, write_func):
1063
1073
        self._buf = []
 
1074
        self._buf_len = 0
1064
1075
        self._real_write_func = write_func
1065
1076
 
1066
1077
    def _write_func(self, bytes):
 
1078
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1079
        #       for total number of bytes to write, rather than buffer based on
 
1080
        #       the number of write() calls
 
1081
        # TODO: Another possibility would be to turn this into an async model.
 
1082
        #       Where we let another thread know that we have some bytes if
 
1083
        #       they want it, but we don't actually block for it
 
1084
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1085
        #       we might just push out smaller bits at a time?
1067
1086
        self._buf.append(bytes)
1068
 
        if len(self._buf) > 100:
 
1087
        self._buf_len += len(bytes)
 
1088
        if self._buf_len > self.BUFFER_SIZE:
1069
1089
            self.flush()
1070
1090
 
1071
1091
    def flush(self):
1072
1092
        if self._buf:
1073
1093
            self._real_write_func(''.join(self._buf))
1074
1094
            del self._buf[:]
 
1095
            self._buf_len = 0
1075
1096
 
1076
1097
    def _serialise_offsets(self, offsets):
1077
1098
        """Serialise a readv offset list."""
1126
1147
        _ProtocolThreeEncoder.__init__(self, write_func)
1127
1148
        self.response_sent = False
1128
1149
        self._headers = {'Software version': bzrlib.__version__}
 
1150
        if 'hpss' in debug.debug_flags:
 
1151
            self._thread_id = thread.get_ident()
 
1152
            self._response_start_time = None
 
1153
 
 
1154
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1155
        if self._response_start_time is None:
 
1156
            self._response_start_time = osutils.timer_func()
 
1157
        if include_time:
 
1158
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1159
        else:
 
1160
            t = ''
 
1161
        if extra_bytes is None:
 
1162
            extra = ''
 
1163
        else:
 
1164
            extra = ' ' + repr(extra_bytes[:40])
 
1165
            if len(extra) > 33:
 
1166
                extra = extra[:29] + extra[-1] + '...'
 
1167
        mutter('%12s: [%s] %s%s%s'
 
1168
               % (action, self._thread_id, t, message, extra))
1129
1169
 
1130
1170
    def send_error(self, exception):
1131
1171
        if self.response_sent:
1137
1177
                ('UnknownMethod', exception.verb))
1138
1178
            self.send_response(failure)
1139
1179
            return
 
1180
        if 'hpss' in debug.debug_flags:
 
1181
            self._trace('error', str(exception))
1140
1182
        self.response_sent = True
1141
1183
        self._write_protocol_version()
1142
1184
        self._write_headers(self._headers)
1156
1198
            self._write_success_status()
1157
1199
        else:
1158
1200
            self._write_error_status()
 
1201
        if 'hpss' in debug.debug_flags:
 
1202
            self._trace('response', repr(response.args))
1159
1203
        self._write_structure(response.args)
1160
1204
        if response.body is not None:
1161
1205
            self._write_prefixed_body(response.body)
 
1206
            if 'hpss' in debug.debug_flags:
 
1207
                self._trace('body', '%d bytes' % (len(response.body),),
 
1208
                            response.body, include_time=True)
1162
1209
        elif response.body_stream is not None:
 
1210
            count = num_bytes = 0
 
1211
            first_chunk = None
1163
1212
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1213
                count += 1
1164
1214
                if exc_info is not None:
1165
1215
                    self._write_error_status()
1166
1216
                    error_struct = request._translate_error(exc_info[1])
1171
1221
                        self._write_error_status()
1172
1222
                        self._write_structure(chunk.args)
1173
1223
                        break
 
1224
                    num_bytes += len(chunk)
 
1225
                    if first_chunk is None:
 
1226
                        first_chunk = chunk
1174
1227
                    self._write_prefixed_body(chunk)
 
1228
                    if 'hpssdetail' in debug.debug_flags:
 
1229
                        # Not worth timing separately, as _write_func is
 
1230
                        # actually buffered
 
1231
                        self._trace('body chunk',
 
1232
                                    '%d bytes' % (len(chunk),),
 
1233
                                    chunk, suppress_time=True)
 
1234
            if 'hpss' in debug.debug_flags:
 
1235
                self._trace('body stream',
 
1236
                            '%d bytes %d chunks' % (num_bytes, count),
 
1237
                            first_chunk)
1175
1238
        self._write_end()
 
1239
        if 'hpss' in debug.debug_flags:
 
1240
            self._trace('response end', '', include_time=True)
1176
1241
 
1177
1242
 
1178
1243
def _iter_with_errors(iterable):
1208
1273
        except (KeyboardInterrupt, SystemExit):
1209
1274
            raise
1210
1275
        except Exception:
 
1276
            mutter('_iter_with_errors caught error')
 
1277
            log_exception_quietly()
1211
1278
            yield sys.exc_info(), None
1212
1279
            return
1213
1280
 
1228
1295
            base = getattr(self._medium_request._medium, 'base', None)
1229
1296
            if base is not None:
1230
1297
                mutter('             (to %s)', base)
1231
 
            self._request_start_time = time.time()
 
1298
            self._request_start_time = osutils.timer_func()
1232
1299
        self._write_protocol_version()
1233
1300
        self._write_headers(self._headers)
1234
1301
        self._write_structure(args)
1246
1313
            if path is not None:
1247
1314
                mutter('                  (to %s)', path)
1248
1315
            mutter('              %d bytes', len(body))
1249
 
            self._request_start_time = time.time()
 
1316
            self._request_start_time = osutils.timer_func()
1250
1317
        self._write_protocol_version()
1251
1318
        self._write_headers(self._headers)
1252
1319
        self._write_structure(args)
1265
1332
            path = getattr(self._medium_request._medium, '_path', None)
1266
1333
            if path is not None:
1267
1334
                mutter('                  (to %s)', path)
1268
 
            self._request_start_time = time.time()
 
1335
            self._request_start_time = osutils.timer_func()
1269
1336
        self._write_protocol_version()
1270
1337
        self._write_headers(self._headers)
1271
1338
        self._write_structure(args)
1282
1349
            path = getattr(self._medium_request._medium, '_path', None)
1283
1350
            if path is not None:
1284
1351
                mutter('                  (to %s)', path)
1285
 
            self._request_start_time = time.time()
 
1352
            self._request_start_time = osutils.timer_func()
1286
1353
        self._write_protocol_version()
1287
1354
        self._write_headers(self._headers)
1288
1355
        self._write_structure(args)