~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Vincent Ladeuil
  • Date: 2009-06-22 12:52:39 UTC
  • mto: (4471.1.1 integration)
  • mto: This revision was merged to the branch mainline in revision 4472.
  • Revision ID: v.ladeuil+lp@free.fr-20090622125239-kabo9smxt9c3vnir
Use a consistent scheme for naming pyrex source files.

Show diffs side-by-side

added added

removed removed

Lines of Context:
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
29
29
from bzrlib import errors
30
30
from bzrlib.smart import message, request
31
31
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode, bencode
 
32
from bzrlib.bencode import bdecode_as_tuple, bencode
33
33
 
34
34
 
35
35
# Protocol version strings.  These are sent as prefixes of bzr requests and
109
109
        for start, length in offsets:
110
110
            txt.append('%d,%d' % (start, length))
111
111
        return '\n'.join(txt)
112
 
        
 
112
 
113
113
 
114
114
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
115
    """Server-side encoding and decoding logic for smart version 1."""
116
 
    
 
116
 
117
117
    def __init__(self, backing_transport, write_func, root_client_path='/'):
118
118
        self._backing_transport = backing_transport
119
119
        self._root_client_path = root_client_path
127
127
 
128
128
    def accept_bytes(self, bytes):
129
129
        """Take bytes, and advance the internal state machine appropriately.
130
 
        
 
130
 
131
131
        :param bytes: must be a byte string
132
132
        """
133
133
        if not isinstance(bytes, str):
169
169
 
170
170
        if self._has_dispatched:
171
171
            if self._finished:
172
 
                # nothing to do.XXX: this routine should be a single state 
 
172
                # nothing to do.XXX: this routine should be a single state
173
173
                # machine too.
174
174
                self.unused_data += self.in_buffer
175
175
                self.in_buffer = ''
211
211
 
212
212
    def _write_protocol_version(self):
213
213
        """Write any prefixes this protocol requires.
214
 
        
 
214
 
215
215
        Version one doesn't send protocol versions.
216
216
        """
217
217
 
234
234
 
235
235
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
236
236
    r"""Version two of the server side of the smart protocol.
237
 
   
 
237
 
238
238
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
239
    """
240
240
 
250
250
 
251
251
    def _write_protocol_version(self):
252
252
        r"""Write any prefixes this protocol requires.
253
 
        
 
253
 
254
254
        Version two sends the value of RESPONSE_VERSION_TWO.
255
255
        """
256
256
        self._write_func(self.response_marker)
323
323
 
324
324
    def __init__(self):
325
325
        self.finished_reading = False
326
 
        self._in_buffer = ''
 
326
        self._in_buffer_list = []
 
327
        self._in_buffer_len = 0
327
328
        self.unused_data = ''
328
329
        self.bytes_left = None
329
330
        self._number_needed_bytes = None
330
331
 
 
332
    def _get_in_buffer(self):
 
333
        if len(self._in_buffer_list) == 1:
 
334
            return self._in_buffer_list[0]
 
335
        in_buffer = ''.join(self._in_buffer_list)
 
336
        if len(in_buffer) != self._in_buffer_len:
 
337
            raise AssertionError(
 
338
                "Length of buffer did not match expected value: %s != %s"
 
339
                % self._in_buffer_len, len(in_buffer))
 
340
        self._in_buffer_list = [in_buffer]
 
341
        return in_buffer
 
342
 
 
343
    def _get_in_bytes(self, count):
 
344
        """Grab X bytes from the input_buffer.
 
345
 
 
346
        Callers should have already checked that self._in_buffer_len is >
 
347
        count. Note, this does not consume the bytes from the buffer. The
 
348
        caller will still need to call _get_in_buffer() and then
 
349
        _set_in_buffer() if they actually need to consume the bytes.
 
350
        """
 
351
        # check if we can yield the bytes from just the first entry in our list
 
352
        if len(self._in_buffer_list) == 0:
 
353
            raise AssertionError('Callers must be sure we have buffered bytes'
 
354
                ' before calling _get_in_bytes')
 
355
        if len(self._in_buffer_list[0]) > count:
 
356
            return self._in_buffer_list[0][:count]
 
357
        # We can't yield it from the first buffer, so collapse all buffers, and
 
358
        # yield it from that
 
359
        in_buf = self._get_in_buffer()
 
360
        return in_buf[:count]
 
361
 
 
362
    def _set_in_buffer(self, new_buf):
 
363
        if new_buf is not None:
 
364
            self._in_buffer_list = [new_buf]
 
365
            self._in_buffer_len = len(new_buf)
 
366
        else:
 
367
            self._in_buffer_list = []
 
368
            self._in_buffer_len = 0
 
369
 
331
370
    def accept_bytes(self, bytes):
332
371
        """Decode as much of bytes as possible.
333
372
 
338
377
        data will be appended to self.unused_data.
339
378
        """
340
379
        # accept_bytes is allowed to change the state
341
 
        current_state = self.state_accept
342
380
        self._number_needed_bytes = None
343
 
        self._in_buffer += bytes
 
381
        # lsprof puts a very large amount of time on this specific call for
 
382
        # large readv arrays
 
383
        self._in_buffer_list.append(bytes)
 
384
        self._in_buffer_len += len(bytes)
344
385
        try:
345
386
            # Run the function for the current state.
 
387
            current_state = self.state_accept
346
388
            self.state_accept()
347
389
            while current_state != self.state_accept:
348
390
                # The current state has changed.  Run the function for the new
370
412
        self.chunks = collections.deque()
371
413
        self.error = False
372
414
        self.error_in_progress = None
373
 
    
 
415
 
374
416
    def next_read_size(self):
375
417
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
376
418
        # end-of-body marker is 4 bytes: 'END\n'.
379
421
            # the rest of this chunk plus an END chunk.
380
422
            return self.bytes_left + 4
381
423
        elif self.state_accept == self._state_accept_expecting_length:
382
 
            if self._in_buffer == '':
 
424
            if self._in_buffer_len == 0:
383
425
                # We're expecting a chunk length.  There's at least two bytes
384
426
                # left: a digit plus '\n'.
385
427
                return 2
390
432
        elif self.state_accept == self._state_accept_reading_unused:
391
433
            return 1
392
434
        elif self.state_accept == self._state_accept_expecting_header:
393
 
            return max(0, len('chunked\n') - len(self._in_buffer))
 
435
            return max(0, len('chunked\n') - self._in_buffer_len)
394
436
        else:
395
437
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
396
438
 
401
443
            return None
402
444
 
403
445
    def _extract_line(self):
404
 
        pos = self._in_buffer.find('\n')
 
446
        in_buf = self._get_in_buffer()
 
447
        pos = in_buf.find('\n')
405
448
        if pos == -1:
406
449
            # We haven't read a complete line yet, so request more bytes before
407
450
            # we continue.
408
451
            raise _NeedMoreBytes(1)
409
 
        line = self._in_buffer[:pos]
 
452
        line = in_buf[:pos]
410
453
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
411
 
        self._in_buffer = self._in_buffer[pos+1:]
 
454
        self._set_in_buffer(in_buf[pos+1:])
412
455
        return line
413
456
 
414
457
    def _finished(self):
415
 
        self.unused_data = self._in_buffer
416
 
        self._in_buffer = ''
 
458
        self.unused_data = self._get_in_buffer()
 
459
        self._in_buffer_list = []
 
460
        self._in_buffer_len = 0
417
461
        self.state_accept = self._state_accept_reading_unused
418
462
        if self.error:
419
463
            error_args = tuple(self.error_in_progress)
448
492
            self.state_accept = self._state_accept_reading_chunk
449
493
 
450
494
    def _state_accept_reading_chunk(self):
451
 
        in_buffer_len = len(self._in_buffer)
