~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Andrew Bennetts
  • Date: 2008-05-07 10:33:00 UTC
  • mto: This revision was merged to the branch mainline in revision 3428.
  • Revision ID: andrew.bennetts@canonical.com-20080507103300-d0vr775rbjj83xte
Make _SmartClient automatically detect and use the highest protocol version compatible with the server.

Show diffs side-by-side

added added

removed removed

Lines of Context:
89
89
        """
90
90
        raise NotImplementedError(self.call_with_body_readv_array)
91
91
 
 
92
    def set_headers(self, headers):
 
93
        raise NotImplementedError(self.set_headers)
 
94
 
92
95
 
93
96
class SmartProtocolBase(object):
94
97
    """Methods common to client and server"""
523
526
        self._body_buffer = None
524
527
        self._request_start_time = None
525
528
        self._last_verb = None
 
529
        self._headers = None
 
530
 
 
531
    def set_headers(self, headers):
 
532
        self._headers = dict(headers)
526
533
 
527
534
    def call(self, *args):
528
535
        if 'hpss' in debug.debug_flags:
788
795
    response_marker = RESPONSE_VERSION_THREE
789
796
    request_marker = REQUEST_VERSION_THREE
790
797
 
791
 
    def __init__(self, message_handler):
 
798
    def __init__(self, message_handler, expect_version_marker=False):
792
799
        _StatefulDecoder.__init__(self)
793
800
        self.has_dispatched = False
794
801
        # Initial state
795
802
        self._in_buffer = ''
796
 
        self._number_needed_bytes = 4
797
 
        self.state_accept = self._state_accept_expecting_headers
 
803
        if expect_version_marker:
 
804
            self.state_accept = self._state_accept_expecting_protocol_version
 
805
            # We're expecting at least the protocol version marker + some
 
806
            # headers.
 
807
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
808
        else:
 
809
            self.state_accept = self._state_accept_expecting_headers
 
810
            self._number_needed_bytes = 4
798
811
        self.errored = False
799
812
 
800
813
        self.request_handler = self.message_handler = message_handler
843
856
        self._in_buffer = self._in_buffer[1:]
844
857
        return one_byte
845
858
 
 
859
    def _state_accept_expecting_protocol_version(self, bytes):
 
860
        self._in_buffer += bytes
 
861
        needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
 
862
        if needed_bytes > 0:
 
863
            if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
 
864
                # We have enough bytes to know the protocol version is wrong
 
865
                raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
 
866
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
867
        if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
 
868
            raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
 
869
        self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
 
870
        self.state_accept = self._state_accept_expecting_headers
 
871
 
846
872
    def _state_accept_expecting_headers(self, bytes):
847
873
        self._in_buffer += bytes
848
874
        decoded = self._extract_prefixed_bencoded_data()
934
960
        self._write_func(struct.pack('!L', len(bytes)))
935
961
        self._write_func(bytes)
936
962
 
937
 
    def _write_headers(self, headers=None):
938
 
        if headers is None:
939
 
            headers = {'Software version': bzrlib.__version__}
 
963
    def _write_headers(self, headers):
940
964
        self._write_prefixed_bencode(headers)
941
965
 
942
966
    def _write_structure(self, args):
969
993
    def __init__(self, write_func):
970
994
        _ProtocolThreeEncoder.__init__(self, write_func)
971
995
        self.response_sent = False
 
996
        self._headers = {'Software version': bzrlib.__version__}
972
997
 
973
998
    def send_error(self, exception):
974
999
        assert not self.response_sent
978
1003
            self.send_response(failure)
979
1004
            return
980
1005
        self.response_sent = True
981
 
        self._write_headers()
 
1006
        self._write_protocol_version()
 
1007
        self._write_headers(self._headers)
982
1008
        self._write_error_status()
983
1009
        self._write_structure(('error', str(exception)))
984
1010
        self._write_end()
986
1012
    def send_response(self, response):
987
1013
        assert not self.response_sent
988
1014
        self.response_sent = True
989
 
        self._write_headers()
 
1015
        self._write_protocol_version()
 
1016
        self._write_headers(self._headers)
990
1017
        if response.is_successful():
991
1018
            self._write_success_status()
992
1019
        else:
1005
1032
    def __init__(self, medium_request):
1006
1033
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1007
1034
        self._medium_request = medium_request
 
1035
        self._headers = {}
1008
1036
 
1009
 
    def call(self, *args, **kw):
1010
 
        # XXX: ideally, signature would be call(self, *args, headers=None), but
1011
 
        # python doesn't allow that.  So, we fake it.
1012
 
        headers = None
1013
 
        if 'headers' in kw:
1014
 
            headers = kw.pop('headers')
1015
 
        if kw != {}:
1016
 
            raise TypeError('Unexpected keyword arguments: %r' % (kw,))
 
1037
    def set_headers(self, headers):
 
1038
        self._headers = dict(headers)
 
1039
        
 
1040
    def call(self, *args):
1017
1041
        if 'hpss' in debug.debug_flags:
1018
1042
            mutter('hpss call:   %s', repr(args)[1:-1])
1019
1043
            base = getattr(self._medium_request._medium, 'base', None)
1021
1045
                mutter('             (to %s)', base)
1022
1046
            self._request_start_time = time.time()
1023
1047
        self._write_protocol_version()
1024
 
        self._write_headers(headers)
 
1048
        self._write_headers(self._headers)
1025
1049
        self._write_structure(args)
1026
1050
        self._write_end()
1027
1051
        self._medium_request.finished_writing()
1028
1052
 
1029
 
    def call_with_body_bytes(self, args, body, headers=None):
 
1053
    def call_with_body_bytes(self, args, body):
1030
1054
        """Make a remote call of args with body bytes 'body'.
1031
1055
 
1032
1056
        After calling this, call read_response_tuple to find the result out.
1039
1063
            mutter('              %d bytes', len(body))
1040
1064
            self._request_start_time = time.time()
1041
1065
        self._write_protocol_version()
1042
 
        self._write_headers(headers)
 
1066
        self._write_headers(self._headers)
1043
1067
        self._write_structure(args)
1044
1068
        self._write_prefixed_body(body)
1045
1069
        self._write_end()
1046
1070
        self._medium_request.finished_writing()
1047
1071
 
1048
 
    def call_with_body_readv_array(self, args, body, headers=None):
 
1072
    def call_with_body_readv_array(self, args, body):
1049
1073
        """Make a remote call with a readv array.
1050
1074
 
1051
1075
        The body is encoded with one line per readv offset pair. The numbers in
1058
1082
                mutter('                  (to %s)', path)
1059
1083
            self._request_start_time = time.time()
1060
1084
        self._write_protocol_version()
1061
 
        self._write_headers(headers)
 
1085
        self._write_headers(self._headers)
1062
1086
        self._write_structure(args)
1063
1087
        readv_bytes = self._serialise_offsets(body)
1064
1088
        if 'hpss' in debug.debug_flags: