~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: John Arbash Meinel
  • Date: 2009-06-02 19:56:24 UTC
  • mto: This revision was merged to the branch mainline in revision 4469.
  • Revision ID: john@arbash-meinel.com-20090602195624-utljsyz0qgmq63lg
Add a chunks_to_gzip function.
This allows the _record_to_data code to build up a list of chunks,
rather than requiring a single string.
It should be ~ the same performance when using a single string, since
we are only adding a for() loop over the chunks and an if check.
We could possibly just remove the if check and not worry about adding
some empty strings in there.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
25
 
import thread
26
 
import threading
27
25
import time
28
26
 
29
27
import bzrlib
30
 
from bzrlib import (
31
 
    debug,
32
 
    errors,
33
 
    osutils,
34
 
    )
 
28
from bzrlib import debug
 
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
 
from bzrlib.bencode import bdecode_as_tuple, bencode
 
32
from bzrlib.util.bencode import bdecode_as_tuple, bencode
38
33
 
39
34
 
40
35
# Protocol version strings.  These are sent as prefixes of bzr requests and
62
57
 
63
58
def _encode_tuple(args):
64
59
    """Encode the tuple args to a bytestream."""
65
 
    joined = '\x01'.join(args) + '\n'
66
 
    if type(joined) is unicode:
67
 
        # XXX: We should fix things so this never happens!  -AJB, 20100304
68
 
        mutter('response args contain unicode, should be only bytes: %r',
69
 
               joined)
70
 
        joined = joined.encode('ascii')
71
 
    return joined
 
60
    return '\x01'.join(args) + '\n'
72
61
 
73
62
 
74
63
class Requester(object):
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
    """Server-side encoding and decoding logic for smart version 1."""
127
116
 
128
 
    def __init__(self, backing_transport, write_func, root_client_path='/',
129
 
            jail_root=None):
 
117
    def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
        self._backing_transport = backing_transport
131
119
        self._root_client_path = root_client_path
132
 
        self._jail_root = jail_root
133
120
        self.unused_data = ''
134
121
        self._finished = False
135
122
        self.in_buffer = ''
157
144
                req_args = _decode_tuple(first_line)
158
145
                self.request = request.SmartServerRequestHandler(
159
146
                    self._backing_transport, commands=request.request_handlers,
160
 
                    root_client_path=self._root_client_path,
161
 
                    jail_root=self._jail_root)
162
 
                self.request.args_received(req_args)
 
147
                    root_client_path=self._root_client_path)
 
148
                self.request.dispatch_command(req_args[0], req_args[1:])
163
149
                if self.request.finished_reading:
164
150
                    # trivial request
165
151
                    self.unused_data = self.in_buffer
626
612
            mutter('hpss call:   %s', repr(args)[1:-1])
627
613
            if getattr(self._request._medium, 'base', None) is not None:
628
614
                mutter('             (to %s)', self._request._medium.base)
629
 
            self._request_start_time = osutils.timer_func()
 
615
            self._request_start_time = time.time()
630
616
        self._write_args(args)
631
617
        self._request.finished_writing()
632
618
        self._last_verb = args[0]
641
627
            if getattr(self._request._medium, '_path', None) is not None:
642
628
                mutter('                  (to %s)', self._request._medium._path)
643
629
            mutter('              %d bytes', len(body))
644
 
            self._request_start_time = osutils.timer_func()
 
630
            self._request_start_time = time.time()
645
631
            if 'hpssdetail' in debug.debug_flags:
646
632
                mutter('hpss body content: %s', body)
647
633
        self._write_args(args)
660
646
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
647
            if getattr(self._request._medium, '_path', None) is not None:
662
648
                mutter('                  (to %s)', self._request._medium._path)
663
 
            self._request_start_time = osutils.timer_func()
 
649
            self._request_start_time = time.time()
664
650
        self._write_args(args)
665
651
        readv_bytes = self._serialise_offsets(body)
666
652
        bytes = self._encode_bulk_data(readv_bytes)
692
678
        if 'hpss' in debug.debug_flags:
693
679
            if self._request_start_time is not None:
694
680
                mutter('   result:   %6.3fs  %s',
695
 
                       osutils.timer_func() - self._request_start_time,
 
681
                       time.time() - self._request_start_time,
696
682
                       repr(result)[1:-1])