452
 
        self.chunk_in_progress += self._in_buffer[:self.bytes_left]
453
 
        self._in_buffer = self._in_buffer[self.bytes_left:]
 
495
        in_buf = self._get_in_buffer()
 
496
        in_buffer_len = len(in_buf)
 
497
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
498
        self._set_in_buffer(in_buf[self.bytes_left:])
454
499
        self.bytes_left -= in_buffer_len
455
500
        if self.bytes_left <= 0:
456
501
            # Finished with chunk
461
506
                self.chunks.append(self.chunk_in_progress)
462
507
            self.chunk_in_progress = None
463
508
            self.state_accept = self._state_accept_expecting_length
464
 
        
 
509
 
465
510
    def _state_accept_reading_unused(self):
466
 
        self.unused_data += self._in_buffer
467
 
        self._in_buffer = ''
 
511
        self.unused_data += self._get_in_buffer()
 
512
        self._in_buffer_list = []
468
513
 
469
514
 
470
515
class LengthPrefixedBodyDecoder(_StatefulDecoder):
471
516
    """Decodes the length-prefixed bulk data."""
472
 
    
 
517
 
473
518
    def __init__(self):
474
519
        _StatefulDecoder.__init__(self)
475
520
        self.state_accept = self._state_accept_expecting_length
476
521
        self.state_read = self._state_read_no_data
477
522
        self._body = ''
478
523
        self._trailer_buffer = ''
479
 
    
 
524
 
480
525
    def next_read_size(self):
481
526
        if self.bytes_left is not None:
482
527
            # Ideally we want to read all the remainder of the body and the
492
537
        else:
493
538
            # Reading excess data.  Either way, 1 byte at a time is fine.
494
539
            return 1
495
 
        
 
540
 
496
541
    def read_pending_data(self):
497
542
        """Return any pending data that has been decoded."""
498
543
        return self.state_read()
499
544
 
500
545
    def _state_accept_expecting_length(self):
501
 
        pos = self._in_buffer.find('\n')
 
546
        in_buf = self._get_in_buffer()
 
547
        pos = in_buf.find('\n')
502
548
        if pos == -1:
503
549
            return
504
 
        self.bytes_left = int(self._in_buffer[:pos])
505
 
        self._in_buffer = self._in_buffer[pos+1:]
 
550
        self.bytes_left = int(in_buf[:pos])
 
551
        self._set_in_buffer(in_buf[pos+1:])
506
552
        self.state_accept = self._state_accept_reading_body
507
553
        self.state_read = self._state_read_body_buffer
508
554
 
509
555
    def _state_accept_reading_body(self):
510
 
        self._body += self._in_buffer
511
 
        self.bytes_left -= len(self._in_buffer)
512
 
        self._in_buffer = ''
 
556
        in_buf = self._get_in_buffer()
 
557
        self._body += in_buf
 
558
        self.bytes_left -= len(in_buf)
 
559
        self._set_in_buffer(None)
513
560
        if self.bytes_left <= 0:
514
561
            # Finished with body
515
562
            if self.bytes_left != 0:
517
564
                self._body = self._body[:self.bytes_left]
518
565
            self.bytes_left = None
519
566
            self.state_accept = self._state_accept_reading_trailer
520
 
        
 
567
 
521
568
    def _state_accept_reading_trailer(self):
522
 
        self._trailer_buffer += self._in_buffer
523
 
        self._in_buffer = ''
 
569
        self._trailer_buffer += self._get_in_buffer()
 
570
        self._set_in_buffer(None)
524
571
        # TODO: what if the trailer does not match "done\n"?  Should this raise
525
572
        # a ProtocolViolation exception?
526
573
        if self._trailer_buffer.startswith('done\n'):
527
574
            self.unused_data = self._trailer_buffer[len('done\n'):]
528
575
            self.state_accept = self._state_accept_reading_unused
529
576
            self.finished_reading = True
