~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Martin Pool
  • Date: 2008-10-20 23:58:12 UTC
  • mto: This revision was merged to the branch mainline in revision 3787.
  • Revision ID: mbp@sourcefrog.net-20081020235812-itg90mk0u4dez92z
lp-upload-release now handles names like bzr-1.8.tar.gz

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007, 2008, 2009 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
25
 
import thread
26
 
import threading
27
25
import time
28
26
 
29
27
import bzrlib
30
 
from bzrlib import (
31
 
    debug,
32
 
    errors,
33
 
    osutils,
34
 
    )
 
28
from bzrlib import debug
 
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
 
from bzrlib.bencode import bdecode_as_tuple, bencode
 
32
from bzrlib.util.bencode import bdecode, bencode
38
33
 
39
34
 
40
35
# Protocol version strings.  These are sent as prefixes of bzr requests and
114
109
        for start, length in offsets:
115
110
            txt.append('%d,%d' % (start, length))
116
111
        return '\n'.join(txt)
117
 
 
 
112
        
118
113
 
119
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
120
115
    """Server-side encoding and decoding logic for smart version 1."""
121
 
 
122
 
    def __init__(self, backing_transport, write_func, root_client_path='/',
123
 
            jail_root=None):
 
116
    
 
117
    def __init__(self, backing_transport, write_func, root_client_path='/'):
124
118
        self._backing_transport = backing_transport
125
119
        self._root_client_path = root_client_path
126
 
        self._jail_root = jail_root
127
120
        self.unused_data = ''
128
121
        self._finished = False
129
122
        self.in_buffer = ''
134
127
 
135
128
    def accept_bytes(self, bytes):
136
129
        """Take bytes, and advance the internal state machine appropriately.
137
 
 
 
130
        
138
131
        :param bytes: must be a byte string
139
132
        """
140
133
        if not isinstance(bytes, str):
151
144
                req_args = _decode_tuple(first_line)
152
145
                self.request = request.SmartServerRequestHandler(
153
146
                    self._backing_transport, commands=request.request_handlers,
154
 
                    root_client_path=self._root_client_path,
155
 
                    jail_root=self._jail_root)
156
 
                self.request.args_received(req_args)
 
147
                    root_client_path=self._root_client_path)
 
148
                self.request.dispatch_command(req_args[0], req_args[1:])
157
149
                if self.request.finished_reading:
158
150
                    # trivial request
159
151
                    self.unused_data = self.in_buffer
177
169
 
178
170
        if self._has_dispatched:
179
171
            if self._finished:
180
 
                # nothing to do.XXX: this routine should be a single state
 
172
                # nothing to do.XXX: this routine should be a single state 
181
173
                # machine too.
182
174
                self.unused_data += self.in_buffer
183
175
                self.in_buffer = ''
219
211
 
220
212
    def _write_protocol_version(self):
221
213
        """Write any prefixes this protocol requires.
222
 
 
 
214
        
223
215
        Version one doesn't send protocol versions.
224
216
        """
225
217
 
242
234
 
243
235
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
244
236
    r"""Version two of the server side of the smart protocol.
245
 
 
 
237
   
246
238
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
247
239
    """
248
240
 
258
250
 
259
251
    def _write_protocol_version(self):
260
252
        r"""Write any prefixes this protocol requires.
261
 
 
 
253
        
262
254
        Version two sends the value of RESPONSE_VERSION_TWO.
263
255
        """
264
256
        self._write_func(self.response_marker)
420
412
        self.chunks = collections.deque()
421
413
        self.error = False
422
414
        self.error_in_progress = None
423
 
 
 
415
    
424
416
    def next_read_size(self):
425
417
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
426
418
        # end-of-body marker is 4 bytes: 'END\n'.
514
506
                self.chunks.append(self.chunk_in_progress)
515
507
            self.chunk_in_progress = None
516
508
            self.state_accept = self._state_accept_expecting_length
517
 
 
 
509
        
518
510
    def _state_accept_reading_unused(self):
519
511
        self.unused_data += self._get_in_buffer()
520
512
        self._in_buffer_list = []
522
514
 
523
515
class LengthPrefixedBodyDecoder(_StatefulDecoder):
524
516
    """Decodes the length-prefixed bulk data."""
525
 
 
 
517
    
526
518
    def __init__(self):
527
519
        _StatefulDecoder.__init__(self)
528
520
        self.state_accept = self._state_accept_expecting_length
529
521
        self.state_read = self._state_read_no_data
530
522
        self._body = ''
