~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Vincent Ladeuil
  • Date: 2016-02-01 19:26:41 UTC
  • mto: This revision was merged to the branch mainline in revision 6616.
  • Revision ID: v.ladeuil+lp@free.fr-20160201192641-mzn90m51rydhw00n
Open trunk again as 2.8b1

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
18
18
client and server.
19
19
"""
20
20
 
 
21
from __future__ import absolute_import
 
22
 
21
23
import collections
22
24
from cStringIO import StringIO
23
25
import struct
24
26
import sys
 
27
import thread
25
28
import time
26
29
 
27
30
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
31
from bzrlib import (
 
32
    debug,
 
33
    errors,
 
34
    osutils,
 
35
    )
30
36
from bzrlib.smart import message, request
31
37
from bzrlib.trace import log_exception_quietly, mutter
32
38
from bzrlib.bencode import bdecode_as_tuple, bencode
57
63
 
58
64
def _encode_tuple(args):
59
65
    """Encode the tuple args to a bytestream."""
60
 
    return '\x01'.join(args) + '\n'
 
66
    joined = '\x01'.join(args) + '\n'
 
67
    if type(joined) is unicode:
 
68
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
69
        mutter('response args contain unicode, should be only bytes: %r',
 
70
               joined)
 
71
        joined = joined.encode('ascii')
 
72
    return joined
61
73
 
62
74
 
63
75
class Requester(object):
114
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
127
    """Server-side encoding and decoding logic for smart version 1."""
116
128
 
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
129
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
130
            jail_root=None):
118
131
        self._backing_transport = backing_transport
119
132
        self._root_client_path = root_client_path
 
133
        self._jail_root = jail_root
120
134
        self.unused_data = ''
121
135
        self._finished = False
122
136
        self.in_buffer = ''
144
158
                req_args = _decode_tuple(first_line)
145
159
                self.request = request.SmartServerRequestHandler(
146
160
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
161
                    root_client_path=self._root_client_path,
 
162
                    jail_root=self._jail_root)
 
163
                self.request.args_received(req_args)
149
164
                if self.request.finished_reading:
150
165
                    # trivial request
151
166
                    self.unused_data = self.in_buffer
612
627
            mutter('hpss call:   %s', repr(args)[1:-1])
613
628
            if getattr(self._request._medium, 'base', None) is not None:
614
629
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
630
            self._request_start_time = osutils.timer_func()
616
631
        self._write_args(args)
617
632
        self._request.finished_writing()
618
633
        self._last_verb = args[0]
627
642
            if getattr(self._request._medium, '_path', None) is not None:
628
643
                mutter('                  (to %s)', self._request._medium._path)
629
644
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
645
            self._request_start_time = osutils.timer_func()
631
646
            if 'hpssdetail' in debug.debug_flags:
632
647
                mutter('hpss body content: %s', body)
633
648
        self._write_args(args)
640
655
        """Make a remote call with a readv array.
641
656
 
642
657
        The body is encoded with one line per readv offset pair. The numbers in
643
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
658
        each pair are separated by a comma, and no trailing \\n is emitted.
