~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Vincent Ladeuil
  • Date: 2010-02-10 15:46:03 UTC
  • mfrom: (4985.3.21 update)
  • mto: This revision was merged to the branch mainline in revision 5021.
  • Revision ID: v.ladeuil+lp@free.fr-20100210154603-k4no1gvfuqpzrw7p
Update performs two merges in a more logical order but stop on conflicts

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode_as_tuple, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
109
114
        for start, length in offsets:
110
115
            txt.append('%d,%d' % (start, length))
111
116
        return '\n'.join(txt)
112
 
        
 
117
 
113
118
 
114
119
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
120
    """Server-side encoding and decoding logic for smart version 1."""
116
 
    
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
121
 
 
122
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
123
            jail_root=None):
118
124
        self._backing_transport = backing_transport
119
125
        self._root_client_path = root_client_path
 
126
        self._jail_root = jail_root
120
127
        self.unused_data = ''
121
128
        self._finished = False
122
129
        self.in_buffer = ''
127
134
 
128
135
    def accept_bytes(self, bytes):
129
136
        """Take bytes, and advance the internal state machine appropriately.
130
 
        
 
137
 
131
138
        :param bytes: must be a byte string
132
139
        """
133
140
        if not isinstance(bytes, str):
144
151
                req_args = _decode_tuple(first_line)
145
152
                self.request = request.SmartServerRequestHandler(
146
153
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
154
                    root_client_path=self._root_client_path,
 
155
                    jail_root=self._jail_root)
 
156
                self.request.args_received(req_args)
149
157
                if self.request.finished_reading:
150
158
                    # trivial request
151
159
                    self.unused_data = self.in_buffer
169
177
 
170
178
        if self._has_dispatched:
171
179
            if self._finished:
172
 
                # nothing to do.XXX: this routine should be a single state 
 
180
                # nothing to do.XXX: this routine should be a single state
173
181
                # machine too.
174
182
                self.unused_data += self.in_buffer
175
183
                self.in_buffer = ''
211
219
 
212
220
    def _write_protocol_version(self):
213
221
        """Write any prefixes this protocol requires.
214
 
        
 
222
 
215
223
        Version one doesn't send protocol versions.
216
224
        """
217
225
 
234
242
 
235
243
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
236
244
    r"""Version two of the server side of the smart protocol.
237
 
   
 
245
 
238
246
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
247
    """
240
248
 
250
258
 
251
259
    def _write_protocol_version(self):
252
260
        r"""Write any prefixes this protocol requires.
253
 
        
 
261
 
254
262
        Version two sends the value of RESPONSE_VERSION_TWO.
255
263
        """
256
264
        self._write_func(self.response_marker)
412
420
        self.chunks = collections.deque()
413
421
        self.error = False
414
422
        self.error_in_progress = None
415
 
    
 
423
 
416
424
    def next_read_size(self):
417
425
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
418
426
        # end-of-body marker is 4 bytes: 'END\n'.
506
514
                self.chunks.append(self.chunk_in_progress)
507
515
            self.chunk_in_progress = None
508
516
            self.state_accept = self._state_accept_expecting_length
509
 
        
 
517
 
510
518
    def _state_accept_reading_unused(self):
511
519
        self.unused_data += self._get_in_buffer()
512
520
        self._in_buffer_list = []
514
522
 
515
523
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
524
    """Decodes the length-prefixed bulk data."""
517
 
    
 
525
 
518
526
    def __init__(self):
519
527
        _StatefulDecoder.__init__(self)
520
528
        self.state_accept = self._state_accept_expecting_length
521
529
        self.state_read = self._state_read_no_data
522
530
        self._body = ''
523
531
        self._trailer_buffer = ''
524
 
    
 
532
 
525
533
    def next_read_size(self):
526
534
        if self.bytes_left is not None:
527
535
            # Ideally we want to read all the remainder of the body and the
537
545
        else:
538
546
            # Reading excess data.  Either way, 1 byte at a time is fine.
539
547
            return 1
