~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

(jameinel) Allow 'bzr serve' to interpret SIGHUP as a graceful shutdown.
 (bug #795025) (John A Meinel)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
 
25
import thread
 
26
import threading
25
27
import time
26
28
 
27
29
import bzrlib
28
 
from bzrlib import debug
29
 
from bzrlib import errors
 
30
from bzrlib import (
 
31
    debug,
 
32
    errors,
 
33
    osutils,
 
34
    )
30
35
from bzrlib.smart import message, request
31
36
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode_as_tuple, bencode
 
37
from bzrlib.bencode import bdecode_as_tuple, bencode
33
38
 
34
39
 
35
40
# Protocol version strings.  These are sent as prefixes of bzr requests and
57
62
 
58
63
def _encode_tuple(args):
59
64
    """Encode the tuple args to a bytestream."""
60
 
    return '\x01'.join(args) + '\n'
 
65
    joined = '\x01'.join(args) + '\n'
 
66
    if type(joined) is unicode:
 
67
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
68
        mutter('response args contain unicode, should be only bytes: %r',
 
69
               joined)
 
70
        joined = joined.encode('ascii')
 
71
    return joined
61
72
 
62
73
 
63
74
class Requester(object):
109
120
        for start, length in offsets:
110
121
            txt.append('%d,%d' % (start, length))
111
122
        return '\n'.join(txt)
112
 
        
 
123
 
113
124
 
114
125
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
126
    """Server-side encoding and decoding logic for smart version 1."""
116
 
    
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
127
 
 
128
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
129
            jail_root=None):
118
130
        self._backing_transport = backing_transport
119
131
        self._root_client_path = root_client_path
 
132
        self._jail_root = jail_root
120
133
        self.unused_data = ''
121
134
        self._finished = False
122
135
        self.in_buffer = ''
127
140
 
128
141
    def accept_bytes(self, bytes):
129
142
        """Take bytes, and advance the internal state machine appropriately.
130
 
        
 
143
 
131
144
        :param bytes: must be a byte string
132
145
        """
133
146
        if not isinstance(bytes, str):
144
157
                req_args = _decode_tuple(first_line)
145
158
                self.request = request.SmartServerRequestHandler(
146
159
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
148
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
160
                    root_client_path=self._root_client_path,
 
161
                    jail_root=self._jail_root)
 
162
                self.request.args_received(req_args)
149
163
                if self.request.finished_reading:
150
164
                    # trivial request
151
165
                    self.unused_data = self.in_buffer
169
183
 
170
184
        if self._has_dispatched:
171
185
            if self._finished:
172
 
                # nothing to do.XXX: this routine should be a single state 
 
186
                # nothing to do.XXX: this routine should be a single state
173
187
                # machine too.
174
188
                self.unused_data += self.in_buffer
175
189
                self.in_buffer = ''
211
225
 
212
226
    def _write_protocol_version(self):
213
227
        """Write any prefixes this protocol requires.
214
 
        
 
228
 
215
229
        Version one doesn't send protocol versions.
216
230
        """
217
231
 
234
248
 
235
249
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
236
250
    r"""Version two of the server side of the smart protocol.
237
 
   
 
251
 
238
252
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
253
    """
240
254
 
250
264
 
251
265
    def _write_protocol_version(self):
252
266
        r"""Write any prefixes this protocol requires.
253
 
        
 
267
 
254
268
        Version two sends the value of RESPONSE_VERSION_TWO.
255
269
        """
256
270
        self._write_func(self.response_marker)
412
426
        self.chunks = collections.deque()
413
427
        self.error = False
414
428
        self.error_in_progress = None
415
 
    
 
429
 
416
430
    def next_read_size(self):
417
431
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
418
432
        # end-of-body marker is 4 bytes: 'END\n'.
506
520
                self.chunks.append(self.chunk_in_progress)
507
521
            self.chunk_in_progress = None
508
522
            self.state_accept = self._state_accept_expecting_length
509
 
        
 
523
 
510
524
    def _state_accept_reading_unused(self):
511
525
        self.unused_data += self._get_in_buffer()
512
526
        self._in_buffer_list = []
514
528
 
515
529
class LengthPrefixedBodyDecoder(_StatefulDecoder):
516
530
    """Decodes the length-prefixed bulk data."""
517
 
    
 
531
 
518
532
    def __init__(self):
519
533
        _StatefulDecoder.__init__(self)
520
534
        self.state_accept = self._state_accept_expecting_length
521
535
        self.state_read = self._state_read_no_data
522
536
        self._body = ''
523
537
        self._trailer_buffer = ''
524
 
    
 
538
 
525
539
    def next_read_size(self):
526
540
        if self.bytes_left is not None:
527
541
            # Ideally we want to read all the remainder of the body and the
537
551
        else:
538
552
            # Reading excess data.  Either way, 1 byte at a time is fine.
539
553
            return 1
540
 
        
 
554
 
541
555
    def read_pending_data(self):
542
556
        """Return any pending data that has been decoded."""
543
557
        return self.state_read()
564
578
                self._body = self._body[:self.bytes_left]
565
579
            self.bytes_left = None
566
580
            self.state_accept = self._state_accept_reading_trailer
567
 
        
 
581
 
568
582
    def _state_accept_reading_trailer(self):
569
583
        self._trailer_buffer += self._get_in_buffer()
570
584
        self._set_in_buffer(None)
574
588
            self.unused_data = self._trailer_buffer[len('done\n'):]
575
589
            self.state_accept = self._state_accept_reading_unused
576
590
            self.finished_reading = True
577
 
    
 
591
 
578
592
    def _state_accept_reading_unused(self):
579
593
        self.unused_data += self._get_in_buffer()
580
594
        self._set_in_buffer(None)
612
626
            mutter('hpss call:   %s', repr(args)[1:-1])
613
627
            if getattr(self._request._medium, 'base', None) is not None:
614
628
                mutter('             (to %s)', self._request._medium.base)
615
 
            self._request_start_time = time.time()
 
629
            self._request_start_time = osutils.timer_func()
616
630
        self._write_args(args)
617
631
        self._request.finished_writing()
618
632
        self._last_verb = args[0]
627
641
            if getattr(self._request._medium, '_path', None) is not None:
628
642
                mutter('                  (to %s)', self._request._medium._path)
629
643
            mutter('              %d bytes', len(body))
630
 
            self._request_start_time = time.time()
 
644
            self._request_start_time = osutils.timer_func()
631
645
            if 'hpssdetail' in debug.debug_flags:
632
646
                mutter('hpss body content: %s', body)
633
647
        self._write_args(args)
640
654
        """Make a remote call with a readv array.
641
655
 
642
656
        The body is encoded with one line per readv offset pair. The numbers in
643
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
657
        each pair are separated by a comma, and no trailing \\n is emitted.
644
658
        """
645
659
        if 'hpss' in debug.debug_flags:
646
660
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
647
661
            if getattr(self._request._medium, '_path', None) is not None:
648
662
                mutter('                  (to %s)', self._request._medium._path)
649
 
            self._request_start_time = time.time()
 
663
            self._request_start_time = osutils.timer_func()
650
664
        self._write_args(args)
651
665
        readv_bytes = self._serialise_offsets(body)
652
666
        bytes = self._encode_bulk_data(readv_bytes)
655
669
        if 'hpss' in debug.debug_flags:
656
670
            mutter('              %d bytes in readv request', len(readv_bytes))
657
671
        self._last_verb = args[0]
658
 
    
 
672
 
659
673
    def call_with_body_stream(self, args, stream):
660
674
        # Protocols v1 and v2 don't support body streams.  So it's safe to
661
675
        # assume that a v1/v2 server doesn't support whatever method we're
678
692
        if 'hpss' in debug.debug_flags:
679
693
            if self._request_start_time is not None:
680
694
                mutter('   result:   %6.3fs  %s',
681
 
                       time.time() - self._request_start_time,
 
695
                       osutils.timer_func() - self._request_start_time,
682
696
                       repr(result)[1:-1])
683
697
                self._request_start_time = None
684
698
            else:
729
743
    def _response_is_unknown_method(self, result_tuple):
730
744
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
731
745
        method' response to the request.
732
 
        
 
746
 
733
747
        :param response: The response from a smart client call_expecting_body
734
748
            call.
735
749
        :param verb: The verb used in that call.
742
756
            # The response will have no body, so we've finished reading.
743
757
            self._request.finished_reading()
744
758
            raise errors.UnknownSmartMethod(self._last_verb)
745
 
        
 
759
 
746
760
    def read_body_bytes(self, count=-1):
