~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2008-07-17 07:33:12 UTC
  • mfrom: (3530.3.3 btree-graphindex)
  • Revision ID: pqm@pqm.ubuntu.com-20080717073312-reglpowwyo671081
(robertc) Intern GraphIndex strings and handle frozenset inputs to
        make_mpdiffs in the case of errors. (Robert Collins, Andrew Bennetts)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
22
22
from cStringIO import StringIO
23
23
import struct
24
24
import sys
25
 
import thread
26
 
import threading
27
25
import time
28
26
 
29
27
import bzrlib
30
 
from bzrlib import (
31
 
    debug,
32
 
    errors,
33
 
    osutils,
34
 
    )
 
28
from bzrlib import debug
 
29
from bzrlib import errors
35
30
from bzrlib.smart import message, request
36
31
from bzrlib.trace import log_exception_quietly, mutter
37
 
from bzrlib.bencode import bdecode_as_tuple, bencode
 
32
from bzrlib.util.bencode import bdecode, bencode
38
33
 
39
34
 
40
35
# Protocol version strings.  These are sent as prefixes of bzr requests and
62
57
 
63
58
def _encode_tuple(args):
64
59
    """Encode the tuple args to a bytestream."""
65
 
    joined = '\x01'.join(args) + '\n'
66
 
    if type(joined) is unicode:
67
 
        # XXX: We should fix things so this never happens!  -AJB, 20100304
68
 
        mutter('response args contain unicode, should be only bytes: %r',
69
 
               joined)
70
 
        joined = joined.encode('ascii')
71
 
    return joined
 
60
    return '\x01'.join(args) + '\n'
72
61
 
73
62
 
74
63
class Requester(object):
120
109
        for start, length in offsets:
121
110
            txt.append('%d,%d' % (start, length))
122
111
        return '\n'.join(txt)
123
 
 
 
112
        
124
113
 
125
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
126
115
    """Server-side encoding and decoding logic for smart version 1."""
127
 
 
128
 
    def __init__(self, backing_transport, write_func, root_client_path='/',
129
 
            jail_root=None):
 
116
    
 
117
    def __init__(self, backing_transport, write_func, root_client_path='/'):
130
118
        self._backing_transport = backing_transport
131
119
        self._root_client_path = root_client_path
132
 
        self._jail_root = jail_root
133
120
        self.unused_data = ''
134
121
        self._finished = False
135
122
        self.in_buffer = ''
140
127
 
141
128
    def accept_bytes(self, bytes):
142
129
        """Take bytes, and advance the internal state machine appropriately.
143
 
 
 
130
        
144
131
        :param bytes: must be a byte string
145
132
        """
146
133
        if not isinstance(bytes, str):
157
144
                req_args = _decode_tuple(first_line)
158
145
                self.request = request.SmartServerRequestHandler(
159
146
                    self._backing_transport, commands=request.request_handlers,
160
 
                    root_client_path=self._root_client_path,
161
 
                    jail_root=self._jail_root)
162
 
                self.request.args_received(req_args)
 
147
                    root_client_path=self._root_client_path)
 
148
                self.request.dispatch_command(req_args[0], req_args[1:])
163
149
                if self.request.finished_reading:
164
150
                    # trivial request
165
151
                    self.unused_data = self.in_buffer
183
169
 
184
170
        if self._has_dispatched:
185
171
            if self._finished:
186
 
                # nothing to do.XXX: this routine should be a single state
 
172
                # nothing to do.XXX: this routine should be a single state 
187
173
                # machine too.
188
174
                self.unused_data += self.in_buffer
189
175
                self.in_buffer = ''
225
211
 
226
212
    def _write_protocol_version(self):
227
213
        """Write any prefixes this protocol requires.
228
 
 
 
214
        
229
215
        Version one doesn't send protocol versions.
230
216
        """
231
217
 
248
234
 
249
235
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
250
236
    r"""Version two of the server side of the smart protocol.
251
 
 
 
237
   
252
238
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
253
239
    """
254
240
 
264
250
 
265
251
    def _write_protocol_version(self):
266
252
        r"""Write any prefixes this protocol requires.
267
 
 
 
253
        
268
254
        Version two sends the value of RESPONSE_VERSION_TWO.
269
255
        """
270
256
        self._write_func(self.response_marker)
337
323
 
338
324
    def __init__(self):
339
325
        self.finished_reading = False
340
 
        self._in_buffer_list = []
341
 
        self._in_buffer_len = 0
 
326
        self._in_buffer = ''
342
327
        self.unused_data = ''
343
328
        self.bytes_left = None
344
329
        self._number_needed_bytes = None
345
330
 
346
 
    def _get_in_buffer(self):
347
 
        if len(self._in_buffer_list) == 1:
348
 
            return self._in_buffer_list[0]
349
 
        in_buffer = ''.join(self._in_buffer_list)
350
 
        if len(in_buffer) != self._in_buffer_len:
351
 
            raise AssertionError(
352
 
                "Length of buffer did not match expected value: %s != %s"
353
 
                % self._in_buffer_len, len(in_buffer))
354
 
        self._in_buffer_list = [in_buffer]
355
 
        return in_buffer
356
 
 
357
 
    def _get_in_bytes(self, count):
358
 
        """Grab X bytes from the input_buffer.
359
 
 
360
 
        Callers should have already checked that self._in_buffer_len is >
361
 
        count. Note, this does not consume the bytes from the buffer. The
362
 
        caller will still need to call _get_in_buffer() and then
363
 
        _set_in_buffer() if they actually need to consume the bytes.
364
 
        """
365
 
        # check if we can yield the bytes from just the first entry in our list
366
 
        if len(self._in_buffer_list) == 0:
367
 
            raise AssertionError('Callers must be sure we have buffered bytes'
368
 
                ' before calling _get_in_bytes')
369
 
        if len(self._in_buffer_list[0]) > count:
370
 
            return self._in_buffer_list[0][:count]
371
 
        # We can't yield it from the first buffer, so collapse all buffers, and
372
 
        # yield it from that
373
 
        in_buf = self._get_in_buffer()
374
 
        return in_buf[:count]
375
 
 
376
 
    def _set_in_buffer(self, new_buf):
377
 
        if new_buf is not None:
378
 
            self._in_buffer_list = [new_buf]
379
 
            self._in_buffer_len = len(new_buf)
380
 
        else:
381
 
            self._in_buffer_list = []
382
 
            self._in_buffer_len = 0
383
 
 
384
331
    def accept_bytes(self, bytes):
385
332
        """Decode as much of bytes as possible.
386
333
 
391
338
        data will be appended to self.unused_data.
392
339
        """
393
340
        # accept_bytes is allowed to change the state
 
341
        current_state = self.state_accept
394
342
        self._number_needed_bytes = None
395
 
        # lsprof puts a very large amount of time on this specific call for
396
 
        # large readv arrays
397
 
        self._in_buffer_list.append(bytes)
398
 
        self._in_buffer_len += len(bytes)
 
343
        self._in_buffer += bytes
399
344
        try:
400
345
            # Run the function for the current state.
401
 
            current_state = self.state_accept
402
346
            self.state_accept()
403
347
            while current_state != self.state_accept:
404
348
                # The current state has changed.  Run the function for the new
426
370
        self.chunks = collections.deque()
427
371
        self.error = False
428
372
        self.error_in_progress = None
429
 
 
 
373
    
430
374
    def next_read_size(self):
431
375
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
432
376
        # end-of-body marker is 4 bytes: 'END\n'.
435
379
            # the rest of this chunk plus an END chunk.
436
380
            return self.bytes_left + 4
437
381
        elif self.state_accept == self._state_accept_expecting_length:
438
 
            if self._in_buffer_len == 0:
 
382
            if self._in_buffer == '':
439
383
                # We're expecting a chunk length.  There's at least two bytes
440
384
                # left: a digit plus '\n'.
441
385
                return 2
446
390
        elif self.state_accept == self._state_accept_reading_unused:
447
391
            return 1
448
392
        elif self.state_accept == self._state_accept_expecting_header:
449
 
            return max(0, len('chunked\n') - self._in_buffer_len)
 
393
            return max(0, len('chunked\n') - len(self._in_buffer))
450
394
        else:
451
395
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
452
396
 
457
401
            return None
458
402
 
459
403
    def _extract_line(self):
460
 
        in_buf = self._get_in_buffer()
461
 
        pos = in_buf.find('\n')
 
404
        pos = self._in_buffer.find('\n')
462
405
        if pos == -1:
463
406
            # We haven't read a complete line yet, so request more bytes before
464
407
            # we continue.
465
408
            raise _NeedMoreBytes(1)
466
 
        line = in_buf[:pos]
 
409
        line = self._in_buffer[:pos]
467
410
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
468
 
        self._set_in_buffer(in_buf[pos+1:])
 
411
        self._in_buffer = self._in_buffer[pos+1:]
469
412
        return line
470
413
 
471
414
    def _finished(self):
472
 
        self.unused_data = self._get_in_buffer()
473
 
        self._in_buffer_list = []
474
 
        self._in_buffer_len = 0
 
