~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2010-03-18 09:13:28 UTC
  • mfrom: (5096.1.1 integration)
  • Revision ID: pqm@pqm.ubuntu.com-20100318091328-8fo347hq4at1usky
(vila) Get better feedback about why
        TestGetFileMTime.test_get_file_mtime is failing

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
109
114
        for start, length in offsets:
110
115
            txt.append('%d,%d' % (start, length))
111
116
        return '\n'.join(txt)
112
 
        
 
117
 
113
118
 
114
119
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
120
    """Server-side encoding and decoding logic for smart version 1."""
116
 
    
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
121
 
 
122
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
123
            jail_root=None):
118
124
        self._backing_transport = backing_transport
119
125
        self._root_client_path = root_client_path
 
126
        self._jail_root = jail_root
120
127
        self.unused_data = ''
121
128
        self._finished = False
122
129
        self.in_buffer = ''
127
134
 
128
135
    def accept_bytes(self, bytes):
129
136
        """Take bytes, and advance the internal state machine appropriately.
130
 
        
 
137
 
131
138
        :param bytes: must be a byte string
132
139
        """
133
140
        if not isinstance(bytes, str):
144
151
                req_args = _decode_tuple(first_line)
145
152
                self.request = request.SmartServerRequestHandler(
146
153
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
154
                    root_client_path=self._root_client_path,
 
155
                    jail_root=self._jail_root)
 
156
                self.request.args_received(req_args)
149
157
                if self.request.finished_reading:
150
158
                    # trivial request
151
159
                    self.unused_data = self.in_buffer
169
177
 
170
178
        if self._has_dispatched:
171
179
            if self._finished:
172
 
                # nothing to do.XXX: this routine should be a single state 
 
180
                # nothing to do.XXX: this routine should be a single state
173
181
                # machine too.
174
182
                self.unused_data += self.in_buffer
175
183
                self.in_buffer = ''
211
219
 
212
220
    def _write_protocol_version(self):
213
221
        """Write any prefixes this protocol requires.
214
 
        
 
222
 
215
223
        Version one doesn't send protocol versions.
216
224
        """
217
225
 
234
242
 
235
243
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
236
244
    r"""Version two of the server side of the smart protocol.
237
 
   
 
245
 
238
246
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
247
    """
240
248
 
250
258
 
251
259
    def _write_protocol_version(self):
252
260
        r"""Write any prefixes this protocol requires.
253
 
        
 
261
 
254
262
        Version two sends the value of RESPONSE_VERSION_TWO.
255
263
        """
256
264
        self._write_func(self.response_marker)
412
420
        self.chunks = collections.deque()
413
421
        self.error = False
414
422
        self.error_in_progress = None
415
 
    
 
423
 
416
424
    def next_read_size(self):
417
425
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
418
426
        # end-of-body marker is 4 bytes: 'END\n'.
456
464
 
457
465
    def _finished(self):
458
466
        self.unused_data = self._get_in_buffer()
459
 
        # self._in_buffer = None
460
467
        self._in_buffer_list = []
461
468
        self._in_buffer_len = 0
462
469
        self.state_accept = self._state_accept_reading_unused
507
514
                self.chunks.append(self.chunk_in_progress)
508
515
            self.chunk_in_progress = None
509
516
            self.state_accept = self._state_accept_expecting_length
510
 
        
 
517
 
511
518
    def _state_accept_reading_unused(self):
512
519
        self.unused_data += self._get_in_buffer()
513
520
        self._in_buffer_list = []
515
522
 
516
523
class LengthPrefixedBodyDecoder(_StatefulDecoder):
517
524
    """Decodes the length-prefixed bulk data."""
518
 
    
 
525
 
519
526
    def __init__(self):
520
527
        _StatefulDecoder.__init__(self)
521
528
        self.state_accept = self._state_accept_expecting_length
522
529
        self.state_read = self._state_read_no_data
523
530
        self._body = ''
524
531
        self._trailer_buffer = ''
525
 
    
 
532
 
526
533
    def next_read_size(self):
527
534
        if self.bytes_left is not None:
528
535
            # Ideally we want to read all the remainder of the body and the
538
545
        else:
539
546
            # Reading excess data.  Either way, 1 byte at a time is fine.
540
547
            return 1
541
 
        
 
548
 