644
659
        """
645
660
        if 'hpss' in debug.debug_flags:
646
661
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
662
            if getattr(self._request._medium, '_path', None) is not None:
648
663
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
664
            self._request_start_time = osutils.timer_func()
650
665
        self._write_args(args)
651
666
        readv_bytes = self._serialise_offsets(body)
652
667
        bytes = self._encode_bulk_data(readv_bytes)
678
693
        if 'hpss' in debug.debug_flags:
679
694
            if self._request_start_time is not None:
680
695
                mutter('   result:   %6.3fs  %s',
681
 
                       time.time() - self._request_start_time,
 
696
                       osutils.timer_func() - self._request_start_time,
682
697
                       repr(result)[1:-1])
683
698
                self._request_start_time = None
684
699
            else:
858
873
 
859
874
 
860
875
def build_server_protocol_three(backing_transport, write_func,
861
 
                                root_client_path):
 
876
                                root_client_path, jail_root=None):
862
877
    request_handler = request.SmartServerRequestHandler(
863
878
        backing_transport, commands=request.request_handlers,
864
 
        root_client_path=root_client_path)
 
879
        root_client_path=root_client_path, jail_root=jail_root)
865
880
    responder = ProtocolThreeResponder(write_func)
866
881
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
882
    return ProtocolThreeDecoder(message_handler)
1059
1074
class _ProtocolThreeEncoder(object):
1060
1075
 
1061
1076
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1077
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1062
1078
 
1063
1079
    def __init__(self, write_func):
1064
1080
        self._buf = []
 
1081
        self._buf_len = 0
1065
1082
        self._real_write_func = write_func
1066
1083
 
1067
1084
    def _write_func(self, bytes):
 
1085
        # TODO: Another possibility would be to turn this into an async model.
 
1086
        #       Where we let another thread know that we have some bytes if
 
1087
        #       they want it, but we don't actually block for it
 
1088
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1089
        #       we might just push out smaller bits at a time?
1068
1090
        self._buf.append(bytes)
1069
 
        if len(self._buf) > 100:
 
1091
        self._buf_len += len(bytes)
 
1092
        if self._buf_len > self.BUFFER_SIZE:
1070
1093
            self.flush()
1071
1094
 
1072
1095
    def flush(self):
1073
1096
        if self._buf:
1074
1097
            self._real_write_func(''.join(self._buf))
1075
1098
            del self._buf[:]
 
1099
            self._buf_len = 0
1076
1100
 
1077
1101
    def _serialise_offsets(self, offsets):
1078
1102
        """Serialise a readv offset list."""
1127
1151
        _ProtocolThreeEncoder.__init__(self, write_func)
1128
1152
        self.response_sent = False
1129
1153
        self._headers = {'Software version': bzrlib.__version__}
 
1154
        if 'hpss' in debug.debug_flags:
 
1155
            self._thread_id = thread.get_ident()
 
1156
            self._response_start_time = None
 
1157
 
 
1158
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1159
        if self._response_start_time is None:
 
1160
            self._response_start_time = osutils.timer_func()
 
1161
        if include_time:
 
1162
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1163
        else:
 
1164
            t = ''
 
1165
        if extra_bytes is None:
 
1166
            extra = ''
 
1167
        else:
 
1168
            extra = ' ' + repr(extra_bytes[:40])
 
1169
            if len(extra) > 33:
 
1170
                extra = extra[:29] + extra[-1] + '...'
 
1171
        mutter('%12s: [%s] %s%s%s'
 
1172
               % (action, self._thread_id, t, message, extra))
1130
1173
 
1131
1174
    def send_error(self, exception):
1132
1175
        if self.response_sent:
1138
1181
                ('UnknownMethod', exception.verb))
1139
1182
            self.send_response(failure)
1140
1183
            return
 
1184
        if 'hpss' in debug.debug_flags:
 
1185
            self._trace('error', str(exception))
1141
1186
        self.response_sent = True
1142
1187
        self._write_protocol_version()
1143
1188
        self._write_headers(self._headers)
1157
1202
            self._write_success_status()
1158
1203
        else:
1159
1204
            self._write_error_status()
 
1205
        if 'hpss' in debug.debug_flags:
 
1206
            self._trace('response', repr(response.args))
1160
1207
        self._write_structure(response.args)
1161
1208
        if response.body is not None:
1162
1209
            self._write_prefixed_body(response.body)
 
1210
            if 'hpss' in debug.debug_flags:
 
1211
                self._trace('body', '%d bytes' % (len(response.body),),
 
1212
                            response.body, include_time=True)
1163
1213
        elif response.body_stream is not None:
 
1214
            count = num_bytes = 0
 
1215
            first_chunk = None
1164
1216
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1217
                count += 1
1165
1218
                if exc_info is not None:
1166
1219
                    self._write_error_status()
1167
1220
                    error_struct = request._translate_error(exc_info[1])
1172
1225
                        self._write_error_status()
1173
1226
                        self._write_structure(chunk.args)
1174
1227
                        break
 
1228
                    num_bytes += len(chunk)
 
1229
                    if first_chunk is None:
 
1230
                        first_chunk = chunk
1175
1231
                    self._write_prefixed_body(chunk)
 
1232
                    self.flush()
 
1233
                    if 'hpssdetail' in debug.debug_flags:
 
1234
                        # Not worth timing separately, as _write_func is
 
1235
                        # actually buffered
 
1236
                        self._trace('body chunk',
 
1237
                                    '%d bytes' % (len(chunk),),
 
1238
                                    chunk, suppress_time=True)
 
1239
            if 'hpss' in debug.debug_flags:
 
1240
                self._trace('body stream',
 
1241
                            '%d bytes %d chunks' % (num_bytes, count),
 
1242
                            first_chunk)
1176
1243
        self._write_end()
 
1244
        if 'hpss' in debug.debug_flags:
 
1245
            self._trace('response end', '', include_time=True)
1177
1246
 
1178
1247
 
1179
1248
def _iter_with_errors(iterable):
1209
1278
        except (KeyboardInterrupt, SystemExit):
1210
1279
            raise
1211
1280
        except Exception:
 
1281
            mutter('_iter_with_errors caught error')
 
1282
            log_exception_quietly()
1212
1283
            yield sys.exc_info(), None
1213
1284
            return
1214
1285
 
1219
1290
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1220
1291
        self._medium_request = medium_request
1221
1292
        self._headers = {}
 
1293
        self.body_stream_started = None
1222
1294
 
1223
1295
    def set_headers(self, headers):
1224
1296
        self._headers = headers.copy()
1229
1301
            base = getattr(self._medium_request._medium, 'base', None)
1230
1302
            if base is not None:
1231
1303
                mutter('             (to %s)', base)
1232
 
            self._request_start_time = time.time()
 
1304
            self._request_start_time = osutils.timer_func()
1233
1305
        self._write_protocol_version()
1234
1306
        self._write_headers(self._headers)
1235
1307
        self._write_structure(args)
1247
1319
            if path is not None:
1248
1320
                mutter('                  (to %s)', path)
1249
1321
            mutter('              %d bytes', len(body))
1250
 
            self._request_start_time = time.time()
 
1322
            self._request_start_time = osutils.timer_func()
1251
1323
        self._write_protocol_version()
1252
1324
        self._write_headers(self._headers)
1253
1325
        self._write_structure(args)
1259
1331
        """Make a remote call with a readv array.