415
        self.unused_data = self._in_buffer
 
416
        self._in_buffer = ''
475
417
        self.state_accept = self._state_accept_reading_unused
476
418
        if self.error:
477
419
            error_args = tuple(self.error_in_progress)
506
448
            self.state_accept = self._state_accept_reading_chunk
507
449
 
508
450
    def _state_accept_reading_chunk(self):
509
 
        in_buf = self._get_in_buffer()
510
 
        in_buffer_len = len(in_buf)
511
 
        self.chunk_in_progress += in_buf[:self.bytes_left]
512
 
        self._set_in_buffer(in_buf[self.bytes_left:])
 
451
        in_buffer_len = len(self._in_buffer)
 
452
        self.chunk_in_progress += self._in_buffer[:self.bytes_left]
 
453
        self._in_buffer = self._in_buffer[self.bytes_left:]
513
454
        self.bytes_left -= in_buffer_len
514
455
        if self.bytes_left <= 0:
515
456
            # Finished with chunk
520
461
                self.chunks.append(self.chunk_in_progress)
521
462
            self.chunk_in_progress = None
522
463
            self.state_accept = self._state_accept_expecting_length
523
 
 
 
464
        
524
465
    def _state_accept_reading_unused(self):
525
 
        self.unused_data += self._get_in_buffer()
526
 
        self._in_buffer_list = []
 
466
        self.unused_data += self._in_buffer
 
467
        self._in_buffer = ''
527
468
 
528
469
 
529
470
class LengthPrefixedBodyDecoder(_StatefulDecoder):
530
471
    """Decodes the length-prefixed bulk data."""
531
 
 
 
472
    
532
473
    def __init__(self):
533
474
        _StatefulDecoder.__init__(self)
534
475
        self.state_accept = self._state_accept_expecting_length
535
476
        self.state_read = self._state_read_no_data
536
477
        self._body = ''
537
478
        self._trailer_buffer = ''
538
 
 
 
479
    
539
480
    def next_read_size(self):
540
481
        if self.bytes_left is not None:
541
482
            # Ideally we want to read all the remainder of the body and the
551
492
        else:
552
493
            # Reading excess data.  Either way, 1 byte at a time is fine.
553
494
            return 1
554
 
 
 
495
        
555
496
    def read_pending_data(self):
556
497
        """Return any pending data that has been decoded."""
557
498
        return self.state_read()
558
499
 
559
500
    def _state_accept_expecting_length(self):
560
 
        in_buf = self._get_in_buffer()
561
 
        pos = in_buf.find('\n')
 
501
        pos = self._in_buffer.find('\n')
562
502
        if pos == -1:
563
503
            return
564
 
        self.bytes_left = int(in_buf[:pos])
565
 
        self._set_in_buffer(in_buf[pos+1:])
 
504
        self.bytes_left = int(self._in_buffer[:pos])
 
505
        self._in_buffer = self._in_buffer[pos+1:]
566
506
        self.state_accept = self._state_accept_reading_body
567
507
        self.state_read = self._state_read_body_buffer
568
508
 
569
509
    def _state_accept_reading_body(self):
570
 
        in_buf = self._get_in_buffer()
571
 
        self._body += in_buf
572
 
        self.bytes_left -= len(in_buf)
573
 
        self._set_in_buffer(None)
 
510
        self._body += self._in_buffer
 
511
        self.bytes_left -= len(self._in_buffer)
 
512
        self._in_buffer = ''
574
513
        if self.bytes_left <= 0:
575
514
            # Finished with body
576
515
            if self.bytes_left != 0:
578
517
                self._body = self._body[:self.bytes_left]
579
518
            self.bytes_left = None
580
519
            self.state_accept = self._state_accept_reading_trailer
581
 
 
 
520
        
582
521
    def _state_accept_reading_trailer(self):
583
 
        self._trailer_buffer += self._get_in_buffer()
584
 
        self._set_in_buffer(None)
 
522
        self._trailer_buffer += self._in_buffer
 
523
        self._in_buffer = ''
585
524
        # TODO: what if the trailer does not match "done\n"?  Should this raise
586
525
        # a ProtocolViolation exception?
587
526
        if self._trailer_buffer.startswith('done\n'):
588
527
            self.unused_data = self._trailer_buffer[len('done\n'):]
589
528
            self.state_accept = self._state_accept_reading_unused
590
529
            self.finished_reading = True
591
 
 
 
530
    
592
531
    def _state_accept_reading_unused(self):
593
 
        self.unused_data += self._get_in_buffer()
594
 
        self._set_in_buffer(None)
 
532
        self.unused_data += self._in_buffer
 