531
523
        self._trailer_buffer = ''
532
 
 
 
524
    
533
525
    def next_read_size(self):
534
526
        if self.bytes_left is not None:
535
527
            # Ideally we want to read all the remainder of the body and the
545
537
        else:
546
538
            # Reading excess data.  Either way, 1 byte at a time is fine.
547
539
            return 1
548
 
 
 
540
        
549
541
    def read_pending_data(self):
550
542
        """Return any pending data that has been decoded."""
551
543
        return self.state_read()
572
564
                self._body = self._body[:self.bytes_left]
573
565
            self.bytes_left = None
574
566
            self.state_accept = self._state_accept_reading_trailer
575
 
 
 
567
        
576
568
    def _state_accept_reading_trailer(self):
577
569
        self._trailer_buffer += self._get_in_buffer()
578
570
        self._set_in_buffer(None)
582
574
            self.unused_data = self._trailer_buffer[len('done\n'):]
583
575
            self.state_accept = self._state_accept_reading_unused
584
576
            self.finished_reading = True
585
 
 
 
577
    
586
578
    def _state_accept_reading_unused(self):
587
579
        self.unused_data += self._get_in_buffer()
588
580
        self._set_in_buffer(None)
620
612
            mutter('hpss call:   %s', repr(args)[1:-1])
621
613
            if getattr(self._request._medium, 'base', None) is not None:
622
614
                mutter('             (to %s)', self._request._medium.base)
623
 
            self._request_start_time = osutils.timer_func()
 
615
            self._request_start_time = time.time()
624
616
        self._write_args(args)
625
617
        self._request.finished_writing()
626
618
        self._last_verb = args[0]
635
627
            if getattr(self._request._medium, '_path', None) is not None:
636
628
                mutter('                  (to %s)', self._request._medium._path)
637
629
            mutter('              %d bytes', len(body))
638
 
            self._request_start_time = osutils.timer_func()
 
630
            self._request_start_time = time.time()
639
631
            if 'hpssdetail' in debug.debug_flags:
640
632
                mutter('hpss body content: %s', body)
641
633
        self._write_args(args)
654
646
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
655
647
            if getattr(self._request._medium, '_path', None) is not None:
656
648
                mutter('                  (to %s)', self._request._medium._path)
657
 
            self._request_start_time = osutils.timer_func()
 
649
            self._request_start_time = time.time()
658
650
        self._write_args(args)
659
651
        readv_bytes = self._serialise_offsets(body)
660
652
        bytes = self._encode_bulk_data(readv_bytes)
664
656
            mutter('              %d bytes in readv request', len(readv_bytes))
665
657
        self._last_verb = args[0]
666
658
 
667
 
    def call_with_body_stream(self, args, stream):
668
 
        # Protocols v1 and v2 don't support body streams.  So it's safe to
669
 
        # assume that a v1/v2 server doesn't support whatever method we're
670
 
        # trying to call with a body stream.
671
 
        self._request.finished_writing()
672
 
        self._request.finished_reading()
673
 
        raise errors.UnknownSmartMethod(args[0])
674
 
 
675
659
    def cancel_read_body(self):
676
660
        """After expecting a body, a response code may indicate one otherwise.
677
661
 
686
670
        if 'hpss' in debug.debug_flags:
687
671
            if self._request_start_time is not None:
688
672
                mutter('   result:   %6.3fs  %s',
689
 
                       osutils.timer_func() - self._request_start_time,
 
673
                       time.time() - self._request_start_time,
690
674
                       repr(result)[1:-1])
691
675
                self._request_start_time = None
692
676
            else:
737
721
    def _response_is_unknown_method(self, result_tuple):
738
722
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
739
723
        method' response to the request.
740
 
 
 
724
        
741
725
        :param response: The response from a smart client call_expecting_body
742
726
            call.
743
727
        :param verb: The verb used in that call.
750
734
            # The response will have no body, so we've finished reading.
751
735
            self._request.finished_reading()
752
736
            raise errors.UnknownSmartMethod(self._last_verb)
753
 
 
 
737
        
754
738
    def read_body_bytes(self, count=-1):
755
739
        """Read bytes from the body, decoding into a byte stream.
756
 
 
757
 
        We read all bytes at once to ensure we've checked the trailer for
 
740
        
 
741
        We read all bytes at once to ensure we've checked the trailer for 
758
742
        errors, and then feed the buffer back as read_body_bytes is called.
759
743
        """
760
744
        if self._body_buffer is not None:
798
782
 
799
783
    def _write_protocol_version(self):
800
784
        """Write any prefixes this protocol requires.