542
549
    def read_pending_data(self):
543
550
        """Return any pending data that has been decoded."""
544
551
        return self.state_read()
565
572
                self._body = self._body[:self.bytes_left]
566
573
            self.bytes_left = None
567
574
            self.state_accept = self._state_accept_reading_trailer
568
 
        
 
575
 
569
576
    def _state_accept_reading_trailer(self):
570
577
        self._trailer_buffer += self._get_in_buffer()
571
578
        self._set_in_buffer(None)
575
582
            self.unused_data = self._trailer_buffer[len('done\n'):]
576
583
            self.state_accept = self._state_accept_reading_unused
577
584
            self.finished_reading = True
578
 
    
 
585
 
579
586
    def _state_accept_reading_unused(self):
580
587
        self.unused_data += self._get_in_buffer()
581
588
        self._set_in_buffer(None)
613
620
            mutter('hpss call:   %s', repr(args)[1:-1])
614
621
            if getattr(self._request._medium, 'base', None) is not None:
615
622
                mutter('             (to %s)', self._request._medium.base)
616
 
            self._request_start_time = time.time()
 
623
            self._request_start_time = osutils.timer_func()
617
624
        self._write_args(args)
618
625
        self._request.finished_writing()
619
626
        self._last_verb = args[0]
628
635
            if getattr(self._request._medium, '_path', None) is not None:
629
636
                mutter('                  (to %s)', self._request._medium._path)
630
637
            mutter('              %d bytes', len(body))
631
 
            self._request_start_time = time.time()
 
638
            self._request_start_time = osutils.timer_func()
632
639
            if 'hpssdetail' in debug.debug_flags:
633
640
                mutter('hpss body content: %s', body)
634
641
        self._write_args(args)
647
654
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
648
655
            if getattr(self._request._medium, '_path', None) is not None:
649
656
                mutter('                  (to %s)', self._request._medium._path)
650
 
            self._request_start_time = time.time()
 
657
            self._request_start_time = osutils.timer_func()
651
658
        self._write_args(args)
652
659
        readv_bytes = self._serialise_offsets(body)
653
660
        bytes = self._encode_bulk_data(readv_bytes)
657
664
            mutter('              %d bytes in readv request', len(readv_bytes))
658
665
        self._last_verb = args[0]
659
666
 
 
667
    def call_with_body_stream(self, args, stream):
 
668
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
669
        # assume that a v1/v2 server doesn't support whatever method we're
 
670
        # trying to call with a body stream.
 
671
        self._request.finished_writing()
 
672
        self._request.finished_reading()
 
673
        raise errors.UnknownSmartMethod(args[0])
 
674
 
660
675
    def cancel_read_body(self):
661
676
        """After expecting a body, a response code may indicate one otherwise.
662
677
 
671
686
        if 'hpss' in debug.debug_flags:
672
687
            if self._request_start_time is not None:
673
688
                mutter('   result:   %6.3fs  %s',
674
 
                       time.time() - self._request_start_time,
 
689
                       osutils.timer_func() - self._request_start_time,
675
690
                       repr(result)[1:-1])
676
691
                self._request_start_time = None
677
692
            else:
722
737
    def _response_is_unknown_method(self, result_tuple):
723
738
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
724
739
        method' response to the request.
725
 
        
 
740
 
726
741
        :param response: The response from a smart client call_expecting_body
727
742
            call.
728
743
        :param verb: The verb used in that call.
735
750
            # The response will have no body, so we've finished reading.
736
751
            self._request.finished_reading()
737
752
            raise errors.UnknownSmartMethod(self._last_verb)
738
 
        
 
753
 
739
754
    def read_body_bytes(self, count=-1):
740
755
        """Read bytes from the body, decoding into a byte stream.
741
 
        
742
 
        We read all bytes at once to ensure we've checked the trailer for 
 
756
 
 
757
        We read all bytes at once to ensure we've checked the trailer for
743
758
        errors, and then feed the buffer back as read_body_bytes is called.
744
759
        """
745
760
        if self._body_buffer is not None:
783
798
 
784
799
    def _write_protocol_version(self):
785
800
        """Write any prefixes this protocol requires.
786
 
        
 
801
 
787
802
        Version one doesn't send protocol versions.
788
803
        """
789
804
 
790
805
 
791
806
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
792
807
    """Version two of the client side of the smart protocol.