533
        self._in_buffer = ''
595
534
 
596
535
    def _state_read_no_data(self):
597
536
        return ''
626
565
            mutter('hpss call:   %s', repr(args)[1:-1])
627
566
            if getattr(self._request._medium, 'base', None) is not None:
628
567
                mutter('             (to %s)', self._request._medium.base)
629
 
            self._request_start_time = osutils.timer_func()
 
568
            self._request_start_time = time.time()
630
569
        self._write_args(args)
631
570
        self._request.finished_writing()
632
571
        self._last_verb = args[0]
641
580
            if getattr(self._request._medium, '_path', None) is not None:
642
581
                mutter('                  (to %s)', self._request._medium._path)
643
582
            mutter('              %d bytes', len(body))
644
 
            self._request_start_time = osutils.timer_func()
 
583
            self._request_start_time = time.time()
645
584
            if 'hpssdetail' in debug.debug_flags:
646
585
                mutter('hpss body content: %s', body)
647
586
        self._write_args(args)
660
599
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
661
600
            if getattr(self._request._medium, '_path', None) is not None:
662
601
                mutter('                  (to %s)', self._request._medium._path)
663
 
            self._request_start_time = osutils.timer_func()
 
602
            self._request_start_time = time.time()
664
603
        self._write_args(args)
665
604
        readv_bytes = self._serialise_offsets(body)
666
605
        bytes = self._encode_bulk_data(readv_bytes)
670
609
            mutter('              %d bytes in readv request', len(readv_bytes))
671
610
        self._last_verb = args[0]
672
611
 
673
 
    def call_with_body_stream(self, args, stream):
674
 
        # Protocols v1 and v2 don't support body streams.  So it's safe to
675
 
        # assume that a v1/v2 server doesn't support whatever method we're
676
 
        # trying to call with a body stream.
677
 
        self._request.finished_writing()
678
 
        self._request.finished_reading()
679
 
        raise errors.UnknownSmartMethod(args[0])
680
 
 
681
612
    def cancel_read_body(self):
682
613
        """After expecting a body, a response code may indicate one otherwise.
683
614
 
692
623
        if 'hpss' in debug.debug_flags:
693
624
            if self._request_start_time is not None:
694
625
                mutter('   result:   %6.3fs  %s',
695
 
                       osutils.timer_func() - self._request_start_time,
 
626
                       time.time() - self._request_start_time,
696
627
                       repr(result)[1:-1])
697
628
                self._request_start_time = None
698
629
            else:
743
674
    def _response_is_unknown_method(self, result_tuple):
744
675
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
745
676
        method' response to the request.
746
 
 
 
677
        
747
678
        :param response: The response from a smart client call_expecting_body
748
679
            call.
749
680
        :param verb: The verb used in that call.
756
687
            # The response will have no body, so we've finished reading.
757
688
            self._request.finished_reading()
758
689
            raise errors.UnknownSmartMethod(self._last_verb)
759
 
 
 
690
        
760
691
    def read_body_bytes(self, count=-1):
761
692
        """Read bytes from the body, decoding into a byte stream.
762
 
 
763
 
        We read all bytes at once to ensure we've checked the trailer for
 
693
        
 
694
        We read all bytes at once to ensure we've checked the trailer for 
764
695
        errors, and then feed the buffer back as read_body_bytes is called.
765
696
        """
766
697
        if self._body_buffer is not None:
767
698
            return self._body_buffer.read(count)
768
699
        _body_decoder = LengthPrefixedBodyDecoder()
769
700
 
 
701
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
702
        # buffer space available) on Windows.
 
703
        max_read = 64 * 1024
770
704
        while not _body_decoder.finished_reading:
771
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
705
            bytes_wanted = min(_body_decoder.next_read_size(), max_read)
 
706
            bytes = self._request.read_bytes(bytes_wanted)
772
707
            if bytes == '':
773
708
                # end of file encountered reading from server
774
709
                raise errors.ConnectionReset(
784
719
 
785
720
    def _recv_tuple(self):
786
721
        """Receive a tuple from the medium request."""
787
 
        return _decode_tuple(self._request.read_line())
 
722
        return _decode_tuple(self._recv_line())
 
723
 
 
724
    def _recv_line(self):
 
725
        """Read an entire line from the medium request."""
 
726
        line = ''
 
727
        while not line or line[-1] != '\n':
 
728
            # TODO: this is inefficient - but tuples are short.
 
729
            new_char = self._request.read_bytes(1)
 
730
            if new_char == '':
 
731
                # end of file encountered reading from server
 
732
                raise errors.ConnectionReset(
 
733
                    "please check connectivity and permissions",
 
734
                    "(and try -Dhpss if further diagnosis is required)")
 
735
            line += new_char
 
736
        return line
788
737
 
789
738
    def query_version(self):
790
739
        """Return protocol version number of the server."""
804
753
 
805
754
    def _write_protocol_version(self):
806
755
        """Write any prefixes this protocol requires.