530
 
    
 
577
 
531
578
    def _state_accept_reading_unused(self):
532
 
        self.unused_data += self._in_buffer
533
 
        self._in_buffer = ''
 
579
        self.unused_data += self._get_in_buffer()
 
580
        self._set_in_buffer(None)
534
581
 
535
582
    def _state_read_no_data(self):
536
583
        return ''
609
656
            mutter('              %d bytes in readv request', len(readv_bytes))
610
657
        self._last_verb = args[0]
611
658
 
 
659
    def call_with_body_stream(self, args, stream):
 
660
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
661
        # assume that a v1/v2 server doesn't support whatever method we're
 
662
        # trying to call with a body stream.
 
663
        self._request.finished_writing()
 
664
        self._request.finished_reading()
 
665
        raise errors.UnknownSmartMethod(args[0])
 
666
 
612
667
    def cancel_read_body(self):
613
668
        """After expecting a body, a response code may indicate one otherwise.
614
669
 
674
729
    def _response_is_unknown_method(self, result_tuple):
675
730
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
676
731
        method' response to the request.
677
 
        
 
732
 
678
733
        :param response: The response from a smart client call_expecting_body
679
734
            call.
680
735
        :param verb: The verb used in that call.
687
742
            # The response will have no body, so we've finished reading.
688
743
            self._request.finished_reading()
689
744
            raise errors.UnknownSmartMethod(self._last_verb)
690
 
        
 
745
 
691
746
    def read_body_bytes(self, count=-1):
692
747
        """Read bytes from the body, decoding into a byte stream.
693
 
        
694
 
        We read all bytes at once to ensure we've checked the trailer for 
 
748
 
 
749
        We read all bytes at once to ensure we've checked the trailer for
695
750
        errors, and then feed the buffer back as read_body_bytes is called.
696
751
        """
697
752
        if self._body_buffer is not None:
698
753
            return self._body_buffer.read(count)
699
754
        _body_decoder = LengthPrefixedBodyDecoder()
700
755
 
701
 
        # Read no more than 64k at a time so that we don't risk error 10055 (no
702
 
        # buffer space available) on Windows.
703
 
        max_read = 64 * 1024
704
756
        while not _body_decoder.finished_reading:
705
 
            bytes_wanted = min(_body_decoder.next_read_size(), max_read)
706
 
            bytes = self._request.read_bytes(bytes_wanted)
 
757
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
758
            if bytes == '':
 
759
                # end of file encountered reading from server
 
760
                raise errors.ConnectionReset(
 
761
                    "Connection lost while reading response body.")
707
762
            _body_decoder.accept_bytes(bytes)
708
763
        self._request.finished_reading()
709
764
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
715
770
 
716
771
    def _recv_tuple(self):
717
772
        """Receive a tuple from the medium request."""
718
 
        return _decode_tuple(self._recv_line())
719
 
 
720
 
    def _recv_line(self):
721
 
        """Read an entire line from the medium request."""
722
 
        line = ''
723
 
        while not line or line[-1] != '\n':
724
 
            # TODO: this is inefficient - but tuples are short.
725
 
            new_char = self._request.read_bytes(1)
726
 
            if new_char == '':
727
 
                # end of file encountered reading from server
728
 
                raise errors.ConnectionReset(
729
 
                    "please check connectivity and permissions",
730
 
                    "(and try -Dhpss if further diagnosis is required)")
731
 
            line += new_char
732
 
        return line
 
773
        return _decode_tuple(self._request.read_line())
733
774
 
734
775
    def query_version(self):
735
776
        """Return protocol version number of the server."""
749
790
 
750
791
    def _write_protocol_version(self):
751
792
        """Write any prefixes this protocol requires.
752
 
        
 
793
 
753
794
        Version one doesn't send protocol versions.
754
795
        """
755
796
 
756
797
 
757
798
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
758
799
    """Version two of the client side of the smart protocol.
759
 
    
 