697
683
                self._request_start_time = None
698
684
            else:
872
858
 
873
859
 
874
860
def build_server_protocol_three(backing_transport, write_func,
875
 
                                root_client_path, jail_root=None):
 
861
                                root_client_path):
876
862
    request_handler = request.SmartServerRequestHandler(
877
863
        backing_transport, commands=request.request_handlers,
878
 
        root_client_path=root_client_path, jail_root=jail_root)
 
864
        root_client_path=root_client_path)
879
865
    responder = ProtocolThreeResponder(write_func)
880
866
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
867
    return ProtocolThreeDecoder(message_handler)
911
897
            # We do *not* set self.decoding_failed here.  The message handler
912
898
            # has raised an error, but the decoder is still able to parse bytes
913
899
            # and determine when this message ends.
914
 
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
 
                log_exception_quietly()
 
900
            log_exception_quietly()
916
901
            self.message_handler.protocol_error(exception.exc_value)
917
902
            # The state machine is ready to continue decoding, but the
918
903
            # exception has interrupted the loop that runs the state machine.
1051
1036
            raise errors.SmartMessageHandlerError(sys.exc_info())
1052
1037
 
1053
1038
    def _state_accept_reading_unused(self):
1054
 
        self.unused_data += self._get_in_buffer()
 
1039
        self.unused_data = self._get_in_buffer()
1055
1040
        self._set_in_buffer(None)
1056
1041
 
1057
1042
    def next_read_size(self):
1073
1058
class _ProtocolThreeEncoder(object):
1074
1059
 
1075
1060
    response_marker = request_marker = MESSAGE_VERSION_THREE
1076
 
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1077
1061
 
1078
1062
    def __init__(self, write_func):
1079
1063
        self._buf = []
1080
 
        self._buf_len = 0
1081
1064
        self._real_write_func = write_func
1082
1065
 
1083
1066
    def _write_func(self, bytes):
1084
 
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
 
        #       for total number of bytes to write, rather than buffer based on
1086
 
        #       the number of write() calls
1087
 
        # TODO: Another possibility would be to turn this into an async model.
1088
 
        #       Where we let another thread know that we have some bytes if
1089
 
        #       they want it, but we don't actually block for it
1090
 
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
1091
 
        #       we might just push out smaller bits at a time?
1092
1067
        self._buf.append(bytes)
1093
 
        self._buf_len += len(bytes)
1094
 
        if self._buf_len > self.BUFFER_SIZE:
 
1068
        if len(self._buf) > 100:
1095
1069
            self.flush()
1096
1070
 
1097
1071
    def flush(self):
1098
1072
        if self._buf:
1099
1073
            self._real_write_func(''.join(self._buf))
1100
1074
            del self._buf[:]
1101
 
            self._buf_len = 0
1102
1075
 
1103
1076
    def _serialise_offsets(self, offsets):
1104
1077
        """Serialise a readv offset list."""
1153
1126
        _ProtocolThreeEncoder.__init__(self, write_func)
1154
1127
        self.response_sent = False
1155
1128
        self._headers = {'Software version': bzrlib.__version__}
1156
 
        if 'hpss' in debug.debug_flags:
1157
 
            self._thread_id = thread.get_ident()
1158
 
            self._response_start_time = None
1159
 
 
1160
 
    def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
 
        if self._response_start_time is None:
1162
 
            self._response_start_time = osutils.timer_func()
1163
 
        if include_time:
1164
 
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
 
        else:
1166
 
            t = ''
1167
 
        if extra_bytes is None:
1168
 
            extra = ''
1169
 
        else:
1170
 
            extra = ' ' + repr(extra_bytes[:40])
1171
 
            if len(extra) > 33:
1172
 
                extra = extra[:29] + extra[-1] + '...'
1173
 
        mutter('%12s: [%s] %s%s%s'
1174
 
               % (action, self._thread_id, t, message, extra))
1175
1129
 
1176
1130
    def send_error(self, exception):
1177
1131
        if self.response_sent:
1183
1137
                ('UnknownMethod', exception.verb))
1184
1138
            self.send_response(failure)
1185
1139
            return
1186
 
        if 'hpss' in debug.debug_flags:
1187
 
            self._trace('error', str(exception))
1188
1140
        self.response_sent = True
1189
1141
        self._write_protocol_version()
1190
1142
        self._write_headers(self._headers)
1204
1156
            self._write_success_status()
1205
1157
        else:
1206
1158
            self._write_error_status()
1207
 
        if 'hpss' in debug.debug_flags:
1208
 
            self._trace('response', repr(response.args))
1209
1159
        self._write_structure(response.args)
1210
1160
        if response.body is not None:
1211
1161
            self._write_prefixed_body(response.body)
1212
 
            if 'hpss' in debug.debug_flags:
1213
 
                self._trace('body', '%d bytes' % (len(response.body),),
1214
 
                            response.body, include_time=True)
1215
1162
        elif response.body_stream is not None:
1216
 
            count = num_bytes = 0
1217
 
            first_chunk = None
1218
1163
            for exc_info, chunk in _iter_with_errors(response.body_stream):
1219
 
                count += 1
1220
1164
                if exc_info is not None:
1221
1165
                    self._write_error_status()
1222
1166
                    error_struct = request._translate_error(exc_info[1])
1227
1171
                        self._write_error_status()
1228
1172
                        self._write_structure(chunk.args)
1229
1173
                        break
1230
 
                    num_bytes += len(chunk)
1231
 
                    if first_chunk is None:
1232
 
                        first_chunk = chunk
1233
1174
                    self._write_prefixed_body(chunk)
1234
 
                    self.flush()
1235
 
                    if 'hpssdetail' in debug.debug_flags:
1236
 
                        # Not worth timing separately, as _write_func is
1237
 
                        # actually buffered
1238
 
                        self._trace('body chunk',
1239
 
                                    '%d bytes' % (len(chunk),),
1240
 
                                    chunk, suppress_time=True)
1241
 
            if 'hpss' in debug.debug_flags:
1242
 
                self._trace('body stream',
1243
 
                            '%d bytes %d chunks' % (num_bytes, count),
1244
 
                            first_chunk)
1245
1175
        self._write_end()
1246
 
        if 'hpss' in debug.debug_flags:
1247
 
            self._trace('response end', '', include_time=True)
1248
1176
 
1249
1177
 
1250
1178
def _iter_with_errors(iterable):
1280
1208
        except (KeyboardInterrupt, SystemExit):
1281
1209
            raise
1282
1210
        except Exception:
1283
 
            mutter('_iter_with_errors caught error')
1284
 
            log_exception_quietly()
1285
1211
            yield sys.exc_info(), None
1286
1212
            return
1287
1213
 
1302
1228
            base = getattr(self._medium_request._medium, 'base', None)
1303
1229
            if base is not None:
1304
1230
                mutter('             (to %s)', base)
1305
 
            self._request_start_time = osutils.timer_func()
 
1231
            self._request_start_time = time.time()
1306
1232
        self._write_protocol_version()
1307
1233
        self._write_headers(self._headers)
1308
1234
        self._write_structure(args)
1320
1246
            if path is not None:
1321
1247
                mutter('                  (to %s)', path)
1322
1248
            mutter('              %d bytes', len(body))
1323
 
            self._request_start_time = osutils.timer_func()
 
1249
            self._request_start_time = time.time()
1324
1250
        self._write_protocol_version()
1325
1251
        self._write_headers(self._headers)
1326
1252
        self._write_structure(args)
1339
1265
            path = getattr(self._medium_request._medium, '_path', None)
1340
1266
            if path is not None:
1341
1267
                mutter('                  (to %s)', path)
1342
 
            self._request_start_time = osutils.timer_func()
 
1268
            self._request_start_time = time.time()
1343
1269
        self._write_protocol_version()
1344
1270
        self._write_headers(self._headers)
1345
1271
        self._write_structure(args)
1356
1282
            path = getattr(self._medium_request._medium, '_path', None)
1357
1283
            if path is not None:
1358
1284
                mutter('                  (to %s)', path)
1359
 
            self._request_start_time = osutils.timer_func()
 
1285
            self._request_start_time = time.time()
1360
1286
        self._write_protocol_version()
1361
1287
        self._write_headers(self._headers)
1362
1288
        self._write_structure(args)