807
 
 
 
756
        
808
757
        Version one doesn't send protocol versions.
809
758
        """
810
759
 
811
760
 
812
761
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
813
762
    """Version two of the client side of the smart protocol.
814
 
 
 
763
    
815
764
    This prefixes the request with the value of REQUEST_VERSION_TWO.
816
765
    """
817
766
 
827
776
        if version != self.response_marker:
828
777
            self._request.finished_reading()
829
778
            raise errors.UnexpectedProtocolVersionMarker(version)
830
 
        response_status = self._request.read_line()
 
779
        response_status = self._recv_line()
831
780
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
832
781
        self._response_is_unknown_method(result)
833
782
        if response_status == 'success\n':
845
794
 
846
795
    def _write_protocol_version(self):
847
796
        """Write any prefixes this protocol requires.
848
 
 
 
797
        
849
798
        Version two sends the value of REQUEST_VERSION_TWO.
850
799
        """
851
800
        self._request.accept_bytes(self.request_marker)
855
804
        """
856
805
        # Read no more than 64k at a time so that we don't risk error 10055 (no
857
806
        # buffer space available) on Windows.
 
807
        max_read = 64 * 1024
858
808
        _body_decoder = ChunkedBodyDecoder()
859
809
        while not _body_decoder.finished_reading:
860
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
810
            bytes_wanted = min(_body_decoder.next_read_size(), max_read)
 
811
            bytes = self._request.read_bytes(bytes_wanted)
861
812
            if bytes == '':
862
813
                # end of file encountered reading from server
863
814
                raise errors.ConnectionReset(
872
823
 
873
824
 
874
825
def build_server_protocol_three(backing_transport, write_func,
875
 
                                root_client_path, jail_root=None):
 
826
                                root_client_path):
876
827
    request_handler = request.SmartServerRequestHandler(
877
828
        backing_transport, commands=request.request_handlers,
878
 
        root_client_path=root_client_path, jail_root=jail_root)
 
829
        root_client_path=root_client_path)
879
830
    responder = ProtocolThreeResponder(write_func)
880
831
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
881
832
    return ProtocolThreeDecoder(message_handler)
911
862
            # We do *not* set self.decoding_failed here.  The message handler
912
863
            # has raised an error, but the decoder is still able to parse bytes
913
864
            # and determine when this message ends.
914
 
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
915
 
                log_exception_quietly()
 
865
            log_exception_quietly()
916
866
            self.message_handler.protocol_error(exception.exc_value)
917
867
            # The state machine is ready to continue decoding, but the
918
868
            # exception has interrupted the loop that runs the state machine.
935
885
            self.message_handler.protocol_error(exception)
936
886
 
937
887
    def _extract_length_prefixed_bytes(self):
938
 
        if self._in_buffer_len < 4:
 
888
        if len(self._in_buffer) < 4:
939
889
            # A length prefix by itself is 4 bytes, and we don't even have that
940
890
            # many yet.
941
891
            raise _NeedMoreBytes(4)
942
 
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
892
        (length,) = struct.unpack('!L', self._in_buffer[:4])
943
893
        end_of_bytes = 4 + length
944
 
        if self._in_buffer_len < end_of_bytes:
 
894
        if len(self._in_buffer) < end_of_bytes:
945
895
            # We haven't yet read as many bytes as the length-prefix says there
946
896
            # are.
947
897
            raise _NeedMoreBytes(end_of_bytes)
948
898
        # Extract the bytes from the buffer.
949
 
        in_buf = self._get_in_buffer()
950
 
        bytes = in_buf[4:end_of_bytes]
951
 
        self._set_in_buffer(in_buf[end_of_bytes:])
 
899
        bytes = self._in_buffer[4:end_of_bytes]
 
900
        self._in_buffer = self._in_buffer[end_of_bytes:]
952
901
        return bytes
953
902
 
954
903
    def _extract_prefixed_bencoded_data(self):
955
904
        prefixed_bytes = self._extract_length_prefixed_bytes()
956
905
        try:
957
 
            decoded = bdecode_as_tuple(prefixed_bytes)
 
906
            decoded = bdecode(prefixed_bytes)
958
907
        except ValueError:
959
908
            raise errors.SmartProtocolError(
960
909
                'Bytes %r not bencoded' % (prefixed_bytes,))
961
910
        return decoded
962
911
 
963
912
    def _extract_single_byte(self):
964
 
        if self._in_buffer_len == 0:
 
913
        if self._in_buffer == '':
965
914
            # The buffer is empty
966
915
            raise _NeedMoreBytes(1)
967
 
        in_buf = self._get_in_buffer()
968
 
        one_byte = in_buf[0]
969
 
        self._set_in_buffer(in_buf[1:])
 
916
        one_byte = self._in_buffer[0]
 
917
        self._in_buffer = self._in_buffer[1:]
970
918
        return one_byte
971
919
 
972
920
    def _state_accept_expecting_protocol_version(self):
973
 
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
974
 
        in_buf = self._get_in_buffer()
 
921
        needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
975
922
        if needed_bytes > 0:
976
923
            # We don't have enough bytes to check if the protocol version
977
924
            # marker is right.  But we can check if it is already wrong by
981
928
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
982
929
            # are wrong then we should just raise immediately rather than
983
930
            # stall.]
984
 
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
931
            if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
985
932
                # We have enough bytes to know the protocol version is wrong
986
 
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
933
                raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
987
934
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
988
 
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
989
 
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
990
 
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
935
        if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
 
936
            raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
 
937
        self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
991
938
        self.state_accept = self._state_accept_expecting_headers
992
939
 
993
940
    def _state_accept_expecting_headers(self):
1000
947
            self.message_handler.headers_received(decoded)
1001
948
        except:
1002
949
            raise errors.SmartMessageHandlerError(sys.exc_info())
1003
 
 
 
950
    
1004
951
    def _state_accept_expecting_message_part(self):
1005
952
        message_part_kind = self._extract_single_byte()
1006
953
        if message_part_kind == 'o':
1042
989
            raise errors.SmartMessageHandlerError(sys.exc_info())
1043
990
 
1044
991
    def done(self):
1045
 
        self.unused_data = self._get_in_buffer()
1046
 
        self._set_in_buffer(None)
 
992
        self.unused_data = self._in_buffer
 
993
        self._in_buffer = ''
1047
994
        self.state_accept = self._state_accept_reading_unused
1048
995
        try:
1049
996
            self.message_handler.end_received()
1051
998
            raise errors.SmartMessageHandlerError(sys.exc_info())
1052
999
 
1053
1000
    def _state_accept_reading_unused(self):
1054
 
        self.unused_data += self._get_in_buffer()
1055
 
        self._set_in_buffer(None)
 
1001
        self.unused_data += self._in_buffer
 
1002
        self._in_buffer = ''
1056
1003
 
1057
1004
    def next_read_size(self):
1058
1005
        if self.state_accept == self._state_accept_reading_unused:
1065
1012
            return 0
1066
1013
        else:
1067
1014
            if self._number_needed_bytes is not None:
1068
 
                return self._number_needed_bytes - self._in_buffer_len
 
1015
                return self._number_needed_bytes - len(self._in_buffer)
1069
1016
            else:
1070
1017
                raise AssertionError("don't know how many bytes are expected!")
1071
1018
 
1073
1020
class _ProtocolThreeEncoder(object):
1074
1021
 
1075
1022
    response_marker = request_marker = MESSAGE_VERSION_THREE
1076
 
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1077
1023
 
1078
1024
    def __init__(self, write_func):
1079
 
        self._buf = []
1080
 
        self._buf_len = 0
 
1025
        self._buf = ''
1081
1026
        self._real_write_func = write_func
1082
1027
 
1083
1028
    def _write_func(self, bytes):
1084
 
        # TODO: It is probably more appropriate to use sum(map(len, _buf))
1085
 
        #       for total number of bytes to write, rather than buffer based on
1086
 
        #       the number of write() calls
1087
 
        # TODO: Another possibility would be to turn this into an async model.
1088
 
        #       Where we let another thread know that we have some bytes if
1089
 
        #       they want it, but we don't actually block for it
1090
 
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
1091
 
        #       we might just push out smaller bits at a time?
1092
 
        self._buf.append(bytes)
1093
 
        self._buf_len += len(bytes)
1094
 
        if self._buf_len > self.BUFFER_SIZE:
1095
 
            self.flush()
 
1029
        self._buf += bytes
1096
1030
 
1097
1031
    def flush(self):
1098
1032
        if self._buf:
1099
 
            self._real_write_func(''.join(self._buf))
1100
 
            del self._buf[:]
1101
 
            self._buf_len = 0
 
1033
            self._real_write_func(self._buf)
 