1260
1332
 
1261
1333
        The body is encoded with one line per readv offset pair. The numbers in
1262
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
1334
        each pair are separated by a comma, and no trailing \\n is emitted.
1263
1335
        """
1264
1336
        if 'hpss' in debug.debug_flags:
1265
1337
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
1266
1338
            path = getattr(self._medium_request._medium, '_path', None)
1267
1339
            if path is not None:
1268
1340
                mutter('                  (to %s)', path)
1269
 
            self._request_start_time = time.time()
 
1341
            self._request_start_time = osutils.timer_func()
1270
1342
        self._write_protocol_version()
1271
1343
        self._write_headers(self._headers)
1272
1344
        self._write_structure(args)
1283
1355
            path = getattr(self._medium_request._medium, '_path', None)
1284
1356
            if path is not None:
1285
1357
                mutter('                  (to %s)', path)
1286
 
            self._request_start_time = time.time()
 
1358
            self._request_start_time = osutils.timer_func()
 
1359
        self.body_stream_started = False
1287
1360
        self._write_protocol_version()
1288
1361
        self._write_headers(self._headers)
1289
1362
        self._write_structure(args)
1291
1364
        #       have finished sending the stream.  We would notice at the end
1292
1365
        #       anyway, but if the medium can deliver it early then it's good
1293
1366
        #       to short-circuit the whole request...
 
1367
        # Provoke any ConnectionReset failures before we start the body stream.
 
1368
        self.flush()
 
1369
        self.body_stream_started = True
1294
1370
        for exc_info, part in _iter_with_errors(stream):
1295
1371
            if exc_info is not None:
1296
1372
                # Iterating the stream failed.  Cleanly abort the request.