~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Robert Collins
  • Date: 2010-05-06 23:41:35 UTC
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100506234135-yivbzczw1sejxnxc
Lock methods on ``Tree``, ``Branch`` and ``Repository`` are now
expected to return an object which can be used to unlock them. This reduces
duplicate code when using cleanups. The previous 'tokens's returned by
``Branch.lock_write`` and ``Repository.lock_write`` are now attributes
on the result of the lock_write. ``repository.RepositoryWriteLockResult``
and ``branch.BranchWriteLockResult`` document this. (Robert Collins)

``log._get_info_for_log_files`` now takes an add_cleanup callable.
(Robert Collins)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode_as_tuple, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
57
62
 
58
63
def _encode_tuple(args):
59
64
    """Encode the tuple args to a bytestream."""
60
 
    return '\x01'.join(args) + '\n'
 
65
    joined = '\x01'.join(args) + '\n'
 
66
    if type(joined) is unicode:
 
67
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
68
        mutter('response args contain unicode, should be only bytes: %r',
 
69
               joined)
 
70
        joined = joined.encode('ascii')
 
71
    return joined
61
72
 
62
73
 
63
74
class Requester(object):
114
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
126
    """Server-side encoding and decoding logic for smart version 1."""
116
127
 
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
128
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
129
            jail_root=None):
118
130
        self._backing_transport = backing_transport
119
131
        self._root_client_path = root_client_path
 
132
        self._jail_root = jail_root
120
133
        self.unused_data = ''
121
134
        self._finished = False
122
135
        self.in_buffer = ''
144
157
                req_args = _decode_tuple(first_line)
145
158
                self.request = request.SmartServerRequestHandler(
146
159
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
160
                    root_client_path=self._root_client_path,
 
161
                    jail_root=self._jail_root)
 
162
                self.request.args_received(req_args)
149
163
                if self.request.finished_reading:
150
164
                    # trivial request
151
165
                    self.unused_data = self.in_buffer
612
626
            mutter('hpss call:   %s', repr(args)[1:-1])
613
627
            if getattr(self._request._medium, 'base', None) is not None:
614
628
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
629
            self._request_start_time = osutils.timer_func()
616
630
        self._write_args(args)
617
631
        self._request.finished_writing()
618
632
        self._last_verb = args[0]
627
641
            if getattr(self._request._medium, '_path', None) is not None:
628
642
                mutter('                  (to %s)', self._request._medium._path)
629
643
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
644
            self._request_start_time = osutils.timer_func()
631
645
            if 'hpssdetail' in debug.debug_flags:
632
646
                mutter('hpss body content: %s', body)
633
647
        self._write_args(args)
646
660
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
661
            if getattr(self._request._medium, '_path', None) is not None:
648
662
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
663
            self._request_start_time = osutils.timer_func()
650
664
        self._write_args(args)
651
665
        readv_bytes = self._serialise_offsets(body)
652
666
        bytes = self._encode_bulk_data(readv_bytes)
678
692
        if 'hpss' in debug.debug_flags:
679
693
            if self._request_start_time is not None:
680
694
                mutter('   result:   %6.3fs  %s',
681
 
                       time.time() - self._request_start_time,
 
695
                       osutils.timer_func() - self._request_start_time,
682
696
                       repr(result)[1:-1])
683
697
                self._request_start_time = None
684
698
            else:
858
872
 
859
873
 
860
874
def build_server_protocol_three(backing_transport, write_func,
861
 
                                root_client_path):
 
875
                                root_client_path, jail_root=None):
862
876
    request_handler = request.SmartServerRequestHandler(
863
877
        backing_transport, commands=request.request_handlers,
864
 
        root_client_path=root_client_path)
 
878
        root_client_path=root_client_path, jail_root=jail_root)
865
879
    responder = ProtocolThreeResponder(write_func)
866
880
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
881
    return ProtocolThreeDecoder(message_handler)
897
911
            # We do *not* set self.decoding_failed here.  The message handler
898
912
            # has raised an error, but the decoder is still able to parse bytes
899
913
            # and determine when this message ends.
900
 
            log_exception_quietly()
 
914
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
915
                log_exception_quietly()
901
916
            self.message_handler.protocol_error(exception.exc_value)
902
917
            # The state machine is ready to continue decoding, but the
903
918
            # exception has interrupted the loop that runs the state machine.
1036
1051
            raise errors.SmartMessageHandlerError(sys.exc_info())
1037
1052
 
1038
1053
    def _state_accept_reading_unused(self):
1039
 
        self.unused_data = self._get_in_buffer()
 
1054
        self.unused_data += self._get_in_buffer()
1040
1055
        self._set_in_buffer(None)
1041
1056
 
1042
1057
    def next_read_size(self):
1058
1073
class _ProtocolThreeEncoder(object):
1059
1074
 
1060
1075
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1076
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1061
1077
 
1062
1078
    def __init__(self, write_func):
1063
1079
        self._buf = []
 
1080
        self._buf_len = 0
1064
1081
        self._real_write_func = write_func
1065
1082
 
1066
1083
    def _write_func(self, bytes):
 
1084
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1085
        #       for total number of bytes to write, rather than buffer based on
 
1086
        #       the number of write() calls
 
1087
        # TODO: Another possibility would be to turn this into an async model.
 
1088
        #       Where we let another thread know that we have some bytes if
 
1089
        #       they want it, but we don't actually block for it
 
1090
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1091
        #       we might just push out smaller bits at a time?
1067
1092
        self._buf.append(bytes)
1068
 
        if len(self._buf) > 100:
 
1093
        self._buf_len += len(bytes)
 
1094
        if self._buf_len > self.BUFFER_SIZE:
1069
1095
            self.flush()