747
761
        """Read bytes from the body, decoding into a byte stream.
748
 
        
749
 
        We read all bytes at once to ensure we've checked the trailer for 
 
762
 
 
763
        We read all bytes at once to ensure we've checked the trailer for
750
764
        errors, and then feed the buffer back as read_body_bytes is called.
751
765
        """
752
766
        if self._body_buffer is not None:
790
804
 
791
805
    def _write_protocol_version(self):
792
806
        """Write any prefixes this protocol requires.
793
 
        
 
807
 
794
808
        Version one doesn't send protocol versions.
795
809
        """
796
810
 
797
811
 
798
812
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
799
813
    """Version two of the client side of the smart protocol.
800
 
    
 
814
 
801
815
    This prefixes the request with the value of REQUEST_VERSION_TWO.
802
816
    """
803
817
 
831
845
 
832
846
    def _write_protocol_version(self):
833
847
        """Write any prefixes this protocol requires.
834
 
        
 
848
 
835
849
        Version two sends the value of REQUEST_VERSION_TWO.
836
850
        """
837
851
        self._request.accept_bytes(self.request_marker)
858
872
 
859
873
 
860
874
def build_server_protocol_three(backing_transport, write_func,
861
 
                                root_client_path):
 
875
                                root_client_path, jail_root=None):
862
876
    request_handler = request.SmartServerRequestHandler(
863
877
        backing_transport, commands=request.request_handlers,
864
 
        root_client_path=root_client_path)
 
878
        root_client_path=root_client_path, jail_root=jail_root)
865
879
    responder = ProtocolThreeResponder(write_func)
866
880
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
867
881
    return ProtocolThreeDecoder(message_handler)
897
911
            # We do *not* set self.decoding_failed here.  The message handler
898
912
            # has raised an error, but the decoder is still able to parse bytes
899
913
            # and determine when this message ends.
900
 
            log_exception_quietly()
 
914
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
915
                log_exception_quietly()
901
916
            self.message_handler.protocol_error(exception.exc_value)
902
917
            # The state machine is ready to continue decoding, but the
903
918
            # exception has interrupted the loop that runs the state machine.
985
1000
            self.message_handler.headers_received(decoded)
986
1001
        except:
987
1002
            raise errors.SmartMessageHandlerError(sys.exc_info())
988
 
    
 
1003
 
989
1004
    def _state_accept_expecting_message_part(self):
990
1005
        message_part_kind = self._extract_single_byte()
991
1006
        if message_part_kind == 'o':
1036
1051
            raise errors.SmartMessageHandlerError(sys.exc_info())
1037
1052
 
1038
1053
    def _state_accept_reading_unused(self):
1039
 
        self.unused_data = self._get_in_buffer()
 
1054
        self.unused_data += self._get_in_buffer()
1040
1055
        self._set_in_buffer(None)
1041
1056
 
1042
1057
    def next_read_size(self):
1058
1073
class _ProtocolThreeEncoder(object):
1059
1074
 
1060
1075
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1076
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1061
1077
 
1062
1078
    def __init__(self, write_func):
1063
 
        self._buf = ''
 
1079
        self._buf = []
 
1080
        self._buf_len = 0
1064
1081
        self._real_write_func = write_func
1065
1082
 
1066
1083
    def _write_func(self, bytes):
1067
 
        self._buf += bytes
 
1084
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
 
1085
        #       for total number of bytes to write, rather than buffer based on
 
1086
        #       the number of write() calls
 
1087
        # TODO: Another possibility would be to turn this into an async model.
 
1088
        #       Where we let another thread know that we have some bytes if
 
1089
        #       they want it, but we don't actually block for it
 
1090
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1091
        #       we might just push out smaller bits at a time?
 
1092
        self._buf.append(bytes)
 
1093
        self._buf_len += len(bytes)
 
1094
        if self._buf_len > self.BUFFER_SIZE:
 
1095
            self.flush()
1068
1096
 
1069
1097
    def flush(self):
1070
1098
        if self._buf:
1071
 
            self._real_write_func(self._buf)
1072
 
            self._buf = ''
 
1099
            self._real_write_func(''.join(self._buf))
 
1100
            del self._buf[:]
 
1101
            self._buf_len = 0
1073
1102
 
1074
1103
    def _serialise_offsets(self, offsets):
1075
1104
        """Serialise a readv offset list."""
1077
1106
        for start, length in offsets:
1078
1107
            txt.append('%d,%d' % (start, length))
1079
1108
        return '\n'.join(txt)
1080
 
        
 
1109
 
1081
1110
    def _write_protocol_version(self):
1082
1111
        self._write_func(MESSAGE_VERSION_THREE)
1083
1112
 
1124
1153
        _ProtocolThreeEncoder.__init__(self, write_func)
1125
1154
        self.response_sent = False
1126
1155
        self._headers = {'Software version': bzrlib.__version__}
 
1156
        if 'hpss' in debug.debug_flags:
 
1157
            self._thread_id = thread.get_ident()
 
1158
            self._response_start_time = None
 
1159
 
 
1160
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1161
        if self._response_start_time is None:
 
1162
            self._response_start_time = osutils.timer_func()
 
1163
        if include_time:
 
1164
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1165
        else:
 
1166
            t = ''
 
1167
        if extra_bytes is None:
 
1168
            extra = ''
 
1169
        else:
 
1170
            extra = ' ' + repr(extra_bytes[:40])
 
1171
            if len(extra) > 33:
 
1172
                extra = extra[:29] + extra[-1] + '...'
 
1173
        mutter('%12s: [%s] %s%s%s'
 
1174
               % (action, self._thread_id, t, message, extra))
1127
1175
 
1128
1176
    def send_error(self, exception):
1129
1177
        if self.response_sent:
1135
1183
                ('UnknownMethod', exception.verb))
1136
1184
            self.send_response(failure)
1137
1185
            return
 
1186
        if 'hpss' in debug.debug_flags:
 
1187
            self._trace('error', str(exception))
1138
1188
        self.response_sent = True
1139
1189
        self._write_protocol_version()
1140
1190
        self._write_headers(self._headers)
1154
1204
            self._write_success_status()
1155
1205
        else:
1156
1206
            self._write_error_status()
 
1207
        if 'hpss' in debug.debug_flags:
 
1208
            self._trace('response', repr(response.args))
1157
1209
        self._write_structure(response.args)
1158
1210
        if response.body is not None:
1159
1211
            self._write_prefixed_body(response.body)
 
1212
            if 'hpss' in debug.debug_flags:
 
1213
                self._trace('body', '%d bytes' % (len(response.body),),
 
1214
                            response.body, include_time=True)
1160
1215
        elif response.body_stream is not None:
1161
 
            for chunk in response.body_stream:
1162
 
                self._write_prefixed_body(chunk)
1163
 
                self.flush()
 
1216
            count = num_bytes = 0
 
1217
            first_chunk = None
 
1218
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1219
                count += 1
 
1220
                if exc_info is not None:
 
1221
                    self._write_error_status()
 
1222
                    error_struct = request._translate_error(exc_info[1])
 
1223
                    self._write_structure(error_struct)
 
1224
                    break
 
1225
                else:
 
1226
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1227
                        self._write_error_status()
 
1228
                        self._write_structure(chunk.args)
 
1229
                        break
 
1230
                    num_bytes += len(chunk)
 
1231
                    if first_chunk is None:
 
1232
                        first_chunk = chunk
 
1233
                    self._write_prefixed_body(chunk)
 
1234
                    self.flush()
 
1235
                    if 'hpssdetail' in debug.debug_flags:
 
1236
                        # Not worth timing separately, as _write_func is
 
1237
                        # actually buffered
 
1238
                        self._trace('body chunk',
 
1239
                                    '%d bytes' % (len(chunk),),
 
1240
                                    chunk, suppress_time=True)
 
1241
            if 'hpss' in debug.debug_flags:
 
1242
                self._trace('body stream',
 
1243
                            '%d bytes %d chunks' % (num_bytes, count),
 
1244
                            first_chunk)
1164
1245
        self._write_end()
1165
 
        
 
1246
        if 'hpss' in debug.debug_flags:
 
1247
            self._trace('response end', '', include_time=True)
 
1248
 
 
1249
 
 
1250
def _iter_with_errors(iterable):
 
1251
    """Handle errors from iterable.next().
 
1252
 
 
1253
    Use like::
 
1254
 
 
1255
        for exc_info, value in _iter_with_errors(iterable):
 
1256
            ...
 
1257
 
 
1258
    This is a safer alternative to::
 
1259
 
 
1260
        try:
 
1261
            for value in iterable:
 
1262
               ...
 
1263
        except:
 
1264
            ...
 
1265
 
 
1266
    Because the latter will catch errors from the for-loop body, not just
 
1267
    iterable.next()
 
1268
 
 
1269
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1270
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1271
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1272
    will not be itercepted.
 
1273
    """
 
1274
    iterator = iter(iterable)
 
1275
    while True:
 
1276
        try:
 
1277
            yield None, iterator.next()
 
1278
        except StopIteration:
 
1279
            return
 
1280
        except (KeyboardInterrupt, SystemExit):
 
1281
            raise
 
1282
        except Exception:
 
1283
            mutter('_iter_with_errors caught error')
 
1284
            log_exception_quietly()
 
1285
            yield sys.exc_info(), None
 
1286
            return
 
1287
 
1166
1288
 
1167
1289
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1168
1290
 
1173
1295
 
1174
1296
    def set_headers(self, headers):
1175
1297
        self._headers = headers.copy()
1176
 
        
 
1298
 
1177
1299
    def call(self, *args):
1178
1300
        if 'hpss' in debug.debug_flags:
1179
1301
            mutter('hpss call:   %s', repr(args)[1:-1])
1180
1302
            base = getattr(self._medium_request._medium, 'base', None)
1181
1303
            if base is not None:
1182
1304
                mutter('             (to %s)', base)
1183
 
            self._request_start_time = time.time()
 
1305
            self._request_start_time = osutils.timer_func()
1184
1306
        self._write_protocol_version()
1185
1307
        self._write_headers(self._headers)
1186
1308
        self._write_structure(args)
1198
1320
            if path is not None:
1199
1321
                mutter('                  (to %s)', path)
1200
1322
            mutter('              %d bytes', len(body))
1201
 
            self._request_start_time = time.time()
 
1323
            self._request_start_time = osutils.timer_func()
1202
1324
        self._write_protocol_version()
1203
1325
        self._write_headers(self._headers)
1204
1326
        self._write_structure(args)
1210
1332
        """Make a remote call with a readv array.
1211
1333
 
1212
1334
        The body is encoded with one line per readv offset pair. The numbers in
1213
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
1335
        each pair are separated by a comma, and no trailing \\n is emitted.
1214
1336
        """
1215
1337
        if 'hpss' in debug.debug_flags:
1216
1338
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
1217
1339
            path = getattr(self._medium_request._medium, '_path', None)
1218
1340
            if path is not None:
1219
1341
                mutter('                  (to %s)', path)
1220
 
            self._request_start_time = time.time()
 
1342
            self._request_start_time = osutils.timer_func()
1221
1343
        self._write_protocol_version()
1222
1344
        self._write_headers(self._headers)
1223
1345
        self._write_structure(args)
1234
1356
            path = getattr(self._medium_request._medium, '_path', None)
1235
1357
            if path is not None:
1236
1358
                mutter('                  (to %s)', path)
1237
 
            self._request_start_time = time.time()
 
1359
            self._request_start_time = osutils.timer_func()
1238
1360
        self._write_protocol_version()
1239
1361
        self._write_headers(self._headers)
1240
1362
        self._write_structure(args)
1242
1364
        #       have finished sending the stream.  We would notice at the end
1243
1365
        #       anyway, but if the medium can deliver it early then it's good
1244
1366
        #       to short-circuit the whole request...
1245
 
        try:
1246
 
            for part in stream:
 
1367
        for exc_info, part in _iter_with_errors(stream):
 
1368
            if exc_info is not None:
 
1369
                # Iterating the stream failed.  Cleanly abort the request.
 
1370
                self._write_error_status()
 
1371
                # Currently the client unconditionally sends ('error',) as the
 
1372
                # error args.
 
1373
                self._write_structure(('error',))
 
1374
                self._write_end()
 
1375
                self._medium_request.finished_writing()
 
1376
                raise exc_info[0], exc_info[1], exc_info[2]
 
1377
            else:
1247
1378
                self._write_prefixed_body(part)
1248
1379
                self.flush()
1249
 
        except Exception:
1250
 
            # Iterating the stream failed.  Cleanly abort the request.
1251
 
            self._write_error_status()
1252
 
            # Currently the client unconditionally sends ('error',) as the
1253
 
            # error args.
1254
 
            self._write_structure(('error',))
1255
1380
        self._write_end()
1256
1381
        self._medium_request.finished_writing()
1257
1382