801
 
 
 
785
        
802
786
        Version one doesn't send protocol versions.
803
787
        """
804
788
 
805
789
 
806
790
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
807
791
    """Version two of the client side of the smart protocol.
808
 
 
 
792
    
809
793
    This prefixes the request with the value of REQUEST_VERSION_TWO.
810
794
    """
811
795
 
839
823
 
840
824
    def _write_protocol_version(self):
841
825
        """Write any prefixes this protocol requires.
842
 
 
 
826
        
843
827
        Version two sends the value of REQUEST_VERSION_TWO.
844
828
        """
845
829
        self._request.accept_bytes(self.request_marker)
866
850
 
867
851
 
868
852
def build_server_protocol_three(backing_transport, write_func,
869
 
                                root_client_path, jail_root=None):
 
853
                                root_client_path):
870
854
    request_handler = request.SmartServerRequestHandler(
871
855
        backing_transport, commands=request.request_handlers,
872
 
        root_client_path=root_client_path, jail_root=jail_root)
 
856
        root_client_path=root_client_path)
873
857
    responder = ProtocolThreeResponder(write_func)
874
858
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
875
859
    return ProtocolThreeDecoder(message_handler)
905
889
            # We do *not* set self.decoding_failed here.  The message handler
906
890
            # has raised an error, but the decoder is still able to parse bytes
907
891
            # and determine when this message ends.
908
 
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
909
 
                log_exception_quietly()
 
892
            log_exception_quietly()
910
893
            self.message_handler.protocol_error(exception.exc_value)
911
894
            # The state machine is ready to continue decoding, but the
912
895
            # exception has interrupted the loop that runs the state machine.
948
931
    def _extract_prefixed_bencoded_data(self):
949
932
        prefixed_bytes = self._extract_length_prefixed_bytes()
950
933
        try:
951
 
            decoded = bdecode_as_tuple(prefixed_bytes)
 
934
            decoded = bdecode(prefixed_bytes)
952
935
        except ValueError:
953
936
            raise errors.SmartProtocolError(
954
937
                'Bytes %r not bencoded' % (prefixed_bytes,))
994
977
            self.message_handler.headers_received(decoded)
995
978
        except:
996
979
            raise errors.SmartMessageHandlerError(sys.exc_info())
997
 
 
 
980
    
998
981
    def _state_accept_expecting_message_part(self):
999
982
        message_part_kind = self._extract_single_byte()
1000
983
        if message_part_kind == 'o':
1045
1028
            raise errors.SmartMessageHandlerError(sys.exc_info())
1046
1029
 
1047
1030
    def _state_accept_reading_unused(self):
1048
 
        self.unused_data += self._get_in_buffer()
 
1031
        self.unused_data = self._get_in_buffer()
1049
1032
        self._set_in_buffer(None)
1050
1033
 
1051
1034
    def next_read_size(self):
1067
1050
class _ProtocolThreeEncoder(object):
1068
1051
 
1069
1052
    response_marker = request_marker = MESSAGE_VERSION_THREE
1070
 
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1071
1053
 
1072
1054
    def __init__(self, write_func):
1073
 
        self._buf = []
1074
 
        self._buf_len = 0
 
1055
        self._buf = ''
1075
1056
        self._real_write_func = write_func
1076
1057
 
1077
1058
    def _write_func(self, bytes):
1078
 
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
1079
 
        #       for total number of bytes to write, rather than buffer based on
1080
 
        #       the number of write() calls
1081
 
        # TODO: Another possibility would be to turn this into an async model.
1082
 
        #       Where we let another thread know that we have some bytes if
1083
 
        #       they want it, but we don't actually block for it
1084
 
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
1085
 
        #       we might just push out smaller bits at a time?
1086
 
        self._buf.append(bytes)
1087
 
        self._buf_len += len(bytes)
1088
 
        if self._buf_len > self.BUFFER_SIZE:
1089
 
            self.flush()
 
1059
        self._buf += bytes
1090
1060
 
1091
1061
    def flush(self):
1092
1062
        if self._buf:
1093
 
            self._real_write_func(''.join(self._buf))
1094
 
            del self._buf[:]
1095
 
            self._buf_len = 0
 
1063
            self._real_write_func(self._buf)
 
1064
            self._buf = ''
1096
1065
 
1097
1066
    def _serialise_offsets(self, offsets):
1098
1067
        """Serialise a readv offset list."""
1100
1069
        for start, length in offsets:
1101
1070
            txt.append('%d,%d' % (start, length))
1102
1071
        return '\n'.join(txt)
1103
 
 
 
1072
        
1104
1073
    def _write_protocol_version(self):
1105
1074
        self._write_func(MESSAGE_VERSION_THREE)
1106
1075
 
1131
1100
        self._write_func(struct.pack('!L', len(bytes)))
1132
1101
        self._write_func(bytes)
1133
1102
 
1134
 
    def _write_chunked_body_start(self):
1135
 
        self._write_func('oC')
1136
 
 
1137
1103
    def _write_error_status(self):
1138
1104
        self._write_func('oE')
1139
1105
 
1147
1113
        _ProtocolThreeEncoder.__init__(self, write_func)
1148
1114
        self.response_sent = False
1149
1115
        self._headers = {'Software version': bzrlib.__version__}
1150
 
        if 'hpss' in debug.debug_flags:
1151
 
            self._thread_id = thread.get_ident()
1152
 
            self._response_start_time = None
1153
 
 
1154
 
    def _trace(self, action, message, extra_bytes=None, include_time=False):
1155
 
        if self._response_start_time is None:
1156
 
            self._response_start_time = osutils.timer_func()
1157
 
        if include_time:
1158
 
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
1159
 
        else:
1160
 
            t = ''
1161
 
        if extra_bytes is None:
1162
 
            extra = ''
1163
 
        else:
1164
 
            extra = ' ' + repr(extra_bytes[:40])
1165
 
            if len(extra) > 33:
1166
 
                extra = extra[:29] + extra[-1] + '...'
1167
 
        mutter('%12s: [%s] %s%s%s'
1168
 
               % (action, self._thread_id, t, message, extra))
1169
1116
 
1170
1117
    def send_error(self, exception):
1171
1118
        if self.response_sent:
1177
1124
                ('UnknownMethod', exception.verb))
1178
1125
            self.send_response(failure)
1179
1126
            return
1180
 
        if 'hpss' in debug.debug_flags:
1181
 
            self._trace('error', str(exception))
1182
1127
        self.response_sent = True
1183
1128
        self._write_protocol_version()
1184
1129
        self._write_headers(self._headers)
1198
1143
            self._write_success_status()
1199
1144
        else:
1200
1145
            self._write_error_status()
1201
 
        if 'hpss' in debug.debug_flags:
1202
 
            self._trace('response', repr(response.args))
1203
1146
        self._write_structure(response.args)
1204
1147
        if response.body is not None:
1205
1148
            self._write_prefixed_body(response.body)
1206
 
            if 'hpss' in debug.debug_flags:
1207
 
                self._trace('body', '%d bytes' % (len(response.body),),
1208
 
                            response.body, include_time=True)
1209
1149
        elif response.body_stream is not None:
1210
 
            count = num_bytes = 0
1211
 
            first_chunk = None
1212
 
            for exc_info, chunk in _iter_with_errors(response.body_stream):
1213
 
                count += 1
1214
 
                if exc_info is not None:
1215
 
                    self._write_error_status()
1216
 
                    error_struct = request._translate_error(exc_info[1])
1217
 
                    self._write_structure(error_struct)
1218
 
                    break
1219
 
                else:
1220
 
                    if isinstance(chunk, request.FailedSmartServerResponse):
1221
 
                        self._write_error_status()
1222
 
                        self._write_structure(chunk.args)
1223
 
                        break
1224
 
                    num_bytes += len(chunk)
1225
 
                    if first_chunk is None:
1226
 
                        first_chunk = chunk
1227
 
                    self._write_prefixed_body(chunk)
1228
 
                    if 'hpssdetail' in debug.debug_flags:
1229
 
                        # Not worth timing separately, as _write_func is
1230
 
                        # actually buffered
1231
 
                        self._trace('body chunk',
1232
 
                                    '%d bytes' % (len(chunk),),
1233
 
                                    chunk, suppress_time=True)
1234
 
            if 'hpss' in debug.debug_flags:
1235
 
                self._trace('body stream',
1236
 
                            '%d bytes %d chunks' % (num_bytes, count),
1237
 
                            first_chunk)
 
1150
            for chunk in response.body_stream:
 
1151
                self._write_prefixed_body(chunk)
 
1152
                self.flush()
1238
1153
        self._write_end()
1239
 
        if 'hpss' in debug.debug_flags:
1240
 
            self._trace('response end', '', include_time=True)
1241
 
 
1242
 
 
1243
 
def _iter_with_errors(iterable):
1244
 
    """Handle errors from iterable.next().
1245
 
 
1246
 
    Use like::
1247
 
 
1248
 
        for exc_info, value in _iter_with_errors(iterable):
1249
 
            ...
1250
 
 
1251
 
    This is a safer alternative to::
1252
 
 
1253
 
        try:
1254
 
            for value in iterable:
1255
 
               ...
1256
 
        except:
1257
 
            ...
1258
 
 
1259
 
    Because the latter will catch errors from the for-loop body, not just
1260
 
    iterable.next()
1261
 
 
1262
 
    If an error occurs, exc_info will be a exc_info tuple, and the generator
1263
 
    will terminate.  Otherwise exc_info will be None, and value will be the
1264
 
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
1265
 
    will not be itercepted.
1266
 
    """
1267
 
    iterator = iter(iterable)
1268
 
    while True:
1269
 
        try:
1270
 
            yield None, iterator.next()
1271
 
        except StopIteration:
1272
 
            return
1273
 
        except (KeyboardInterrupt, SystemExit):
1274
 
            raise
1275
 
        except Exception:
1276
 
            mutter('_iter_with_errors caught error')
1277
 
            log_exception_quietly()
1278
 
            yield sys.exc_info(), None
1279
 
            return
1280
 
 
 
1154
        
1281
1155
 
1282
1156
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1283
1157
 
1288
1162
 
1289
1163
    def set_headers(self, headers):
1290
1164
        self._headers = headers.copy()
1291
 
 
 
1165
        
1292
1166
    def call(self, *args):
1293
1167
        if 'hpss' in debug.debug_flags:
1294
1168
            mutter('hpss call:   %s', repr(args)[1:-1])
1295
1169
            base = getattr(self._medium_request._medium, 'base', None)
1296
1170
            if base is not None:
1297
1171
                mutter('             (to %s)', base)
1298
 
            self._request_start_time = osutils.timer_func()
 
1172
            self._request_start_time = time.time()
1299
1173
        self._write_protocol_version()
1300
1174
        self._write_headers(self._headers)
1301
1175
        self._write_structure(args)
1313
1187
            if path is not None:
1314
1188
                mutter('                  (to %s)', path)
1315
1189
            mutter('              %d bytes', len(body))
1316
 
            self._request_start_time = osutils.timer_func()
 
1190
            self._request_start_time = time.time()
1317
1191
        self._write_protocol_version()
1318
1192
        self._write_headers(self._headers)
1319
1193
        self._write_structure(args)
1332
1206
            path = getattr(self._medium_request._medium, '_path', None)
1333
1207
            if path is not None:
1334
1208
                mutter('                  (to %s)', path)
1335
 
            self._request_start_time = osutils.timer_func()
 
1209
            self._request_start_time = time.time()
1336
1210
        self._write_protocol_version()
1337
1211
        self._write_headers(self._headers)
1338
1212
        self._write_structure(args)
1343
1217
        self._write_end()
1344
1218
        self._medium_request.finished_writing()
1345
1219
 
1346
 
    def call_with_body_stream(self, args, stream):
1347
 
        if 'hpss' in debug.debug_flags:
1348
 
            mutter('hpss call w/body stream: %r', args)
1349
 
            path = getattr(self._medium_request._medium, '_path', None)
1350
 
            if path is not None:
1351
 
                mutter('                  (to %s)', path)
1352
 
            self._request_start_time = osutils.timer_func()
1353
 
        self._write_protocol_version()
1354
 
        self._write_headers(self._headers)
1355
 
        self._write_structure(args)
1356
 
        # TODO: notice if the server has sent an early error reply before we
1357
 
        #       have finished sending the stream.  We would notice at the end
1358
 
        #       anyway, but if the medium can deliver it early then it's good
1359
 
        #       to short-circuit the whole request...
1360
 
        for exc_info, part in _iter_with_errors(stream):
1361
 
            if exc_info is not None:
1362
 
                # Iterating the stream failed.  Cleanly abort the request.
1363
 
                self._write_error_status()
1364
 
                # Currently the client unconditionally sends ('error',) as the
1365
 
                # error args.
1366
 
                self._write_structure(('error',))
1367
 
                self._write_end()
1368
 
                self._medium_request.finished_writing()
1369
 
                raise exc_info[0], exc_info[1], exc_info[2]
1370
 
            else:
1371
 
                self._write_prefixed_body(part)
1372
 
                self.flush()
1373
 
        self._write_end()
1374
 
        self._medium_request.finished_writing()
1375