1070
1096
 
1071
1097
    def flush(self):
1072
1098
        if self._buf:
1073
1099
            self._real_write_func(''.join(self._buf))
1074
1100
            del self._buf[:]
 
1101
            self._buf_len = 0
1075
1102
 
1076
1103
    def _serialise_offsets(self, offsets):
1077
1104
        """Serialise a readv offset list."""
1126
1153
        _ProtocolThreeEncoder.__init__(self, write_func)
1127
1154
        self.response_sent = False
1128
1155
        self._headers = {'Software version': bzrlib.__version__}
 
1156
        if 'hpss' in debug.debug_flags:
 
1157
            self._thread_id = thread.get_ident()
 
1158
            self._response_start_time = None
 
1159
 
 
1160
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1161
        if self._response_start_time is None:
 
1162
            self._response_start_time = osutils.timer_func()
 
1163
        if include_time:
 
1164
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1165
        else:
 
1166
            t = ''
 
1167
        if extra_bytes is None:
 
1168
            extra = ''
 
1169
        else:
 
1170
            extra = ' ' + repr(extra_bytes[:40])
 
1171
            if len(extra) > 33:
 
1172
                extra = extra[:29] + extra[-1] + '...'
 
1173
        mutter('%12s: [%s] %s%s%s'
 
1174
               % (action, self._thread_id, t, message, extra))
1129
1175
 
1130
1176
    def send_error(self, exception):
1131
1177
        if self.response_sent:
1137
1183
                ('UnknownMethod', exception.verb))
1138
1184
            self.send_response(failure)
1139
1185
            return
 
1186
        if 'hpss' in debug.debug_flags:
 
1187
            self._trace('error', str(exception))
1140
1188
        self.response_sent = True
1141
1189
        self._write_protocol_version()
1142
1190
        self._write_headers(self._headers)
1156
1204
            self._write_success_status()
1157
1205
        else:
1158
1206
            self._write_error_status()
 
1207
        if 'hpss' in debug.debug_flags:
 
1208
            self._trace('response', repr(response.args))
1159
1209
        self._write_structure(response.args)
1160
1210
        if response.body is not None:
1161
1211
            self._write_prefixed_body(response.body)
 
1212
            if 'hpss' in debug.debug_flags:
 
1213
                self._trace('body', '%d bytes' % (len(response.body),),
 
1214
                            response.body, include_time=True)
1162
1215
        elif response.body_stream is not None:
 
1216
            count = num_bytes = 0
 
1217
            first_chunk = None
1163
1218
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1219
                count += 1
1164
1220
                if exc_info is not None:
1165
1221
                    self._write_error_status()
1166
1222
                    error_struct = request._translate_error(exc_info[1])
1171
1227
                        self._write_error_status()
1172
1228
                        self._write_structure(chunk.args)
1173
1229
                        break
 
1230
                    num_bytes += len(chunk)
 
1231
                    if first_chunk is None:
 
1232
                        first_chunk = chunk
1174
1233
                    self._write_prefixed_body(chunk)
 
1234
                    if 'hpssdetail' in debug.debug_flags:
 
1235
                        # Not worth timing separately, as _write_func is
 
1236
                        # actually buffered
 
1237
                        self._trace('body chunk',
 
1238
                                    '%d bytes' % (len(chunk),),
 
1239
                                    chunk, suppress_time=True)
 
1240
            if 'hpss' in debug.debug_flags:
 
1241
                self._trace('body stream',
 
1242
                            '%d bytes %d chunks' % (num_bytes, count),
 
1243
                            first_chunk)
1175
1244
        self._write_end()
 
1245
        if 'hpss' in debug.debug_flags:
 
1246
            self._trace('response end', '', include_time=True)
1176
1247
 
1177
1248
 
1178
1249
def _iter_with_errors(iterable):
1208
1279
        except (KeyboardInterrupt, SystemExit):
1209
1280
            raise
1210
1281
        except Exception:
 
1282
            mutter('_iter_with_errors caught error')
 
1283
            log_exception_quietly()
1211
1284
            yield sys.exc_info(), None
1212
1285
            return
1213
1286
 
1228
1301
            base = getattr(self._medium_request._medium, 'base', None)
1229
1302
            if base is not None:
1230
1303
                mutter('             (to %s)', base)
1231
 
            self._request_start_time = time.time()
 
1304
            self._request_start_time = osutils.timer_func()
1232
1305
        self._write_protocol_version()
1233
1306
        self._write_headers(self._headers)
1234
1307
        self._write_structure(args)
1246
1319
            if path is not None:
1247
1320
                mutter('                  (to %s)', path)
1248
1321
            mutter('              %d bytes', len(body))
1249
 
            self._request_start_time = time.time()
 
1322
            self._request_start_time = osutils.timer_func()
1250
1323
        self._write_protocol_version()
1251
1324
        self._write_headers(self._headers)
1252
1325
        self._write_structure(args)
1265
1338
            path = getattr(self._medium_request._medium, '_path', None)
1266
1339
            if path is not None:
1267
1340
                mutter('                  (to %s)', path)
1268
 
            self._request_start_time = time.time()
 
1341
            self._request_start_time = osutils.timer_func()
1269
1342
        self._write_protocol_version()
1270
1343
        self._write_headers(self._headers)
1271
1344
        self._write_structure(args)
1282
1355
            path = getattr(self._medium_request._medium, '_path', None)
1283
1356
            if path is not None:
1284
1357
                mutter('                  (to %s)', path)
1285
 
            self._request_start_time = time.time()
 
1358
            self._request_start_time = osutils.timer_func()
1286
1359
        self._write_protocol_version()
1287
1360
        self._write_headers(self._headers)
1288
1361
        self._write_structure(args)