1034
            self._buf = ''
1102
1035
 
1103
1036
    def _serialise_offsets(self, offsets):
1104
1037
        """Serialise a readv offset list."""
1106
1039
        for start, length in offsets:
1107
1040
            txt.append('%d,%d' % (start, length))
1108
1041
        return '\n'.join(txt)
1109
 
 
 
1042
        
1110
1043
    def _write_protocol_version(self):
1111
1044
        self._write_func(MESSAGE_VERSION_THREE)
1112
1045
 
1137
1070
        self._write_func(struct.pack('!L', len(bytes)))
1138
1071
        self._write_func(bytes)
1139
1072
 
1140
 
    def _write_chunked_body_start(self):
1141
 
        self._write_func('oC')
1142
 
 
1143
1073
    def _write_error_status(self):
1144
1074
        self._write_func('oE')
1145
1075
 
1153
1083
        _ProtocolThreeEncoder.__init__(self, write_func)
1154
1084
        self.response_sent = False
1155
1085
        self._headers = {'Software version': bzrlib.__version__}
1156
 
        if 'hpss' in debug.debug_flags:
1157
 
            self._thread_id = thread.get_ident()
1158
 
            self._response_start_time = None
1159
 
 
1160
 
    def _trace(self, action, message, extra_bytes=None, include_time=False):
1161
 
        if self._response_start_time is None:
1162
 
            self._response_start_time = osutils.timer_func()
1163
 
        if include_time:
1164
 
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
 
        else:
1166
 
            t = ''
1167
 
        if extra_bytes is None:
1168
 
            extra = ''
1169
 
        else:
1170
 
            extra = ' ' + repr(extra_bytes[:40])
1171
 
            if len(extra) > 33:
1172
 
                extra = extra[:29] + extra[-1] + '...'
1173
 
        mutter('%12s: [%s] %s%s%s'
1174
 
               % (action, self._thread_id, t, message, extra))
1175
1086
 
1176
1087
    def send_error(self, exception):
1177
1088
        if self.response_sent:
1183
1094
                ('UnknownMethod', exception.verb))
1184
1095
            self.send_response(failure)
1185
1096
            return
1186
 
        if 'hpss' in debug.debug_flags:
1187
 
            self._trace('error', str(exception))
1188
1097
        self.response_sent = True
1189
1098
        self._write_protocol_version()
1190
1099
        self._write_headers(self._headers)
1204
1113
            self._write_success_status()
1205
1114
        else:
1206
1115
            self._write_error_status()
1207
 
        if 'hpss' in debug.debug_flags:
1208
 
            self._trace('response', repr(response.args))
1209
1116
        self._write_structure(response.args)
1210
1117
        if response.body is not None:
1211
1118
            self._write_prefixed_body(response.body)
1212
 
            if 'hpss' in debug.debug_flags:
1213
 
                self._trace('body', '%d bytes' % (len(response.body),),
1214
 
                            response.body, include_time=True)
1215
1119
        elif response.body_stream is not None:
1216
 
            count = num_bytes = 0
1217
 
            first_chunk = None
1218
 
            for exc_info, chunk in _iter_with_errors(response.body_stream):
1219
 
                count += 1
1220
 
                if exc_info is not None:
1221
 
                    self._write_error_status()
1222
 
                    error_struct = request._translate_error(exc_info[1])
1223
 
                    self._write_structure(error_struct)
1224
 
                    break
1225
 
                else:
1226
 
                    if isinstance(chunk, request.FailedSmartServerResponse):
1227
 
                        self._write_error_status()
1228
 
                        self._write_structure(chunk.args)
1229
 
                        break
1230
 
                    num_bytes += len(chunk)
1231
 
                    if first_chunk is None:
1232
 
                        first_chunk = chunk
1233
 
                    self._write_prefixed_body(chunk)
1234
 
                    if 'hpssdetail' in debug.debug_flags:
1235
 
                        # Not worth timing separately, as _write_func is
1236
 
                        # actually buffered
1237
 
                        self._trace('body chunk',
1238
 
                                    '%d bytes' % (len(chunk),),
1239
 
                                    chunk, suppress_time=True)
1240
 
            if 'hpss' in debug.debug_flags:
1241
 
                self._trace('body stream',
1242
 
                            '%d bytes %d chunks' % (num_bytes, count),
1243
 
                            first_chunk)
 
1120
            for chunk in response.body_stream:
 
1121
                self._write_prefixed_body(chunk)
 
1122
                self.flush()
1244
1123
        self._write_end()
1245
 
        if 'hpss' in debug.debug_flags:
1246
 
            self._trace('response end', '', include_time=True)
