~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/client.py

  • Committer: Martin Pool
  • Date: 2010-02-17 05:12:01 UTC
  • mfrom: (4797.2.16 2.1)
  • mto: This revision was merged to the branch mainline in revision 5037.
  • Revision ID: mbp@sourcefrog.net-20100217051201-1sd9dssoujfdc6c4
merge 2.1 back to trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2008 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import bzrlib
 
18
from bzrlib.smart import message, protocol
 
19
from bzrlib.trace import warning
 
20
from bzrlib import (
 
21
    errors,
 
22
    hooks,
 
23
    )
 
24
 
 
25
 
 
26
class _SmartClient(object):
 
27
 
 
28
    def __init__(self, medium, headers=None):
 
29
        """Constructor.
 
30
 
 
31
        :param medium: a SmartClientMedium
 
32
        """
 
33
        self._medium = medium
 
34
        if headers is None:
 
35
            self._headers = {'Software version': bzrlib.__version__}
 
36
        else:
 
37
            self._headers = dict(headers)
 
38
 
 
39
    def __repr__(self):
 
40
        return '%s(%r)' % (self.__class__.__name__, self._medium)
 
41
 
 
42
    def _send_request(self, protocol_version, method, args, body=None,
 
43
                      readv_body=None, body_stream=None):
 
44
        encoder, response_handler = self._construct_protocol(
 
45
            protocol_version)
 
46
        encoder.set_headers(self._headers)
 
47
        if body is not None:
 
48
            if readv_body is not None:
 
49
                raise AssertionError(
 
50
                    "body and readv_body are mutually exclusive.")
 
51
            if body_stream is not None:
 
52
                raise AssertionError(
 
53
                    "body and body_stream are mutually exclusive.")
 
54
            encoder.call_with_body_bytes((method, ) + args, body)
 
55
        elif readv_body is not None:
 
56
            if body_stream is not None:
 
57
                raise AssertionError(
 
58
                    "readv_body and body_stream are mutually exclusive.")
 
59
            encoder.call_with_body_readv_array((method, ) + args, readv_body)
 
60
        elif body_stream is not None:
 
61
            encoder.call_with_body_stream((method, ) + args, body_stream)
 
62
        else:
 
63
            encoder.call(method, *args)
 
64
        return response_handler
 
65
 
 
66
    def _run_call_hooks(self, method, args, body, readv_body):
 
67
        if not _SmartClient.hooks['call']:
 
68
            return
 
69
        params = CallHookParams(method, args, body, readv_body, self._medium)
 
70
        for hook in _SmartClient.hooks['call']:
 
71
            hook(params)
 
72
 
 
73
    def _call_and_read_response(self, method, args, body=None, readv_body=None,
 
74
            body_stream=None, expect_response_body=True):
 
75
        self._run_call_hooks(method, args, body, readv_body)
 
76
        if self._medium._protocol_version is not None:
 
77
            response_handler = self._send_request(
 
78
                self._medium._protocol_version, method, args, body=body,
 
79
                readv_body=readv_body, body_stream=body_stream)
 
80
            return (response_handler.read_response_tuple(
 
81
                        expect_body=expect_response_body),
 
82
                    response_handler)
 
83
        else:
 
84
            for protocol_version in [3, 2]:
 
85
                if protocol_version == 2:
 
86
                    # If v3 doesn't work, the remote side is older than 1.6.
 
87
                    self._medium._remember_remote_is_before((1, 6))
 
88
                response_handler = self._send_request(
 
89
                    protocol_version, method, args, body=body,
 
90
                    readv_body=readv_body, body_stream=body_stream)
 
91
                try:
 
92
                    response_tuple = response_handler.read_response_tuple(
 
93
                        expect_body=expect_response_body)
 
94
                except errors.UnexpectedProtocolVersionMarker, err:
 
95
                    # TODO: We could recover from this without disconnecting if
 
96
                    # we recognise the protocol version.
 
97
                    warning(
 
98
                        'Server does not understand Bazaar network protocol %d,'
 
99
                        ' reconnecting.  (Upgrade the server to avoid this.)'
 
100
                        % (protocol_version,))
 
101
                    self._medium.disconnect()
 
102
                    continue
 
103
                except errors.ErrorFromSmartServer:
 
104
                    # If we received an error reply from the server, then it
 
105
                    # must be ok with this protocol version.
 
106
                    self._medium._protocol_version = protocol_version
 
107
                    raise
 
108
                else:
 
109
                    self._medium._protocol_version = protocol_version
 
110
                    return response_tuple, response_handler
 
111
            raise errors.SmartProtocolError(
 
112
                'Server is not a Bazaar server: ' + str(err))
 
113
 
 
114
    def _construct_protocol(self, version):
 
115
        request = self._medium.get_request()
 
116
        if version == 3:
 
117
            request_encoder = protocol.ProtocolThreeRequester(request)
 