540
 
        
 
548
 
541
549
    def read_pending_data(self):
542
550
        """Return any pending data that has been decoded."""
543
551
        return self.state_read()
564
572
                self._body = self._body[:self.bytes_left]
565
573
            self.bytes_left = None
566
574
            self.state_accept = self._state_accept_reading_trailer
567
 
        
 
575
 
568
576
    def _state_accept_reading_trailer(self):
569
577
        self._trailer_buffer += self._get_in_buffer()
570
578
        self._set_in_buffer(None)
574
582
            self.unused_data = self._trailer_buffer[len('done\n'):]
575
583
            self.state_accept = self._state_accept_reading_unused
576
584
            self.finished_reading = True
577
 
    
 
585
 
578
586
    def _state_accept_reading_unused(self):
579
587
        self.unused_data += self._get_in_buffer()
580
588
        self._set_in_buffer(None)
612
620
            mutter('hpss call:   %s', repr(args)[1:-1])
613
621
            if getattr(self._request._medium, 'base', None) is not None:
614
622
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
623
            self._request_start_time = osutils.timer_func()
616
624
        self._write_args(args)
617
625
        self._request.finished_writing()
618
626
        self._last_verb = args[0]
627
635
            if getattr(self._request._medium, '_path', None) is not None:
628
636
                mutter('                  (to %s)', self._request._medium._path)
629
637
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
638
            self._request_start_time = osutils.timer_func()
631
639
            if 'hpssdetail' in debug.debug_flags:
632
640
                mutter('hpss body content: %s', body)
633
641
        self._write_args(args)
646
654
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
655
            if getattr(self._request._medium, '_path', None) is not None:
648
656
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
657
            self._request_start_time = osutils.timer_func()
650
658
        self._write_args(args)
651
659
        readv_bytes = self._serialise_offsets(body)
652
660
        bytes = self._encode_bulk_data(readv_bytes)
655
663
        if 'hpss' in debug.debug_flags:
656
664
            mutter('              %d bytes in readv request', len(readv_bytes))
657
665
        self._last_verb = args[0]
658
 
    
 
666
 
659
667
    def call_with_body_stream(self, args, stream):
660
668
        # Protocols v1 and v2 don't support body streams.  So it's safe to
661
669
        # assume that a v1/v2 server doesn't support whatever method we're
678
686
        if 'hpss' in debug.debug_flags:
679
687
            if self._request_start_time is not None:
680
688
                mutter('   result:   %6.3fs  %s',
681
 
                       time.time() - self._request_start_time,
 
689
                       osutils.timer_func() - self._request_start_time,
682
690
                       repr(result)[1:-1])
683
691
                self._request_start_time = None
684
692
            else:
729
737
    def _response_is_unknown_method(self, result_tuple):
730
738
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
731
739
        method' response to the request.
732
 
        
 
740
 
733
741
        :param response: The response from a smart client call_expecting_body
734
742
            call.
735
743
        :param verb: The verb used in that call.
742
750
            # The response will have no body, so we've finished reading.
743
751
            self._request.finished_reading()
744
752
            raise errors.UnknownSmartMethod(self._last_verb)
745
 
        
 
753
 
746
754
    def read_body_bytes(self, count=-1):
747
755
        """Read bytes from the body, decoding into a byte stream.
748
 
        
749
 
        We read all bytes at once to ensure we've checked the trailer for 
 
756
 
 
757
        We read all bytes at once to ensure we've checked the trailer for
750
758
        errors, and then feed the buffer back as read_body_bytes is called.
751
759
        """
752
760
        if self._body_buffer is not None:
790
798
 
791
799
    def _write_protocol_version(self):
792
800
        """Write any prefixes this protocol requires.
793
 
        
 
801
 
794
802
        Version one doesn't send protocol versions.
795
803
        """
796
804
 
797
805
 
798
806
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
799
807
    """Version two of the client side of the smart protocol.
800
 
    
 
808
 
801
809
    This prefixes the request with the value of REQUEST_VERSION_TWO.
802
810
    """
803
811
 
831
839
 
832
840
    def _write_protocol_version(self):
833
841
        """Write any prefixes this protocol requires.
834
 
        
 
842
 
835
843
        Version two sends the value of REQUEST_VERSION_TWO.
836
844
        """
837
845
        self._request.accept_bytes(self.request_marker)
858
866
 
859
867
 
860
868
def build_server_protocol_three(backing_transport, write_func,
861
 
                                root_client_path):
 
869
                                root_client_path, jail_root=None):
862
870
    request_handler = request.SmartServerRequestHandler(
863
871
        backing_transport, commands=request.request_handlers,
864
 
        root_client_path=root_client_path)
 
872
        root_client_path=root_client_path, jail_root=jail_root)
865
873
    responder = ProtocolThreeResponder(write_func)
866
874
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
875
    return ProtocolThreeDecoder(message_handler)
897
905
            # We do *not* set self.decoding_failed here.  The message handler
898
906
            # has raised an error, but the decoder is still able to parse bytes
899
907
            # and determine when this message ends.
900
 
            log_exception_quietly()
 
908
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
909
                log_exception_quietly()
901
910
            self.message_handler.protocol_error(exception.exc_value)
902
911
            # The state machine is ready to continue decoding, but the
903
912
            # exception has interrupted the loop that runs the state machine.
985
994
            self.message_handler.headers_received(decoded)
986
995
        except:
987
996
            raise errors.SmartMessageHandlerError(sys.exc_info())
988
 
    
 
997
 
989
998
    def _state_accept_expecting_message_part(self):
990
999
        message_part_kind = self._extract_single_byte()
991
1000
        if message_part_kind == 'o':
1036
1045
            raise errors.SmartMessageHandlerError(sys.exc_info())
1037
1046
 
1038
1047
    def _state_accept_reading_unused(self):
1039
 
        self.unused_data = self._get_in_buffer()
 
1048
        self.unused_data += self._get_in_buffer()
1040
1049
        self._set_in_buffer(None)
1041
1050
 
1042
1051
    def next_read_size(self):
1058
1067
class _ProtocolThreeEncoder(object):
1059
1068
 
1060
1069
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1070
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1061
1071
 
1062
1072
    def __init__(self, write_func):
1063
 
        self._buf = ''
 
1073
        self._buf = []
 
1074
        self._buf_len = 0
1064
1075
        self._real_write_func = write_func
1065
1076
 
1066
1077
    def _write_func(self, bytes):
1067
 
        self._buf += bytes
 
1078
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1079
        #       for total number of bytes to write, rather than buffer based on
 
1080
        #       the number of write() calls
 
1081
        # TODO: Another possibility would be to turn this into an async model.
 
1082
        #       Where we let another thread know that we have some bytes if
 
1083
        #       they want it, but we don't actually block for it
 
1084
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1085
        #       we might just push out smaller bits at a time?
 
1086
        self._buf.append(bytes)
 
1087
        self._buf_len += len(bytes)
 
1088
        if self._buf_len > self.BUFFER_SIZE:
 
1089
            self.flush()
1068
1090
 
1069
1091
    def flush(self):
1070
1092
        if self._buf:
1071
 
            self._real_write_func(self._buf)
1072
 
            self._buf = ''
 
1093
            self._real_write_func(''.join(self._buf))
 
1094
            del self._buf[:]
 
1095
            self._buf_len = 0
1073
1096
 
1074
1097
    def _serialise_offsets(self, offsets):
1075
1098
        """Serialise a readv offset list."""
1077
1100
        for start, length in offsets:
1078
1101
            txt.append('%d,%d' % (start, length))
1079
1102
        return '\n'.join(txt)
1080
 
        
 
1103
 
1081
1104
    def _write_protocol_version(self):
1082
1105
        self._write_func(MESSAGE_VERSION_THREE)
1083
1106
 
1124
1147
        _ProtocolThreeEncoder.__init__(self, write_func)
1125
1148
        self.response_sent = False
1126
1149
        self._headers = {'Software version': bzrlib.__version__}
 
1150
        if 'hpss' in debug.debug_flags:
 
1151
            self._thread_id = thread.get_ident()
 
1152
            self._response_start_time = None
 
1153
 
 
1154
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1155
        if self._response_start_time is None:
 
1156
            self._response_start_time = osutils.timer_func()
 
1157
        if include_time:
 
1158
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1159
        else:
 
1160
            t = ''
 
1161
        if extra_bytes is None:
 
1162
            extra = ''
 
1163
        else:
 
1164
            extra = ' ' + repr(extra_bytes[:40])
 
1165
            if len(extra) > 33:
 
1166
                extra = extra[:29] + extra[-1] + '...'
 
1167
        mutter('%12s: [%s] %s%s%s'
 
1168
               % (action, self._thread_id, t, message, extra))
1127
1169
 
1128
1170
    def send_error(self, exception):
1129
1171
        if self.response_sent:
1135
1177
                ('UnknownMethod', exception.verb))
1136
1178
            self.send_response(failure)
1137
1179
            return
 
1180
        if 'hpss' in debug.debug_flags:
 
1181
            self._trace('error', str(exception))
1138
1182
        self.response_sent = True
1139
1183
        self._write_protocol_version()
1140
1184
        self._write_headers(self._headers)
1154
1198
            self._write_success_status()
1155
1199
        else:
1156
1200
            self._write_error_status()
 
1201
        if 'hpss' in debug.debug_flags:
 
1202
            self._trace('response', repr(response.args))
1157
1203
        self._write_structure(response.args)
1158
1204
        if response.body is not None:
1159
1205
            self._write_prefixed_body(response.body)
 
1206
            if 'hpss' in debug.debug_flags:
 
1207
                self._trace('body', '%d bytes' % (len(response.body),),
 
1208
                            response.body, include_time=True)
1160
1209
        elif response.body_stream is not None:
1161
 
            for chunk in response.body_stream:
1162
 
                self._write_prefixed_body(chunk)
1163
 
                self.flush()
 
1210
            count = num_bytes = 0
 
1211
            first_chunk = None
 
1212
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1213
                count += 1
 
1214
                if exc_info is not None:
 
1215
                    self._write_error_status()
 
1216
                    error_struct = request._translate_error(exc_info[1])
 
1217
                    self._write_structure(error_struct)
 
1218
                    break
 
1219
                else:
 
1220
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1221
                        self._write_error_status()
 
1222
                        self._write_structure(chunk.args)
 
1223
                        break
 
1224
                    num_bytes += len(chunk)
 
1225
                    if first_chunk is None:
 
1226
                        first_chunk = chunk
 
1227
                    self._write_prefixed_body(chunk)
 
1228
                    if 'hpssdetail' in debug.debug_flags:
 
1229
                        # Not worth timing separately, as _write_func is
 
1230
                        # actually buffered
 
1231
                        self._trace('body chunk',
 
1232
                                    '%d bytes' % (len(chunk),),
 
1233
                                    chunk, suppress_time=True)
 
1234
            if 'hpss' in debug.debug_flags:
 
1235
                self._trace('body stream',
 
1236
                            '%d bytes %d chunks' % (num_bytes, count),
 
1237
                            first_chunk)
1164
1238
        self._write_end()
1165
 
        
 
1239
        if 'hpss' in debug.debug_flags:
 
1240
            self._trace('response end', '', include_time=True)
 
1241
 
 
1242
 
 
1243
def _iter_with_errors(iterable):
 
1244
    """Handle errors from iterable.next().
 
1245
 
 
1246
    Use like::
 
1247
 
 
1248
        for exc_info, value in _iter_with_errors(iterable):
 
1249
            ...
 
1250
 
 
1251
    This is a safer alternative to::
 
1252
 
 
1253
        try:
 
1254
            for value in iterable:
 
1255
               ...
 
1256
        except:
 
1257
            ...
 
1258
 
 
1259
    Because the latter will catch errors from the for-loop body, not just
 
1260
    iterable.next()
 
1261
 
 
1262
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1263
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1264
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1265
    will not be itercepted.
 
1266
    """
 
1267
    iterator = iter(iterable)
 
1268
    while True:
 
1269
        try:
 
1270
            yield None, iterator.next()
 
1271
        except StopIteration:
 
1272
            return
 
1273
        except (KeyboardInterrupt, SystemExit):
 
1274
            raise
 
1275
        except Exception:
 
1276
            mutter('_iter_with_errors caught error')
 
1277
            log_exception_quietly()
 
1278
            yield sys.exc_info(), None
 
1279
            return
 
1280
 
1166
1281
 
1167
1282
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1168
1283
 
1173
1288
 
1174
1289
    def set_headers(self, headers):
1175
1290
        self._headers = headers.copy()
1176
 
        
 
1291
 
1177
1292
    def call(self, *args):
1178
1293
        if 'hpss' in debug.debug_flags:
1179
1294
            mutter('hpss call:   %s', repr(args)[1:-1])
1180
1295
            base = getattr(self._medium_request._medium, 'base', None)
1181
1296
            if base is not None:
1182
1297
                mutter('             (to %s)', base)
1183
 
            self._request_start_time = time.time()
 
1298
            self._request_start_time = osutils.timer_func()
1184
1299
        self._write_protocol_version()
1185
1300
        self._write_headers(self._headers)
1186
1301
        self._write_structure(args)
1198
1313
            if path is not None:
1199
1314
                mutter('                  (to %s)', path)
1200
1315
            mutter('              %d bytes', len(body))
1201
 
            self._request_start_time = time.time()
 
1316
            self._request_start_time = osutils.timer_func()
1202
1317
        self._write_protocol_version()
1203
1318
        self._write_headers(self._headers)
1204
1319
        self._write_structure(args)
1217
1332
            path = getattr(self._medium_request._medium, '_path', None)
1218
1333
            if path is not None:
1219
1334
                mutter('                  (to %s)', path)
1220
 
            self._request_start_time = time.time()
 
1335
            self._request_start_time = osutils.timer_func()
1221
1336
        self._write_protocol_version()
1222
1337
        self._write_headers(self._headers)
1223
1338
        self._write_structure(args)
1234
1349
            path = getattr(self._medium_request._medium, '_path', None)
1235
1350
            if path is not None:
1236
1351
                mutter('                  (to %s)', path)
1237
 
            self._request_start_time = time.time()
 
1352
            self._request_start_time = osutils.timer_func()
1238
1353
        self._write_protocol_version()
1239
1354
        self._write_headers(self._headers)
1240
1355
        self._write_structure(args)
1242
1357
        #       have finished sending the stream.  We would notice at the end
1243
1358
        #       anyway, but if the medium can deliver it early then it's good
1244
1359
        #       to short-circuit the whole request...
1245
 
        try:
1246
 
            for part in stream:
 
1360
        for exc_info, part in _iter_with_errors(stream):
 
1361
            if exc_info is not None:
 
1362
                # Iterating the stream failed.  Cleanly abort the request.
 
1363
                self._write_error_status()
 
1364
                # Currently the client unconditionally sends ('error',) as the
 
1365
                # error args.
 
1366
                self._write_structure(('error',))
 
1367
                self._write_end()
 
1368
                self._medium_request.finished_writing()
 
1369
                raise exc_info[0], exc_info[1], exc_info[2]
 
1370
            else:
1247
1371
                self._write_prefixed_body(part)
1248
1372
                self.flush()
1249
 
        except Exception:
1250
 
            # Iterating the stream failed.  Cleanly abort the request.
1251
 
            self._write_error_status()
1252
 
            # Currently the client unconditionally sends ('error',) as the
1253
 
            # error args.
1254
 
            self._write_structure(('error',))
1255
1373
        self._write_end()
1256
1374
        self._medium_request.finished_writing()
1257
1375