793
 
    
 
808
 
794
809
    This prefixes the request with the value of REQUEST_VERSION_TWO.
795
810
    """
796
811
 
824
839
 
825
840
    def _write_protocol_version(self):
826
841
        """Write any prefixes this protocol requires.
827
 
        
 
842
 
828
843
        Version two sends the value of REQUEST_VERSION_TWO.
829
844
        """
830
845
        self._request.accept_bytes(self.request_marker)
851
866
 
852
867
 
853
868
def build_server_protocol_three(backing_transport, write_func,
854
 
                                root_client_path):
 
869
                                root_client_path, jail_root=None):
855
870
    request_handler = request.SmartServerRequestHandler(
856
871
        backing_transport, commands=request.request_handlers,
857
 
        root_client_path=root_client_path)
 
872
        root_client_path=root_client_path, jail_root=jail_root)
858
873
    responder = ProtocolThreeResponder(write_func)
859
874
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
860
875
    return ProtocolThreeDecoder(message_handler)
890
905
            # We do *not* set self.decoding_failed here.  The message handler
891
906
            # has raised an error, but the decoder is still able to parse bytes
892
907
            # and determine when this message ends.
893
 
            log_exception_quietly()
 
908
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
909
                log_exception_quietly()
894
910
            self.message_handler.protocol_error(exception.exc_value)
895
911
            # The state machine is ready to continue decoding, but the
896
912
            # exception has interrupted the loop that runs the state machine.
932
948
    def _extract_prefixed_bencoded_data(self):
933
949
        prefixed_bytes = self._extract_length_prefixed_bytes()
934
950
        try:
935
 
            decoded = bdecode(prefixed_bytes)
 
951
            decoded = bdecode_as_tuple(prefixed_bytes)
936
952
        except ValueError:
937
953
            raise errors.SmartProtocolError(
938
954
                'Bytes %r not bencoded' % (prefixed_bytes,))
978
994
            self.message_handler.headers_received(decoded)
979
995
        except:
980
996
            raise errors.SmartMessageHandlerError(sys.exc_info())
981
 
    
 
997
 
982
998
    def _state_accept_expecting_message_part(self):
983
999
        message_part_kind = self._extract_single_byte()
984
1000
        if message_part_kind == 'o':
1029
1045
            raise errors.SmartMessageHandlerError(sys.exc_info())
1030
1046
 
1031
1047
    def _state_accept_reading_unused(self):
1032
 
        self.unused_data = self._get_in_buffer()
 
1048
        self.unused_data += self._get_in_buffer()
1033
1049
        self._set_in_buffer(None)
1034
1050
 
1035
1051
    def next_read_size(self):
1051
1067
class _ProtocolThreeEncoder(object):
1052
1068
 
1053
1069
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1070
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1054
1071
 
1055
1072
    def __init__(self, write_func):
1056
 
        self._buf = ''
 
1073
        self._buf = []
 
1074
        self._buf_len = 0
1057
1075
        self._real_write_func = write_func
1058
1076
 
1059
1077
    def _write_func(self, bytes):
1060
 
        self._buf += bytes
 
1078
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1079
        #       for total number of bytes to write, rather than buffer based on
 
1080
        #       the number of write() calls
 
1081
        # TODO: Another possibility would be to turn this into an async model.
 
1082
        #       Where we let another thread know that we have some bytes if
 
1083
        #       they want it, but we don't actually block for it
 
1084
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1085
        #       we might just push out smaller bits at a time?
 
1086
        self._buf.append(bytes)
 
1087
        self._buf_len += len(bytes)
 
1088
        if self._buf_len > self.BUFFER_SIZE:
 
1089
            self.flush()
1061
1090
 
1062
1091
    def flush(self):
1063
1092
        if self._buf:
1064
 
            self._real_write_func(self._buf)
1065
 
            self._buf = ''
 
1093
            self._real_write_func(''.join(self._buf))
 
1094
            del self._buf[:]
 
1095
            self._buf_len = 0
1066
1096
 
1067
1097
    def _serialise_offsets(self, offsets):
1068
1098
        """Serialise a readv offset list."""
1070
1100
        for start, length in offsets:
1071
1101
            txt.append('%d,%d' % (start, length))
1072
1102
        return '\n'.join(txt)
1073
 
        
 
1103
 
1074
1104
    def _write_protocol_version(self):
1075
1105
        self._write_func(MESSAGE_VERSION_THREE)
1076
1106
 
1101
1131
        self._write_func(struct.pack('!L', len(bytes)))
1102
1132
        self._write_func(bytes)
1103
1133
 
 
1134
    def _write_chunked_body_start(self):
 
1135
        self._write_func('oC')
 
1136
 
1104
1137
    def _write_error_status(self):
1105
1138
        self._write_func('oE')
1106
1139
 
1114
1147
        _ProtocolThreeEncoder.__init__(self, write_func)
1115
1148
        self.response_sent = False
1116
1149
        self._headers = {'Software version': bzrlib.__version__}
 
1150
        if 'hpss' in debug.debug_flags:
 
1151
            self._thread_id = thread.get_ident()
 
1152
            self._response_start_time = None
 
1153
 
 
1154
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1155
        if self._response_start_time is None:
 
1156
            self._response_start_time = osutils.timer_func()
 
1157
        if include_time:
 
1158
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1159
        else:
 
1160
            t = ''
 
1161
        if extra_bytes is None:
 
1162
            extra = ''
 
1163
        else:
 
1164
            extra = ' ' + repr(extra_bytes[:40])
 
1165
            if len(extra) > 33:
 
1166
                extra = extra[:29] + extra[-1] + '...'
 
1167
        mutter('%12s: [%s] %s%s%s'
 
1168
               % (action, self._thread_id, t, message, extra))
1117
1169
 
1118
1170
    def send_error(self, exception):
1119
1171
        if self.response_sent:
1125
1177
                ('UnknownMethod', exception.verb))
1126
1178
            self.send_response(failure)
1127
1179
            return
 
1180
        if 'hpss' in debug.debug_flags:
 
1181
            self._trace('error', str(exception))
1128
1182
        self.response_sent = True
1129
1183
        self._write_protocol_version()
1130
1184
        self._write_headers(self._headers)
1144
1198
            self._write_success_status()
1145
1199
        else:
1146
1200
            self._write_error_status()
 
1201
        if 'hpss' in debug.debug_flags:
 
1202
            self._trace('response', repr(response.args))
1147
1203
        self._write_structure(response.args)
1148
1204
        if response.body is not None:
1149
1205
            self._write_prefixed_body(response.body)
 
1206
            if 'hpss' in debug.debug_flags:
 
1207
                self._trace('body', '%d bytes' % (len(response.body),),
 
1208
                            response.body, include_time=True)
1150
1209
        elif response.body_stream is not None:
1151
 
            for chunk in response.body_stream:
1152
 
                self._write_prefixed_body(chunk)
1153
 
                self.flush()
 
1210
            count = num_bytes = 0
 
1211
            first_chunk = None
 
1212
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1213
                count += 1
 
1214
                if exc_info is not None:
 
1215
                    self._write_error_status()
 
1216
                    error_struct = request._translate_error(exc_info[1])
 
1217
                    self._write_structure(error_struct)
 
1218
                    break
 
1219
                else:
 
1220
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1221
                        self._write_error_status()
 
1222
                        self._write_structure(chunk.args)
 
1223
                        break
 
1224
                    num_bytes += len(chunk)
 
1225
                    if first_chunk is None:
 
1226
                        first_chunk = chunk
 
1227
                    self._write_prefixed_body(chunk)
 
1228
                    if 'hpssdetail' in debug.debug_flags:
 
1229
                        # Not worth timing separately, as _write_func is
 
1230
                        # actually buffered
 
1231
                        self._trace('body chunk',
 
1232
                                    '%d bytes' % (len(chunk),),
 
1233
                                    chunk, suppress_time=True)
 
1234
            if 'hpss' in debug.debug_flags:
 
1235
                self._trace('body stream',
 
1236
                            '%d bytes %d chunks' % (num_bytes, count),
 
1237
                            first_chunk)
1154
1238
        self._write_end()
1155
 
        
 
1239
        if 'hpss' in debug.debug_flags:
 
1240
            self._trace('response end', '', include_time=True)
 
1241
 
 
1242
 
 
1243
def _iter_with_errors(iterable):
 
1244
    """Handle errors from iterable.next().
 
1245
 
 
1246
    Use like::
 
1247
 
 
1248
        for exc_info, value in _iter_with_errors(iterable):
 
1249
            ...
 
1250
 
 
1251
    This is a safer alternative to::
 
1252
 
 
1253
        try:
 
1254
            for value in iterable:
 
1255
               ...
 
1256
        except:
 
1257
            ...
 
1258
 
 
1259
    Because the latter will catch errors from the for-loop body, not just
 
1260
    iterable.next()
 
1261
 
 
1262
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1263
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1264
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1265
    will not be itercepted.
 
1266
    """
 
1267
    iterator = iter(iterable)
 
1268
    while True:
 
1269
        try:
 
1270
            yield None, iterator.next()
 
1271
        except StopIteration:
 
1272
            return
 
1273
        except (KeyboardInterrupt, SystemExit):
 
1274
            raise
 
1275
        except Exception:
 
1276
            mutter('_iter_with_errors caught error')
 
1277
            log_exception_quietly()
 
1278
            yield sys.exc_info(), None
 
1279
            return
 
1280
 
1156
1281
 
1157
1282
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1158
1283
 
1163
1288
 
1164
1289
    def set_headers(self, headers):
1165
1290
        self._headers = headers.copy()
1166
 
        
 
1291
 
1167
1292
    def call(self, *args):
1168
1293
        if 'hpss' in debug.debug_flags:
1169
1294
            mutter('hpss call:   %s', repr(args)[1:-1])
1170
1295
            base = getattr(self._medium_request._medium, 'base', None)
1171
1296
            if base is not None:
1172
1297
                mutter('             (to %s)', base)
1173
 
            self._request_start_time = time.time()
 
1298
            self._request_start_time = osutils.timer_func()
1174
1299
        self._write_protocol_version()
1175
1300
        self._write_headers(self._headers)
1176
1301
        self._write_structure(args)
1188
1313
            if path is not None:
1189
1314
                mutter('                  (to %s)', path)
1190
1315
            mutter('              %d bytes', len(body))
1191
 
            self._request_start_time = time.time()
 
1316
            self._request_start_time = osutils.timer_func()
1192
1317
        self._write_protocol_version()
1193
1318
        self._write_headers(self._headers)
1194
1319
        self._write_structure(args)
1207
1332
            path = getattr(self._medium_request._medium, '_path', None)
1208
1333
            if path is not None:
1209
1334
                mutter('                  (to %s)', path)
1210
 
            self._request_start_time = time.time()
 
1335
            self._request_start_time = osutils.timer_func()
1211
1336
        self._write_protocol_version()
1212
1337
        self._write_headers(self._headers)
1213
1338
        self._write_structure(args)
1218
1343
        self._write_end()
1219
1344
        self._medium_request.finished_writing()
1220
1345
 
 
1346
    def call_with_body_stream(self, args, stream):
 
1347
        if 'hpss' in debug.debug_flags:
 
1348
            mutter('hpss call w/body stream: %r', args)
 
1349
            path = getattr(self._medium_request._medium, '_path', None)
 
1350
            if path is not None:
 
1351
                mutter('                  (to %s)', path)
 
1352
            self._request_start_time = osutils.timer_func()
 
1353
        self._write_protocol_version()
 
1354
        self._write_headers(self._headers)
 
1355
        self._write_structure(args)
 
1356
        # TODO: notice if the server has sent an early error reply before we
 
1357
        #       have finished sending the stream.  We would notice at the end
 
1358
        #       anyway, but if the medium can deliver it early then it's good
 
1359
        #       to short-circuit the whole request...
 
1360
        for exc_info, part in _iter_with_errors(stream):
 
1361
            if exc_info is not None:
 
1362
                # Iterating the stream failed.  Cleanly abort the request.
 
1363
                self._write_error_status()
 
1364
                # Currently the client unconditionally sends ('error',) as the
 
1365
                # error args.
 
1366
                self._write_structure(('error',))
 
1367
                self._write_end()
 
1368
                self._medium_request.finished_writing()
 
1369
                raise exc_info[0], exc_info[1], exc_info[2]
 
1370
            else:
 
1371
                self._write_prefixed_body(part)
 
1372
                self.flush()
 
1373
        self._write_end()
 
1374
        self._medium_request.finished_writing()
 
1375