118
            response_handler = message.ConventionalResponseHandler()
 
119
            response_proto = protocol.ProtocolThreeDecoder(
 
120
                response_handler, expect_version_marker=True)
 
121
            response_handler.setProtoAndMediumRequest(response_proto, request)
 
122
        elif version == 2:
 
123
            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
 
124
            response_handler = request_encoder
 
125
        else:
 
126
            request_encoder = protocol.SmartClientRequestProtocolOne(request)
 
127
            response_handler = request_encoder
 
128
        return request_encoder, response_handler
 
129
 
 
130
    def call(self, method, *args):
 
131
        """Call a method on the remote server."""
 
132
        result, protocol = self.call_expecting_body(method, *args)
 
133
        protocol.cancel_read_body()
 
134
        return result
 
135
 
 
136
    def call_expecting_body(self, method, *args):
 
137
        """Call a method and return the result and the protocol object.
 
138
 
 
139
        The body can be read like so::
 
140
 
 
141
            result, smart_protocol = smart_client.call_expecting_body(...)
 
142
            body = smart_protocol.read_body_bytes()
 
143
        """
 
144
        return self._call_and_read_response(
 
145
            method, args, expect_response_body=True)
 
146
 
 
147
    def call_with_body_bytes(self, method, args, body):
 
148
        """Call a method on the remote server with body bytes."""
 
149
        if type(method) is not str:
 
150
            raise TypeError('method must be a byte string, not %r' % (method,))
 
151
        for arg in args:
 
152
            if type(arg) is not str:
 
153
                raise TypeError('args must be byte strings, not %r' % (args,))
 
154
        if type(body) is not str:
 
155
            raise TypeError('body must be byte string, not %r' % (body,))
 
156
        response, response_handler = self._call_and_read_response(
 
157
            method, args, body=body, expect_response_body=False)
 
158
        return response
 
159
 
 
160
    def call_with_body_bytes_expecting_body(self, method, args, body):
 
161
        """Call a method on the remote server with body bytes."""
 
162
        if type(method) is not str:
 
163
            raise TypeError('method must be a byte string, not %r' % (method,))
 
164
        for arg in args:
 
165
            if type(arg) is not str:
 
166
                raise TypeError('args must be byte strings, not %r' % (args,))
 
167
        if type(body) is not str:
 
168
            raise TypeError('body must be byte string, not %r' % (body,))
 
169
        response, response_handler = self._call_and_read_response(
 
170
            method, args, body=body, expect_response_body=True)
 
171
        return (response, response_handler)
 
172
 
 
173
    def call_with_body_readv_array(self, args, body):
 
174
        response, response_handler = self._call_and_read_response(
 
175
                args[0], args[1:], readv_body=body, expect_response_body=True)
 
176
        return (response, response_handler)
 
177
 
 
178
    def call_with_body_stream(self, args, stream):
 
179
        response, response_handler = self._call_and_read_response(
 
180
                args[0], args[1:], body_stream=stream,
 
181
                expect_response_body=False)
 
182
        return (response, response_handler)
 
183
 
 
184
    def remote_path_from_transport(self, transport):
 
185
        """Convert transport into a path suitable for using in a request.
 
186
 
 
187
        Note that the resulting remote path doesn't encode the host name or
 
188
        anything but path, so it is only safe to use it in requests sent over
 
189
        the medium from the matching transport.
 
190
        """
 
191
        return self._medium.remote_path_from_transport(transport)
 
192
 
 
193
 
 
194
class SmartClientHooks(hooks.Hooks):
 
195
 
 
196
    def __init__(self):
 
197
        hooks.Hooks.__init__(self)
 
198
        self.create_hook(hooks.HookPoint('call',
 
199
            "Called when the smart client is submitting a request to the "
 
200
            "smart server. Called with a bzrlib.smart.client.CallHookParams "
 
201
            "object. Streaming request bodies, and responses, are not "
 
202
            "accessible.", None, None))
 
203
 
 
204
 
 
205
_SmartClient.hooks = SmartClientHooks()
 
206
 
 
207
 
 
208
class CallHookParams(object):
 
209
 
 
210
    def __init__(self, method, args, body, readv_body, medium):
 
211
        self.method = method
 
212
        self.args = args
 
213
        self.body = body
 
214
        self.readv_body = readv_body
 
215
        self.medium = medium
 
216
 
 
217
    def __repr__(self):
 
218
        attrs = dict((k, v) for (k, v) in self.__dict__.iteritems()
 
219
                     if v is not None)
 
220
        return '<%s %r>' % (self.__class__.__name__, attrs)
 
221
 
 
222
    def __eq__(self, other):
 
223
        if type(other) is not type(self):
 
224
            return NotImplemented
 
225
        return self.__dict__ == other.__dict__
 
226
 
 
227
    def __ne__(self, other):
 
228
        return not self == other