800
 
760
801
    This prefixes the request with the value of REQUEST_VERSION_TWO.
761
802
    """
762
803
 
772
813
        if version != self.response_marker:
773
814
            self._request.finished_reading()
774
815
            raise errors.UnexpectedProtocolVersionMarker(version)
775
 
        response_status = self._recv_line()
 
816
        response_status = self._request.read_line()
776
817
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
777
818
        self._response_is_unknown_method(result)
778
819
        if response_status == 'success\n':
790
831
 
791
832
    def _write_protocol_version(self):
792
833
        """Write any prefixes this protocol requires.
793
 
        
 
834
 
794
835
        Version two sends the value of REQUEST_VERSION_TWO.
795
836
        """
796
837
        self._request.accept_bytes(self.request_marker)
800
841
        """
801
842
        # Read no more than 64k at a time so that we don't risk error 10055 (no
802
843
        # buffer space available) on Windows.
803
 
        max_read = 64 * 1024
804
844
        _body_decoder = ChunkedBodyDecoder()
805
845
        while not _body_decoder.finished_reading:
806
 
            bytes_wanted = min(_body_decoder.next_read_size(), max_read)
807
 
            bytes = self._request.read_bytes(bytes_wanted)
 
846
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
847
            if bytes == '':
 
848
                # end of file encountered reading from server
 
849
                raise errors.ConnectionReset(
 
850
                    "Connection lost while reading streamed body.")
808
851
            _body_decoder.accept_bytes(bytes)
809
852
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
810
853
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
877
920
            self.message_handler.protocol_error(exception)
878
921
 
879
922
    def _extract_length_prefixed_bytes(self):
880
 
        if len(self._in_buffer) < 4:
 
923
        if self._in_buffer_len < 4:
881
924
            # A length prefix by itself is 4 bytes, and we don't even have that
882
925
            # many yet.
883
926
            raise _NeedMoreBytes(4)
884
 
        (length,) = struct.unpack('!L', self._in_buffer[:4])
 
927
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
885
928
        end_of_bytes = 4 + length
886
 
        if len(self._in_buffer) < end_of_bytes:
 
929
        if self._in_buffer_len < end_of_bytes:
887
930
            # We haven't yet read as many bytes as the length-prefix says there
888
931
            # are.
889
932
            raise _NeedMoreBytes(end_of_bytes)
890
933
        # Extract the bytes from the buffer.
891
 
        bytes = self._in_buffer[4:end_of_bytes]
892
 
        self._in_buffer = self._in_buffer[end_of_bytes:]
 
934
        in_buf = self._get_in_buffer()
 
935
        bytes = in_buf[4:end_of_bytes]
 
936
        self._set_in_buffer(in_buf[end_of_bytes:])
893
937
        return bytes
894
938
 
895
939
    def _extract_prefixed_bencoded_data(self):
896
940
        prefixed_bytes = self._extract_length_prefixed_bytes()
897
941
        try:
898
 
            decoded = bdecode(prefixed_bytes)
 
942
            decoded = bdecode_as_tuple(prefixed_bytes)
899
943
        except ValueError:
900
944
            raise errors.SmartProtocolError(
901
945
                'Bytes %r not bencoded' % (prefixed_bytes,))
902
946
        return decoded
903
947
 
904
948
    def _extract_single_byte(self):
905
 
        if self._in_buffer == '':
 
949
        if self._in_buffer_len == 0:
906
950
            # The buffer is empty
907
951
            raise _NeedMoreBytes(1)
908
 
        one_byte = self._in_buffer[0]
909
 
        self._in_buffer = self._in_buffer[1:]
 
952
        in_buf = self._get_in_buffer()
 
953
        one_byte = in_buf[0]
 
954
        self._set_in_buffer(in_buf[1:])
910
955
        return one_byte
911
956
 
912
957
    def _state_accept_expecting_protocol_version(self):
913
 
        needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
 
958
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
959
        in_buf = self._get_in_buffer()
914
960
        if needed_bytes > 0:
915
961
            # We don't have enough bytes to check if the protocol version
916
962
            # marker is right.  But we can check if it is already wrong by
920
966
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
921
967
            # are wrong then we should just raise immediately rather than
922
968
            # stall.]
923
 
            if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
 
969
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
924
970
                # We have enough bytes to know the protocol version is wrong
925
 
                raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
 
971
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
926
972
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
927
 
        if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
928
 
            raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
929
 
        self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
 
973
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
974
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
975
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
930
976
        self.state_accept = self._state_accept_expecting_headers
931
977
 
932
978
    def _state_accept_expecting_headers(self):
939
985
            self.message_handler.headers_received(decoded)
940
986
        except:
941
987
            raise errors.SmartMessageHandlerError(sys.exc_info())
942
 
    
 
988
 
943
989
    def _state_accept_expecting_message_part(self):
944
990
        message_part_kind = self._extract_single_byte()
945
991
        if message_part_kind == 'o':
981
1027
            raise errors.SmartMessageHandlerError(sys.exc_info())
982
1028
 
983
1029
    def done(self):
984
 
        self.unused_data = self._in_buffer
985
 
        self._in_buffer = ''
 
1030
        self.unused_data = self._get_in_buffer()
 
1031
        self._set_in_buffer(None)
986
1032
        self.state_accept = self._state_accept_reading_unused
987
1033
        try:
988
1034
            self.message_handler.end_received()
990
1036
            raise errors.SmartMessageHandlerError(sys.exc_info())
991
1037
 
992
1038
    def _state_accept_reading_unused(self):
993
 
        self.unused_data += self._in_buffer
994
 
        self._in_buffer = ''
 
1039
        self.unused_data = self._get_in_buffer()
 
1040
        self._set_in_buffer(None)
995
1041
 
996
1042
    def next_read_size(self):
997
1043
        if self.state_accept == self._state_accept_reading_unused:
1004
1050
            return 0
1005
1051
        else:
1006
1052
            if self._number_needed_bytes is not None:
1007
 
                return self._number_needed_bytes - len(self._in_buffer)
 
1053
                return self._number_needed_bytes - self._in_buffer_len
1008
1054
            else:
1009
1055
                raise AssertionError("don't know how many bytes are expected!")
1010
1056
 
1014
1060
    response_marker = request_marker = MESSAGE_VERSION_THREE
1015
1061
 
1016
1062
    def __init__(self, write_func):
1017
 
        self._buf = ''
 
1063
        self._buf = []
1018
1064
        self._real_write_func = write_func
1019
1065
 
1020
1066
    def _write_func(self, bytes):
1021
 
        self._buf += bytes
 
1067
        self._buf.append(bytes)
 
1068
        if len(self._buf) > 100:
 
1069
            self.flush()
1022
1070
 
1023
1071
    def flush(self):
1024
1072
        if self._buf:
1025
 
            self._real_write_func(self._buf)
1026
 
            self._buf = ''
 
1073
            self._real_write_func(''.join(self._buf))
 
1074
            del self._buf[:]
1027
1075
 
1028
1076
    def _serialise_offsets(self, offsets):
1029
1077
        """Serialise a readv offset list."""
1031
1079
        for start, length in offsets:
1032
1080
            txt.append('%d,%d' % (start, length))
1033
1081
        return '\n'.join(txt)
1034
 
        
 
1082
 
1035
1083
    def _write_protocol_version(self):
1036
1084
        self._write_func(MESSAGE_VERSION_THREE)
1037
1085
 
1062
1110
        self._write_func(struct.pack('!L', len(bytes)))
1063
1111
        self._write_func(bytes)
1064
1112
 
 
1113
    def _write_chunked_body_start(self):
 
1114
        self._write_func('oC')
 
1115
 
1065
1116
    def _write_error_status(self):
1066
1117
        self._write_func('oE')
