~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/client.py

Update to bzr.dev.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006-2008 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
 
import urllib
18
 
from urlparse import urlparse
19
 
 
20
 
from bzrlib.smart import protocol
21
 
from bzrlib import (
22
 
    errors,
23
 
    urlutils,
24
 
    )
 
17
import bzrlib
 
18
from bzrlib.smart import message, protocol
 
19
from bzrlib.trace import warning
 
20
from bzrlib import errors
25
21
 
26
22
 
27
23
class _SmartClient(object):
28
24
 
29
 
    def __init__(self, medium, base):
 
25
    def __init__(self, medium, headers=None):
30
26
        """Constructor.
31
27
 
32
28
        :param medium: a SmartClientMedium
33
 
        :param base: a URL
34
29
        """
35
30
        self._medium = medium
36
 
        self._base = base
37
 
 
38
 
    def _build_client_protocol(self):
39
 
        version = self._medium.protocol_version()
 
31
        if headers is None:
 
32
            self._headers = {'Software version': bzrlib.__version__}
 
33
        else:
 
34
            self._headers = dict(headers)
 
35
 
 
36
    def _send_request(self, protocol_version, method, args, body=None,
 
37
                      readv_body=None):
 
38
        encoder, response_handler = self._construct_protocol(
 
39
            protocol_version)
 
40
        encoder.set_headers(self._headers)
 
41
        if body is not None:
 
42
            if readv_body is not None:
 
43
                raise AssertionError(
 
44
                    "body and readv_body are mutually exclusive.")
 
45
            encoder.call_with_body_bytes((method, ) + args, body)
 
46
        elif readv_body is not None:
 
47
            encoder.call_with_body_readv_array((method, ) + args,
 
48
                    readv_body)
 
49
        else:
 
50
            encoder.call(method, *args)
 
51
        return response_handler
 
52
 
 
53
    def _call_and_read_response(self, method, args, body=None, readv_body=None,
 
54
            expect_response_body=True):
 
55
        if self._medium._protocol_version is not None:
 
56
            response_handler = self._send_request(
 
57
                self._medium._protocol_version, method, args, body=body,
 
58
                readv_body=readv_body)
 
59
            return (response_handler.read_response_tuple(
 
60
                        expect_body=expect_response_body),
 
61
                    response_handler)
 
62
        else:
 
63
            for protocol_version in [3, 2]:
 
64
                response_handler = self._send_request(
 
65
                    protocol_version, method, args, body=body,
 
66
                    readv_body=readv_body)
 
67
                try:
 
68
                    response_tuple = response_handler.read_response_tuple(
 
69
                        expect_body=expect_response_body)
 
70
                except errors.UnexpectedProtocolVersionMarker, err:
 
71
                    # TODO: We could recover from this without disconnecting if
 
72
                    # we recognise the protocol version.
 
73
                    warning(
 
74
                        'Server does not understand Bazaar network protocol %d,'
 
75
                        ' reconnecting.  (Upgrade the server to avoid this.)'
 
76
                        % (protocol_version,))
 
77
                    self._medium.disconnect()
 
78
                    continue
 
79
                except errors.ErrorFromSmartServer:
 
80
                    # If we received an error reply from the server, then it
 
81
                    # must be ok with this protocol version.
 
82
                    self._medium._protocol_version = protocol_version
 
83
                    raise
 
84
                else:
 
85
                    self._medium._protocol_version = protocol_version
 
86
                    return response_tuple, response_handler
 
87
            raise errors.SmartProtocolError(
 
88
                'Server is not a Bazaar server: ' + str(err))
 
89
 
 
90
    def _construct_protocol(self, version):
40
91
        request = self._medium.get_request()
41
 
        if version == 2:
42
 
            smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
92
        if version == 3:
 
93
            request_encoder = protocol.ProtocolThreeRequester(request)
 
94
            response_handler = message.ConventionalResponseHandler()
 
95
            response_proto = protocol.ProtocolThreeDecoder(
 
96
                response_handler, expect_version_marker=True)
 
97
            response_handler.setProtoAndMediumRequest(response_proto, request)
 
98
        elif version == 2:
 
99
            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
 
100
            response_handler = request_encoder
43
101
        else:
44
 
            smart_protocol = protocol.SmartClientRequestProtocolOne(request)
45
 
        return smart_protocol
 
102
            request_encoder = protocol.SmartClientRequestProtocolOne(request)
 
103
            response_handler = request_encoder
 
104
        return request_encoder, response_handler
46
105
 
47
106
    def call(self, method, *args):
48
107
        """Call a method on the remote server."""
58
117
            result, smart_protocol = smart_client.call_expecting_body(...)
59
118
            body = smart_protocol.read_body_bytes()
60
119
        """
61
 
        smart_protocol = self._build_client_protocol()
62
 
        smart_protocol.call(method, *args)
63
 
        return smart_protocol.read_response_tuple(expect_body=True), smart_protocol
 
120
        return self._call_and_read_response(
 
121
            method, args, expect_response_body=True)
64
122
 
65
123
    def call_with_body_bytes(self, method, args, body):
66
124
        """Call a method on the remote server with body bytes."""
71
129
                raise TypeError('args must be byte strings, not %r' % (args,))
72
130
        if type(body) is not str:
73
131
            raise TypeError('body must be byte string, not %r' % (body,))
74
 
        smart_protocol = self._build_client_protocol()
75
 
        smart_protocol.call_with_body_bytes((method, ) + args, body)
76
 
        return smart_protocol.read_response_tuple()
 
132
        response, response_handler = self._call_and_read_response(
 
133
            method, args, body=body, expect_response_body=False)
 
134
        return response
77
135
 
78
136
    def call_with_body_bytes_expecting_body(self, method, args, body):
79
137
        """Call a method on the remote server with body bytes."""
84
142
                raise TypeError('args must be byte strings, not %r' % (args,))
85
143
        if type(body) is not str:
86
144
            raise TypeError('body must be byte string, not %r' % (body,))
87
 
        smart_protocol = self._build_client_protocol()
88
 
        smart_protocol.call_with_body_bytes((method, ) + args, body)
89
 
        return smart_protocol.read_response_tuple(expect_body=True), smart_protocol
 
145
        response, response_handler = self._call_and_read_response(
 
146
            method, args, body=body, expect_response_body=True)
 
147
        return (response, response_handler)
 
148
 
 
149
    def call_with_body_readv_array(self, args, body):
 
150
        response, response_handler = self._call_and_read_response(
 
151
                args[0], args[1:], readv_body=body, expect_response_body=True)
 
152
        return (response, response_handler)
90
153
 
91
154
    def remote_path_from_transport(self, transport):
92
155
        """Convert transport into a path suitable for using in a request.
95
158
        anything but path, so it is only safe to use it in requests sent over
96
159
        the medium from the matching transport.
97
160
        """
98
 
        base = self._base
99
 
        if (base.startswith('bzr+http://') or base.startswith('bzr+https://')
100
 
            or base.startswith('http://') or base.startswith('https://')):
101
 
            medium_base = self._base
102
 
        else:
103
 
            medium_base = urlutils.join(self._base, '/')
104
 
            
105
 
        rel_url = urlutils.relative_url(medium_base, transport.base)
106
 
        return urllib.unquote(rel_url)
 
161
        return self._medium.remote_path_from_transport(transport)
107
162