1247
 
 
1248
 
 
1249
 
def _iter_with_errors(iterable):
1250
 
    """Handle errors from iterable.next().
1251
 
 
1252
 
    Use like::
1253
 
 
1254
 
        for exc_info, value in _iter_with_errors(iterable):
1255
 
            ...
1256
 
 
1257
 
    This is a safer alternative to::
1258
 
 
1259
 
        try:
1260
 
            for value in iterable:
1261
 
               ...
1262
 
        except:
1263
 
            ...
1264
 
 
1265
 
    Because the latter will catch errors from the for-loop body, not just
1266
 
    iterable.next()
1267
 
 
1268
 
    If an error occurs, exc_info will be a exc_info tuple, and the generator
1269
 
    will terminate.  Otherwise exc_info will be None, and value will be the
1270
 
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
1271
 
    will not be itercepted.
1272
 
    """
1273
 
    iterator = iter(iterable)
1274
 
    while True:
1275
 
        try:
1276
 
            yield None, iterator.next()
1277
 
        except StopIteration:
1278
 
            return
1279
 
        except (KeyboardInterrupt, SystemExit):
1280
 
            raise
1281
 
        except Exception:
1282
 
            mutter('_iter_with_errors caught error')
1283
 
            log_exception_quietly()
1284
 
            yield sys.exc_info(), None
1285
 
            return
1286
 
 
 
1124
        
1287
1125
 
1288
1126
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1289
1127
 
1294
1132
 
1295
1133
    def set_headers(self, headers):
1296
1134
        self._headers = headers.copy()
1297
 
 
 
1135
        
1298
1136
    def call(self, *args):
1299
1137
        if 'hpss' in debug.debug_flags:
1300
1138
            mutter('hpss call:   %s', repr(args)[1:-1])
1301
1139
            base = getattr(self._medium_request._medium, 'base', None)
1302
1140
            if base is not None:
1303
1141
                mutter('             (to %s)', base)
1304
 
            self._request_start_time = osutils.timer_func()
 
1142
            self._request_start_time = time.time()
1305
1143
        self._write_protocol_version()
1306
1144
        self._write_headers(self._headers)
1307
1145
        self._write_structure(args)
1319
1157
            if path is not None:
1320
1158
                mutter('                  (to %s)', path)
1321
1159
            mutter('              %d bytes', len(body))
1322
 
            self._request_start_time = osutils.timer_func()
 
1160
            self._request_start_time = time.time()
1323
1161
        self._write_protocol_version()
1324
1162
        self._write_headers(self._headers)
1325
1163
        self._write_structure(args)
1338
1176
            path = getattr(self._medium_request._medium, '_path', None)
1339
1177
            if path is not None:
1340
1178
                mutter('                  (to %s)', path)
1341
 
            self._request_start_time = osutils.timer_func()
 
1179
            self._request_start_time = time.time()
1342
1180
        self._write_protocol_version()
1343
1181
        self._write_headers(self._headers)
1344
1182
        self._write_structure(args)
1349
1187
        self._write_end()
1350
1188
        self._medium_request.finished_writing()
1351
1189
 
1352
 
    def call_with_body_stream(self, args, stream):
1353
 
        if 'hpss' in debug.debug_flags:
1354
 
            mutter('hpss call w/body stream: %r', args)
1355
 
            path = getattr(self._medium_request._medium, '_path', None)
1356
 
            if path is not None:
1357
 
                mutter('                  (to %s)', path)
1358
 
            self._request_start_time = osutils.timer_func()
1359
 
        self._write_protocol_version()
1360
 
        self._write_headers(self._headers)
1361
 
        self._write_structure(args)
1362
 
        # TODO: notice if the server has sent an early error reply before we
1363
 
        #       have finished sending the stream.  We would notice at the end
1364
 
        #       anyway, but if the medium can deliver it early then it's good
1365
 
        #       to short-circuit the whole request...
1366
 
        for exc_info, part in _iter_with_errors(stream):
1367
 
            if exc_info is not None:
1368
 
                # Iterating the stream failed.  Cleanly abort the request.
1369
 
                self._write_error_status()
1370
 
                # Currently the client unconditionally sends ('error',) as the
1371
 
                # error args.
1372
 
                self._write_structure(('error',))
1373
 
                self._write_end()
1374
 
                self._medium_request.finished_writing()
1375
 
                raise exc_info[0], exc_info[1], exc_info[2]
1376
 
            else:
1377
 
                self._write_prefixed_body(part)
1378
 
                self.flush()
1379
 
        self._write_end()
1380
 
        self._medium_request.finished_writing()
1381