1067
1118
 
1109
1160
        if response.body is not None:
1110
1161
            self._write_prefixed_body(response.body)
1111
1162
        elif response.body_stream is not None:
1112
 
            for chunk in response.body_stream:
1113
 
                self._write_prefixed_body(chunk)
1114
 
                self.flush()
 
1163
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1164
                if exc_info is not None:
 
1165
                    self._write_error_status()
 
1166
                    error_struct = request._translate_error(exc_info[1])
 
1167
                    self._write_structure(error_struct)
 
1168
                    break
 
1169
                else:
 
1170
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1171
                        self._write_error_status()
 
1172
                        self._write_structure(chunk.args)
 
1173
                        break
 
1174
                    self._write_prefixed_body(chunk)
1115
1175
        self._write_end()
1116
 
        
 
1176
 
 
1177
 
 
1178
def _iter_with_errors(iterable):
 
1179
    """Handle errors from iterable.next().
 
1180
 
 
1181
    Use like::
 
1182
 
 
1183
        for exc_info, value in _iter_with_errors(iterable):
 
1184
            ...
 
1185
 
 
1186
    This is a safer alternative to::
 
1187
 
 
1188
        try:
 
1189
            for value in iterable:
 
1190
               ...
 
1191
        except:
 
1192
            ...
 
1193
 
 
1194
    Because the latter will catch errors from the for-loop body, not just
 
1195
    iterable.next()
 
1196
 
 
1197
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1198
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1199
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1200
    will not be itercepted.
 
1201
    """
 
1202
    iterator = iter(iterable)
 
1203
    while True:
 
1204
        try:
 
1205
            yield None, iterator.next()
 
1206
        except StopIteration:
 
1207
            return
 
1208
        except (KeyboardInterrupt, SystemExit):
 
1209
            raise
 
1210
        except Exception:
 
1211
            yield sys.exc_info(), None
 
1212
            return
 
1213
 
1117
1214
 
1118
1215
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1119
1216
 
1124
1221
 
1125
1222
    def set_headers(self, headers):
1126
1223
        self._headers = headers.copy()
1127
 
        
 
1224
 
1128
1225
    def call(self, *args):
1129
1226
        if 'hpss' in debug.debug_flags:
1130
1227
            mutter('hpss call:   %s', repr(args)[1:-1])
1179
1276
        self._write_end()
1180
1277
        self._medium_request.finished_writing()
1181
1278
 
 
1279
    def call_with_body_stream(self, args, stream):
 
1280
        if 'hpss' in debug.debug_flags:
 
1281
            mutter('hpss call w/body stream: %r', args)
 
1282
            path = getattr(self._medium_request._medium, '_path', None)
 
1283
            if path is not None:
 
1284
                mutter('                  (to %s)', path)
 
1285
            self._request_start_time = time.time()
 
1286
        self._write_protocol_version()
 
1287
        self._write_headers(self._headers)
 
1288
        self._write_structure(args)
 
1289
        # TODO: notice if the server has sent an early error reply before we
 
1290
        #       have finished sending the stream.  We would notice at the end
 
1291
        #       anyway, but if the medium can deliver it early then it's good
 
1292
        #       to short-circuit the whole request...
 
1293
        for exc_info, part in _iter_with_errors(stream):
 
1294
            if exc_info is not None:
 
1295
                # Iterating the stream failed.  Cleanly abort the request.
 
1296
                self._write_error_status()
 
1297
                # Currently the client unconditionally sends ('error',) as the
 
1298
                # error args.
 
1299
                self._write_structure(('error',))
 
1300
                self._write_end()
 
1301
                self._medium_request.finished_writing()
 
1302
                raise exc_info[0], exc_info[1], exc_info[2]
 
1303
            else:
 
1304
                self._write_prefixed_body(part)
 
1305
                self.flush()
 
1306
        self._write_end()
 
1307
        self._medium_request.finished_writing()
 
1308