~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Vincent Ladeuil
  • Date: 2012-01-18 14:09:19 UTC
  • mto: This revision was merged to the branch mainline in revision 6468.
  • Revision ID: v.ladeuil+lp@free.fr-20120118140919-rlvdrhpc0nq1lbwi
Change set/remove to require a lock for the branch config files.

This means that tests (or any plugin for that matter) do not requires an
explicit lock on the branch anymore to change a single option. This also
means the optimisation becomes "opt-in" and as such won't be as
spectacular as it may be and/or harder to get right (nothing fails
anymore).

This reduces the diff by ~300 lines.

Code/tests that were updating more than one config option is still taking
a lock to at least avoid some IOs and demonstrate the benefits through
the decreased number of hpss calls.

The duplication between BranchStack and BranchOnlyStack will be removed
once the same sharing is in place for local config files, at which point
the Stack class itself may be able to host the changes.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
19
19
"""
20
20
 
 
21
from __future__ import absolute_import
21
22
 
 
23
import collections
22
24
from cStringIO import StringIO
23
 
 
24
 
from bzrlib import errors
25
 
from bzrlib.smart import request
 
25
import struct
 
26
import sys
 
27
import thread
 
28
import time
 
29
 
 
30
import bzrlib
 
31
from bzrlib import (
 
32
    debug,
 
33
    errors,
 
34
    osutils,
 
35
    )
 
36
from bzrlib.smart import message, request
 
37
from bzrlib.trace import log_exception_quietly, mutter
 
38
from bzrlib.bencode import bdecode_as_tuple, bencode
 
39
 
 
40
 
 
41
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
42
# responses to identify the protocol version being used. (There are no version
 
43
# one strings because that version doesn't send any).
 
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
46
 
 
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
48
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
26
49
 
27
50
 
28
51
def _recv_tuple(from_file):
31
54
 
32
55
 
33
56
def _decode_tuple(req_line):
34
 
    if req_line == None or req_line == '':
 
57
    if req_line is None or req_line == '':
35
58
        return None
36
59
    if req_line[-1] != '\n':
37
60
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
40
63
 
41
64
def _encode_tuple(args):
42
65
    """Encode the tuple args to a bytestream."""
43
 
    return '\x01'.join(args) + '\n'
 
66
    joined = '\x01'.join(args) + '\n'
 
67
    if type(joined) is unicode:
 
68
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
69
        mutter('response args contain unicode, should be only bytes: %r',
 
70
               joined)
 
71
        joined = joined.encode('ascii')
 
72
    return joined
 
73
 
 
74
 
 
75
class Requester(object):
 
76
    """Abstract base class for an object that can issue requests on a smart
 
77
    medium.
 
78
    """
 
79
 
 
80
    def call(self, *args):
 
81
        """Make a remote call.
 
82
 
 
83
        :param args: the arguments of this call.
 
84
        """
 
85
        raise NotImplementedError(self.call)
 
86
 
 
87
    def call_with_body_bytes(self, args, body):
 
88
        """Make a remote call with a body.
 
89
 
 
90
        :param args: the arguments of this call.
 
91
        :type body: str
 
92
        :param body: the body to send with the request.
 
93
        """
 
94
        raise NotImplementedError(self.call_with_body_bytes)
 
95
 
 
96
    def call_with_body_readv_array(self, args, body):
 
97
        """Make a remote call with a readv array.
 
98
 
 
99
        :param args: the arguments of this call.
 
100
        :type body: iterable of (start, length) tuples.
 
101
        :param body: the readv ranges to send with this request.
 
102
        """
 
103
        raise NotImplementedError(self.call_with_body_readv_array)
 
104
 
 
105
    def set_headers(self, headers):
 
106
        raise NotImplementedError(self.set_headers)
44
107
 
45
108
 
46
109
class SmartProtocolBase(object):
58
121
        for start, length in offsets:
59
122
            txt.append('%d,%d' % (start, length))
60
123
        return '\n'.join(txt)
61
 
        
 
124
 
62
125
 
63
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
64
127
    """Server-side encoding and decoding logic for smart version 1."""
65
 
    
66
 
    def __init__(self, backing_transport, write_func):
 
128
 
 
129
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
130
            jail_root=None):
67
131
        self._backing_transport = backing_transport
68
 
        self.excess_buffer = ''
 
132
        self._root_client_path = root_client_path
 
133
        self._jail_root = jail_root
 
134
        self.unused_data = ''
69
135
        self._finished = False
70
136
        self.in_buffer = ''
71
 
        self.has_dispatched = False
 
137
        self._has_dispatched = False
72
138
        self.request = None
73
139
        self._body_decoder = None
74
140
        self._write_func = write_func
75
141
 
76
142
    def accept_bytes(self, bytes):
77
143
        """Take bytes, and advance the internal state machine appropriately.
78
 
        
 
144
 
79
145
        :param bytes: must be a byte string
80
146
        """
81
 
        assert isinstance(bytes, str)
 
147
        if not isinstance(bytes, str):
 
148
            raise ValueError(bytes)
82
149
        self.in_buffer += bytes
83
 
        if not self.has_dispatched:
 
150
        if not self._has_dispatched:
84
151
            if '\n' not in self.in_buffer:
85
152
                # no command line yet
86
153
                return
87
 
            self.has_dispatched = True
 
154
            self._has_dispatched = True
88
155
            try:
89
156
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
90
157
                first_line += '\n'
91
158
                req_args = _decode_tuple(first_line)
92
159
                self.request = request.SmartServerRequestHandler(
93
 
                    self._backing_transport, commands=request.request_handlers)
94
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
160
                    self._backing_transport, commands=request.request_handlers,
 
161
                    root_client_path=self._root_client_path,
 
162
                    jail_root=self._jail_root)
 
163
                self.request.args_received(req_args)
95
164
                if self.request.finished_reading:
96
165
                    # trivial request
97
 
                    self.excess_buffer = self.in_buffer
 
166
                    self.unused_data = self.in_buffer
98
167
                    self.in_buffer = ''
99
 
                    self._send_response(self.request.response.args,
100
 
                        self.request.response.body)
 
168
                    self._send_response(self.request.response)
101
169
            except KeyboardInterrupt:
102
170
                raise
 
171
            except errors.UnknownSmartMethod, err:
 
172
                protocol_error = errors.SmartProtocolError(
 
173
                    "bad request %r" % (err.verb,))
 
174
                failure = request.FailedSmartServerResponse(
 
175
                    ('error', str(protocol_error)))
 
176
                self._send_response(failure)
 
177
                return
103
178
            except Exception, exception:
104
179
                # everything else: pass to client, flush, and quit
105
 
                self._send_response(('error', str(exception)))
 
180
                log_exception_quietly()
 
181
                self._send_response(request.FailedSmartServerResponse(
 
182
                    ('error', str(exception))))
106
183
                return
107
184
 
108
 
        if self.has_dispatched:
 
185
        if self._has_dispatched:
109
186
            if self._finished:
110
 
                # nothing to do.XXX: this routine should be a single state 
 
187
                # nothing to do.XXX: this routine should be a single state
111
188
                # machine too.
112
 
                self.excess_buffer += self.in_buffer
 
189
                self.unused_data += self.in_buffer
113
190
                self.in_buffer = ''
114
191
                return
115
192
            if self._body_decoder is None:
120
197
            self.request.accept_body(body_data)
121
198
            if self._body_decoder.finished_reading:
122
199
                self.request.end_of_body()
123
 
                assert self.request.finished_reading, \
124
 
                    "no more body, request not finished"
 
200
                if not self.request.finished_reading:
 
201
                    raise AssertionError("no more body, request not finished")
125
202
            if self.request.response is not None:
126
 
                self._send_response(self.request.response.args,
127
 
                    self.request.response.body)
128
 
                self.excess_buffer = self.in_buffer
 
203
                self._send_response(self.request.response)
 
204
                self.unused_data = self.in_buffer
129
205
                self.in_buffer = ''
130
206
            else:
131
 
                assert not self.request.finished_reading, \
132
 
                    "no response and we have finished reading."
 
207
                if self.request.finished_reading:
 
208
                    raise AssertionError(
 
209
                        "no response and we have finished reading.")
133
210
 
134
 
    def _send_response(self, args, body=None):
 
211
    def _send_response(self, response):
135
212
        """Send a smart server response down the output stream."""
136
 
        assert not self._finished, 'response already sent'
 
213
        if self._finished:
 
214
            raise AssertionError('response already sent')
 
215
        args = response.args
 
216
        body = response.body
137
217
        self._finished = True
 
218
        self._write_protocol_version()
 
219
        self._write_success_or_failure_prefix(response)
138
220
        self._write_func(_encode_tuple(args))
139
221
        if body is not None:
140
 
            assert isinstance(body, str), 'body must be a str'
 
222
            if not isinstance(body, str):
 
223
                raise ValueError(body)
141
224
            bytes = self._encode_bulk_data(body)
142
225
            self._write_func(bytes)
143
226
 
 
227
    def _write_protocol_version(self):
 
228
        """Write any prefixes this protocol requires.
 
229
 
 
230
        Version one doesn't send protocol versions.
 
231
        """
 
232
 
 
233
    def _write_success_or_failure_prefix(self, response):
 
234
        """Write the protocol specific success/failure prefix.
 
235
 
 
236
        For SmartServerRequestProtocolOne this is omitted but we
 
237
        call is_successful to ensure that the response is valid.
 
238
        """
 
239
        response.is_successful()
 
240
 
144
241
    def next_read_size(self):
145
242
        if self._finished:
146
243
            return 0
150
247
            return self._body_decoder.next_read_size()
151
248
 
152
249
 
153
 
class LengthPrefixedBodyDecoder(object):
154
 
    """Decodes the length-prefixed bulk data."""
155
 
    
 
250
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
251
    r"""Version two of the server side of the smart protocol.
 
252
 
 
253
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
254
    """
 
255
 
 
256
    response_marker = RESPONSE_VERSION_TWO
 
257
    request_marker = REQUEST_VERSION_TWO
 
258
 
 
259
    def _write_success_or_failure_prefix(self, response):
 
260
        """Write the protocol specific success/failure prefix."""
 
261
        if response.is_successful():
 
262
            self._write_func('success\n')
 
263
        else:
 
264
            self._write_func('failed\n')
 
265
 
 
266
    def _write_protocol_version(self):
 
267
        r"""Write any prefixes this protocol requires.
 
268
 
 
269
        Version two sends the value of RESPONSE_VERSION_TWO.
 
270
        """
 
271
        self._write_func(self.response_marker)
 
272
 
 
273
    def _send_response(self, response):
 
274
        """Send a smart server response down the output stream."""
 
275
        if (self._finished):
 
276
            raise AssertionError('response already sent')
 
277
        self._finished = True
 
278
        self._write_protocol_version()
 
279
        self._write_success_or_failure_prefix(response)
 
280
        self._write_func(_encode_tuple(response.args))
 
281
        if response.body is not None:
 
282
            if not isinstance(response.body, str):
 
283
                raise AssertionError('body must be a str')
 
284
            if not (response.body_stream is None):
 
285
                raise AssertionError(
 
286
                    'body_stream and body cannot both be set')
 
287
            bytes = self._encode_bulk_data(response.body)
 
288
            self._write_func(bytes)
 
289
        elif response.body_stream is not None:
 
290
            _send_stream(response.body_stream, self._write_func)
 
291
 
 
292
 
 
293
def _send_stream(stream, write_func):
 
294
    write_func('chunked\n')
 
295
    _send_chunks(stream, write_func)
 
296
    write_func('END\n')
 
297
 
 
298
 
 
299
def _send_chunks(stream, write_func):
 
300
    for chunk in stream:
 
301
        if isinstance(chunk, str):
 
302
            bytes = "%x\n%s" % (len(chunk), chunk)
 
303
            write_func(bytes)
 
304
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
305
            write_func('ERR\n')
 
306
            _send_chunks(chunk.args, write_func)
 
307
            return
 
308
        else:
 
309
            raise errors.BzrError(
 
310
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
311
                % chunk)
 
312
 
 
313
 
 
314
class _NeedMoreBytes(Exception):
 
315
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
316
    have been received.
 
317
    """
 
318
 
 
319
    def __init__(self, count=None):
 
320
        """Constructor.
 
321
 
 
322
        :param count: the total number of bytes needed by the current state.
 
323
            May be None if the number of bytes needed is unknown.
 
324
        """
 
325
        self.count = count
 
326
 
 
327
 
 
328
class _StatefulDecoder(object):
 
329
    """Base class for writing state machines to decode byte streams.
 
330
 
 
331
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
332
    and, if appropriate, updates self.state_accept to a different function.
 
333
    accept_bytes will call state_accept as often as necessary to make sure the
 
334
    state machine has progressed as far as possible before it returns.
 
335
 
 
336
    See ProtocolThreeDecoder for an example subclass.
 
337
    """
 
338
 
156
339
    def __init__(self):
 
340
        self.finished_reading = False
 
341
        self._in_buffer_list = []
 
342
        self._in_buffer_len = 0
 
343
        self.unused_data = ''
157
344
        self.bytes_left = None
158
 
        self.finished_reading = False
159
 
        self.unused_data = ''
160
 
        self.state_accept = self._state_accept_expecting_length
161
 
        self.state_read = self._state_read_no_data
162
 
        self._in_buffer = ''
163
 
        self._trailer_buffer = ''
164
 
    
 
345
        self._number_needed_bytes = None
 
346
 
 
347
    def _get_in_buffer(self):
 
348
        if len(self._in_buffer_list) == 1:
 
349
            return self._in_buffer_list[0]
 
350
        in_buffer = ''.join(self._in_buffer_list)
 
351
        if len(in_buffer) != self._in_buffer_len:
 
352
            raise AssertionError(
 
353
                "Length of buffer did not match expected value: %s != %s"
 
354
                % self._in_buffer_len, len(in_buffer))
 
355
        self._in_buffer_list = [in_buffer]
 
356
        return in_buffer
 
357
 
 
358
    def _get_in_bytes(self, count):
 
359
        """Grab X bytes from the input_buffer.
 
360
 
 
361
        Callers should have already checked that self._in_buffer_len is >
 
362
        count. Note, this does not consume the bytes from the buffer. The
 
363
        caller will still need to call _get_in_buffer() and then
 
364
        _set_in_buffer() if they actually need to consume the bytes.
 
365
        """
 
366
        # check if we can yield the bytes from just the first entry in our list
 
367
        if len(self._in_buffer_list) == 0:
 
368
            raise AssertionError('Callers must be sure we have buffered bytes'
 
369
                ' before calling _get_in_bytes')
 
370
        if len(self._in_buffer_list[0]) > count:
 
371
            return self._in_buffer_list[0][:count]
 
372
        # We can't yield it from the first buffer, so collapse all buffers, and
 
373
        # yield it from that
 
374
        in_buf = self._get_in_buffer()
 
375
        return in_buf[:count]
 
376
 
 
377
    def _set_in_buffer(self, new_buf):
 
378
        if new_buf is not None:
 
379
            self._in_buffer_list = [new_buf]
 
380
            self._in_buffer_len = len(new_buf)
 
381
        else:
 
382
            self._in_buffer_list = []
 
383
            self._in_buffer_len = 0
 
384
 
165
385
    def accept_bytes(self, bytes):
166
386
        """Decode as much of bytes as possible.
167
387
 
172
392
        data will be appended to self.unused_data.
173
393
        """
174
394
        # accept_bytes is allowed to change the state
175
 
        current_state = self.state_accept
176
 
        self.state_accept(bytes)
177
 
        while current_state != self.state_accept:
 
395
        self._number_needed_bytes = None
 
396
        # lsprof puts a very large amount of time on this specific call for
 
397
        # large readv arrays
 
398
        self._in_buffer_list.append(bytes)
 
399
        self._in_buffer_len += len(bytes)
 
400
        try:
 
401
            # Run the function for the current state.
178
402
            current_state = self.state_accept
179
 
            self.state_accept('')
 
403
            self.state_accept()
 
404
            while current_state != self.state_accept:
 
405
                # The current state has changed.  Run the function for the new
 
406
                # current state, so that it can:
 
407
                #   - decode any unconsumed bytes left in a buffer, and
 
408
                #   - signal how many more bytes are expected (via raising
 
409
                #     _NeedMoreBytes).
 
410
                current_state = self.state_accept
 
411
                self.state_accept()
 
412
        except _NeedMoreBytes, e:
 
413
            self._number_needed_bytes = e.count
 
414
 
 
415
 
 
416
class ChunkedBodyDecoder(_StatefulDecoder):
 
417
    """Decoder for chunked body data.
 
418
 
 
419
    This is very similar the HTTP's chunked encoding.  See the description of
 
420
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
421
    """
 
422
 
 
423
    def __init__(self):
 
424
        _StatefulDecoder.__init__(self)
 
425
        self.state_accept = self._state_accept_expecting_header
 
426
        self.chunk_in_progress = None
 
427
        self.chunks = collections.deque()
 
428
        self.error = False
 
429
        self.error_in_progress = None
 
430
 
 
431
    def next_read_size(self):
 
432
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
433
        # end-of-body marker is 4 bytes: 'END\n'.
 
434
        if self.state_accept == self._state_accept_reading_chunk:
 
435
            # We're expecting more chunk content.  So we're expecting at least
 
436
            # the rest of this chunk plus an END chunk.
 
437
            return self.bytes_left + 4
 
438
        elif self.state_accept == self._state_accept_expecting_length:
 
439
            if self._in_buffer_len == 0:
 
440
                # We're expecting a chunk length.  There's at least two bytes
 
441
                # left: a digit plus '\n'.
 
442
                return 2
 
443
            else:
 
444
                # We're in the middle of reading a chunk length.  So there's at
 
445
                # least one byte left, the '\n' that terminates the length.
 
446
                return 1
 
447
        elif self.state_accept == self._state_accept_reading_unused:
 
448
            return 1
 
449
        elif self.state_accept == self._state_accept_expecting_header:
 
450
            return max(0, len('chunked\n') - self._in_buffer_len)
 
451
        else:
 
452
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
453
 
 
454
    def read_next_chunk(self):
 
455
        try:
 
456
            return self.chunks.popleft()
 
457
        except IndexError:
 
458
            return None
 
459
 
 
460
    def _extract_line(self):
 
461
        in_buf = self._get_in_buffer()
 
462
        pos = in_buf.find('\n')
 
463
        if pos == -1:
 
464
            # We haven't read a complete line yet, so request more bytes before
 
465
            # we continue.
 
466
            raise _NeedMoreBytes(1)
 
467
        line = in_buf[:pos]
 
468
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
469
        self._set_in_buffer(in_buf[pos+1:])
 
470
        return line
 
471
 
 
472
    def _finished(self):
 
473
        self.unused_data = self._get_in_buffer()
 
474
        self._in_buffer_list = []
 
475
        self._in_buffer_len = 0
 
476
        self.state_accept = self._state_accept_reading_unused
 
477
        if self.error:
 
478
            error_args = tuple(self.error_in_progress)
 
479
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
480
            self.error_in_progress = None
 
481
        self.finished_reading = True
 
482
 
 
483
    def _state_accept_expecting_header(self):
 
484
        prefix = self._extract_line()
 
485
        if prefix == 'chunked':
 
486
            self.state_accept = self._state_accept_expecting_length
 
487
        else:
 
488
            raise errors.SmartProtocolError(
 
489
                'Bad chunked body header: "%s"' % (prefix,))
 
490
 
 
491
    def _state_accept_expecting_length(self):
 
492
        prefix = self._extract_line()
 
493
        if prefix == 'ERR':
 
494
            self.error = True
 
495
            self.error_in_progress = []
 
496
            self._state_accept_expecting_length()
 
497
            return
 
498
        elif prefix == 'END':
 
499
            # We've read the end-of-body marker.
 
500
            # Any further bytes are unused data, including the bytes left in
 
501
            # the _in_buffer.
 
502
            self._finished()
 
503
            return
 
504
        else:
 
505
            self.bytes_left = int(prefix, 16)
 
506
            self.chunk_in_progress = ''
 
507
            self.state_accept = self._state_accept_reading_chunk
 
508
 
 
509
    def _state_accept_reading_chunk(self):
 
510
        in_buf = self._get_in_buffer()
 
511
        in_buffer_len = len(in_buf)
 
512
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
513
        self._set_in_buffer(in_buf[self.bytes_left:])
 
514
        self.bytes_left -= in_buffer_len
 
515
        if self.bytes_left <= 0:
 
516
            # Finished with chunk
 
517
            self.bytes_left = None
 
518
            if self.error:
 
519
                self.error_in_progress.append(self.chunk_in_progress)
 
520
            else:
 
521
                self.chunks.append(self.chunk_in_progress)
 
522
            self.chunk_in_progress = None
 
523
            self.state_accept = self._state_accept_expecting_length
 
524
 
 
525
    def _state_accept_reading_unused(self):
 
526
        self.unused_data += self._get_in_buffer()
 
527
        self._in_buffer_list = []
 
528
 
 
529
 
 
530
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
531
    """Decodes the length-prefixed bulk data."""
 
532
 
 
533
    def __init__(self):
 
534
        _StatefulDecoder.__init__(self)
 
535
        self.state_accept = self._state_accept_expecting_length
 
536
        self.state_read = self._state_read_no_data
 
537
        self._body = ''
 
538
        self._trailer_buffer = ''
180
539
 
181
540
    def next_read_size(self):
182
541
        if self.bytes_left is not None:
193
552
        else:
194
553
            # Reading excess data.  Either way, 1 byte at a time is fine.
195
554
            return 1
196
 
        
 
555
 
197
556
    def read_pending_data(self):
198
557
        """Return any pending data that has been decoded."""
199
558
        return self.state_read()
200
559
 
201
 
    def _state_accept_expecting_length(self, bytes):
202
 
        self._in_buffer += bytes
203
 
        pos = self._in_buffer.find('\n')
 
560
    def _state_accept_expecting_length(self):
 
561
        in_buf = self._get_in_buffer()
 
562
        pos = in_buf.find('\n')
204
563
        if pos == -1:
205
564
            return
206
 
        self.bytes_left = int(self._in_buffer[:pos])
207
 
        self._in_buffer = self._in_buffer[pos+1:]
208
 
        self.bytes_left -= len(self._in_buffer)
 
565
        self.bytes_left = int(in_buf[:pos])
 
566
        self._set_in_buffer(in_buf[pos+1:])
209
567
        self.state_accept = self._state_accept_reading_body
210
 
        self.state_read = self._state_read_in_buffer
 
568
        self.state_read = self._state_read_body_buffer
211
569
 
212
 
    def _state_accept_reading_body(self, bytes):
213
 
        self._in_buffer += bytes
214
 
        self.bytes_left -= len(bytes)
 
570
    def _state_accept_reading_body(self):
 
571
        in_buf = self._get_in_buffer()
 
572
        self._body += in_buf
 
573
        self.bytes_left -= len(in_buf)
 
574
        self._set_in_buffer(None)
215
575
        if self.bytes_left <= 0:
216
576
            # Finished with body
217
577
            if self.bytes_left != 0:
218
 
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
219
 
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
578
                self._trailer_buffer = self._body[self.bytes_left:]
 
579
                self._body = self._body[:self.bytes_left]
220
580
            self.bytes_left = None
221
581
            self.state_accept = self._state_accept_reading_trailer
222
 
        
223
 
    def _state_accept_reading_trailer(self, bytes):
224
 
        self._trailer_buffer += bytes
 
582
 
 
583
    def _state_accept_reading_trailer(self):
 
584
        self._trailer_buffer += self._get_in_buffer()
 
585
        self._set_in_buffer(None)
225
586
        # TODO: what if the trailer does not match "done\n"?  Should this raise
226
587
        # a ProtocolViolation exception?
227
588
        if self._trailer_buffer.startswith('done\n'):
228
589
            self.unused_data = self._trailer_buffer[len('done\n'):]
229
590
            self.state_accept = self._state_accept_reading_unused
230
591
            self.finished_reading = True
231
 
    
232
 
    def _state_accept_reading_unused(self, bytes):
233
 
        self.unused_data += bytes
 
592
 
 
593
    def _state_accept_reading_unused(self):
 
594
        self.unused_data += self._get_in_buffer()
 
595
        self._set_in_buffer(None)
234
596
 
235
597
    def _state_read_no_data(self):
236
598
        return ''
237
599
 
238
 
    def _state_read_in_buffer(self):
239
 
        result = self._in_buffer
240
 
        self._in_buffer = ''
 
600
    def _state_read_body_buffer(self):
 
601
        result = self._body
 
602
        self._body = ''
241
603
        return result
242
604
 
243
605
 
244
 
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
606
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
607
                                    message.ResponseHandler):
245
608
    """The client-side protocol for smart version 1."""
246
609
 
247
610
    def __init__(self, request):
252
615
        """
253
616
        self._request = request
254
617
        self._body_buffer = None
 
618
        self._request_start_time = None
 
619
        self._last_verb = None
 
620
        self._headers = None
 
621
 
 
622
    def set_headers(self, headers):
 
623
        self._headers = dict(headers)
255
624
 
256
625
    def call(self, *args):
257
 
        bytes = _encode_tuple(args)
258
 
        self._request.accept_bytes(bytes)
 
626
        if 'hpss' in debug.debug_flags:
 
627
            mutter('hpss call:   %s', repr(args)[1:-1])
 
628
            if getattr(self._request._medium, 'base', None) is not None:
 
629
                mutter('             (to %s)', self._request._medium.base)
 
630
            self._request_start_time = osutils.timer_func()
 
631
        self._write_args(args)
259
632
        self._request.finished_writing()
 
633
        self._last_verb = args[0]
260
634
 
261
635
    def call_with_body_bytes(self, args, body):
262
636
        """Make a remote call of args with body bytes 'body'.
263
637
 
264
638
        After calling this, call read_response_tuple to find the result out.
265
639
        """
266
 
        bytes = _encode_tuple(args)
267
 
        self._request.accept_bytes(bytes)
 
640
        if 'hpss' in debug.debug_flags:
 
641
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
642
            if getattr(self._request._medium, '_path', None) is not None:
 
643
                mutter('                  (to %s)', self._request._medium._path)
 
644
            mutter('              %d bytes', len(body))
 
645
            self._request_start_time = osutils.timer_func()
 
646
            if 'hpssdetail' in debug.debug_flags:
 
647
                mutter('hpss body content: %s', body)
 
648
        self._write_args(args)
268
649
        bytes = self._encode_bulk_data(body)
269
650
        self._request.accept_bytes(bytes)
270
651
        self._request.finished_writing()
 
652
        self._last_verb = args[0]
271
653
 
272
654
    def call_with_body_readv_array(self, args, body):
273
655
        """Make a remote call with a readv array.
274
656
 
275
657
        The body is encoded with one line per readv offset pair. The numbers in
276
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
658
        each pair are separated by a comma, and no trailing \\n is emitted.
277
659
        """
278
 
        bytes = _encode_tuple(args)
279
 
        self._request.accept_bytes(bytes)
 
660
        if 'hpss' in debug.debug_flags:
 
661
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
662
            if getattr(self._request._medium, '_path', None) is not None:
 
663
                mutter('                  (to %s)', self._request._medium._path)
 
664
            self._request_start_time = osutils.timer_func()
 
665
        self._write_args(args)
280
666
        readv_bytes = self._serialise_offsets(body)
281
667
        bytes = self._encode_bulk_data(readv_bytes)
282
668
        self._request.accept_bytes(bytes)
283
669
        self._request.finished_writing()
 
670
        if 'hpss' in debug.debug_flags:
 
671
            mutter('              %d bytes in readv request', len(readv_bytes))
 
672
        self._last_verb = args[0]
 
673
 
 
674
    def call_with_body_stream(self, args, stream):
 
675
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
676
        # assume that a v1/v2 server doesn't support whatever method we're
 
677
        # trying to call with a body stream.
 
678
        self._request.finished_writing()
 
679
        self._request.finished_reading()
 
680
        raise errors.UnknownSmartMethod(args[0])
284
681
 
285
682
    def cancel_read_body(self):
286
683
        """After expecting a body, a response code may indicate one otherwise.
291
688
        """
292
689
        self._request.finished_reading()
293
690
 
294
 
    def read_response_tuple(self, expect_body=False):
295
 
        """Read a response tuple from the wire.
296
 
 
297
 
        This should only be called once.
298
 
        """
 
691
    def _read_response_tuple(self):
299
692
        result = self._recv_tuple()
 
693
        if 'hpss' in debug.debug_flags:
 
694
            if self._request_start_time is not None:
 
695
                mutter('   result:   %6.3fs  %s',
 
696
                       osutils.timer_func() - self._request_start_time,
 
697
                       repr(result)[1:-1])
 
698
                self._request_start_time = None
 
699
            else:
 
700
                mutter('   result:   %s', repr(result)[1:-1])
 
701
        return result
 
702
 
 
703
    def read_response_tuple(self, expect_body=False):
 
704
        """Read a response tuple from the wire.
 
705
 
 
706
        This should only be called once.
 
707
        """
 
708
        result = self._read_response_tuple()
 
709
        self._response_is_unknown_method(result)
 
710
        self._raise_args_if_error(result)
300
711
        if not expect_body:
301
712
            self._request.finished_reading()
302
713
        return result
303
714
 
 
715
    def _raise_args_if_error(self, result_tuple):
 
716
        # Later protocol versions have an explicit flag in the protocol to say
 
717
        # if an error response is "failed" or not.  In version 1 we don't have
 
718
        # that luxury.  So here is a complete list of errors that can be
 
719
        # returned in response to existing version 1 smart requests.  Responses
 
720
        # starting with these codes are always "failed" responses.
 
721
        v1_error_codes = [
 
722
            'norepository',
 
723
            'NoSuchFile',
 
724
            'FileExists',
 
725
            'DirectoryNotEmpty',
 
726
            'ShortReadvError',
 
727
            'UnicodeEncodeError',
 
728
            'UnicodeDecodeError',
 
729
            'ReadOnlyError',
 
730
            'nobranch',
 
731
            'NoSuchRevision',
 
732
            'nosuchrevision',
 
733
            'LockContention',
 
734
            'UnlockableTransport',
 
735
            'LockFailed',
 
736
            'TokenMismatch',
 
737
            'ReadError',
 
738
            'PermissionDenied',
 
739
            ]
 
740
        if result_tuple[0] in v1_error_codes:
 
741
            self._request.finished_reading()
 
742
            raise errors.ErrorFromSmartServer(result_tuple)
 
743
 
 
744
    def _response_is_unknown_method(self, result_tuple):
 
745
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
746
        method' response to the request.
 
747
 
 
748
        :param response: The response from a smart client call_expecting_body
 
749
            call.
 
750
        :param verb: The verb used in that call.
 
751
        :raises: UnexpectedSmartServerResponse
 
752
        """
 
753
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
754
                "bad request '%s'" % self._last_verb) or
 
755
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
756
                "bad request u'%s'" % self._last_verb)):
 
757
            # The response will have no body, so we've finished reading.
 
758
            self._request.finished_reading()
 
759
            raise errors.UnknownSmartMethod(self._last_verb)
 
760
 
304
761
    def read_body_bytes(self, count=-1):
305
762
        """Read bytes from the body, decoding into a byte stream.
306
 
        
307
 
        We read all bytes at once to ensure we've checked the trailer for 
 
763
 
 
764
        We read all bytes at once to ensure we've checked the trailer for
308
765
        errors, and then feed the buffer back as read_body_bytes is called.
309
766
        """
310
767
        if self._body_buffer is not None:
312
769
        _body_decoder = LengthPrefixedBodyDecoder()
313
770
 
314
771
        while not _body_decoder.finished_reading:
315
 
            bytes_wanted = _body_decoder.next_read_size()
316
 
            bytes = self._request.read_bytes(bytes_wanted)
 
772
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
773
            if bytes == '':
 
774
                # end of file encountered reading from server
 
775
                raise errors.ConnectionReset(
 
776
                    "Connection lost while reading response body.")
317
777
            _body_decoder.accept_bytes(bytes)
318
778
        self._request.finished_reading()
319
779
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
320
780
        # XXX: TODO check the trailer result.
 
781
        if 'hpss' in debug.debug_flags:
 
782
            mutter('              %d body bytes read',
 
783
                   len(self._body_buffer.getvalue()))
321
784
        return self._body_buffer.read(count)
322
785
 
323
786
    def _recv_tuple(self):
324
787
        """Receive a tuple from the medium request."""
325
 
        line = ''
326
 
        while not line or line[-1] != '\n':
327
 
            # TODO: this is inefficient - but tuples are short.
328
 
            new_char = self._request.read_bytes(1)
329
 
            line += new_char
330
 
            assert new_char != '', "end of file reading from server."
331
 
        return _decode_tuple(line)
 
788
        return _decode_tuple(self._request.read_line())
332
789
 
333
790
    def query_version(self):
334
791
        """Return protocol version number of the server."""
336
793
        resp = self.read_response_tuple()
337
794
        if resp == ('ok', '1'):
338
795
            return 1
 
796
        elif resp == ('ok', '2'):
 
797
            return 2
339
798
        else:
340
799
            raise errors.SmartProtocolError("bad response %r" % (resp,))
341
800
 
342
 
 
 
801
    def _write_args(self, args):
 
802
        self._write_protocol_version()
 
803
        bytes = _encode_tuple(args)
 
804
        self._request.accept_bytes(bytes)
 
805
 
 
806
    def _write_protocol_version(self):
 
807
        """Write any prefixes this protocol requires.
 
808
 
 
809
        Version one doesn't send protocol versions.
 
810
        """
 
811
 
 
812
 
 
813
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
814
    """Version two of the client side of the smart protocol.
 
815
 
 
816
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
817
    """
 
818
 
 
819
    response_marker = RESPONSE_VERSION_TWO
 
820
    request_marker = REQUEST_VERSION_TWO
 
821
 
 
822
    def read_response_tuple(self, expect_body=False):
 
823
        """Read a response tuple from the wire.
 
824
 
 
825
        This should only be called once.
 
826
        """
 
827
        version = self._request.read_line()
 
828
        if version != self.response_marker:
 
829
            self._request.finished_reading()
 
830
            raise errors.UnexpectedProtocolVersionMarker(version)
 
831
        response_status = self._request.read_line()
 
832
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
833
        self._response_is_unknown_method(result)
 
834
        if response_status == 'success\n':
 
835
            self.response_status = True
 
836
            if not expect_body:
 
837
                self._request.finished_reading()
 
838
            return result
 
839
        elif response_status == 'failed\n':
 
840
            self.response_status = False
 
841
            self._request.finished_reading()
 
842
            raise errors.ErrorFromSmartServer(result)
 
843
        else:
 
844
            raise errors.SmartProtocolError(
 
845
                'bad protocol status %r' % response_status)
 
846
 
 
847
    def _write_protocol_version(self):
 
848
        """Write any prefixes this protocol requires.
 
849
 
 
850
        Version two sends the value of REQUEST_VERSION_TWO.
 
851
        """
 
852
        self._request.accept_bytes(self.request_marker)
 
853
 
 
854
    def read_streamed_body(self):
 
855
        """Read bytes from the body, decoding into a byte stream.
 
856
        """
 
857
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
858
        # buffer space available) on Windows.
 
859
        _body_decoder = ChunkedBodyDecoder()
 
860
        while not _body_decoder.finished_reading:
 
861
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
862
            if bytes == '':
 
863
                # end of file encountered reading from server
 
864
                raise errors.ConnectionReset(
 
865
                    "Connection lost while reading streamed body.")
 
866
            _body_decoder.accept_bytes(bytes)
 
867
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
868
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
869
                    mutter('              %d byte chunk read',
 
870
                           len(body_bytes))
 
871
                yield body_bytes
 
872
        self._request.finished_reading()
 
873
 
 
874
 
 
875
def build_server_protocol_three(backing_transport, write_func,
 
876
                                root_client_path, jail_root=None):
 
877
    request_handler = request.SmartServerRequestHandler(
 
878
        backing_transport, commands=request.request_handlers,
 
879
        root_client_path=root_client_path, jail_root=jail_root)
 
880
    responder = ProtocolThreeResponder(write_func)
 
881
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
882
    return ProtocolThreeDecoder(message_handler)
 
883
 
 
884
 
 
885
class ProtocolThreeDecoder(_StatefulDecoder):
 
886
 
 
887
    response_marker = RESPONSE_VERSION_THREE
 
888
    request_marker = REQUEST_VERSION_THREE
 
889
 
 
890
    def __init__(self, message_handler, expect_version_marker=False):
 
891
        _StatefulDecoder.__init__(self)
 
892
        self._has_dispatched = False
 
893
        # Initial state
 
894
        if expect_version_marker:
 
895
            self.state_accept = self._state_accept_expecting_protocol_version
 
896
            # We're expecting at least the protocol version marker + some
 
897
            # headers.
 
898
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
899
        else:
 
900
            self.state_accept = self._state_accept_expecting_headers
 
901
            self._number_needed_bytes = 4
 
902
        self.decoding_failed = False
 
903
        self.request_handler = self.message_handler = message_handler
 
904
 
 
905
    def accept_bytes(self, bytes):
 
906
        self._number_needed_bytes = None
 
907
        try:
 
908
            _StatefulDecoder.accept_bytes(self, bytes)
 
909
        except KeyboardInterrupt:
 
910
            raise
 
911
        except errors.SmartMessageHandlerError, exception:
 
912
            # We do *not* set self.decoding_failed here.  The message handler
 
913
            # has raised an error, but the decoder is still able to parse bytes
 
914
            # and determine when this message ends.
 
915
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
916
                log_exception_quietly()
 
917
            self.message_handler.protocol_error(exception.exc_value)
 
918
            # The state machine is ready to continue decoding, but the
 
919
            # exception has interrupted the loop that runs the state machine.
 
920
            # So we call accept_bytes again to restart it.
 
921
            self.accept_bytes('')
 
922
        except Exception, exception:
 
923
            # The decoder itself has raised an exception.  We cannot continue
 
924
            # decoding.
 
925
            self.decoding_failed = True
 
926
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
927
                # This happens during normal operation when the client tries a
 
928
                # protocol version the server doesn't understand, so no need to
 
929
                # log a traceback every time.
 
930
                # Note that this can only happen when
 
931
                # expect_version_marker=True, which is only the case on the
 
932
                # client side.
 
933
                pass
 
934
            else:
 
935
                log_exception_quietly()
 
936
            self.message_handler.protocol_error(exception)
 
937
 
 
938
    def _extract_length_prefixed_bytes(self):
 
939
        if self._in_buffer_len < 4:
 
940
            # A length prefix by itself is 4 bytes, and we don't even have that
 
941
            # many yet.
 
942
            raise _NeedMoreBytes(4)
 
943
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
944
        end_of_bytes = 4 + length
 
945
        if self._in_buffer_len < end_of_bytes:
 
946
            # We haven't yet read as many bytes as the length-prefix says there
 
947
            # are.
 
948
            raise _NeedMoreBytes(end_of_bytes)
 
949
        # Extract the bytes from the buffer.
 
950
        in_buf = self._get_in_buffer()
 
951
        bytes = in_buf[4:end_of_bytes]
 
952
        self._set_in_buffer(in_buf[end_of_bytes:])
 
953
        return bytes
 
954
 
 
955
    def _extract_prefixed_bencoded_data(self):
 
956
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
957
        try:
 
958
            decoded = bdecode_as_tuple(prefixed_bytes)
 
959
        except ValueError:
 
960
            raise errors.SmartProtocolError(
 
961
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
962
        return decoded
 
963
 
 
964
    def _extract_single_byte(self):
 
965
        if self._in_buffer_len == 0:
 
966
            # The buffer is empty
 
967
            raise _NeedMoreBytes(1)
 
968
        in_buf = self._get_in_buffer()
 
969
        one_byte = in_buf[0]
 
970
        self._set_in_buffer(in_buf[1:])
 
971
        return one_byte
 
972
 
 
973
    def _state_accept_expecting_protocol_version(self):
 
974
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
975
        in_buf = self._get_in_buffer()
 
976
        if needed_bytes > 0:
 
977
            # We don't have enough bytes to check if the protocol version
 
978
            # marker is right.  But we can check if it is already wrong by
 
979
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
980
            # we've read so far.
 
981
            # [In fact, if the remote end isn't bzr we might never receive
 
982
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
983
            # are wrong then we should just raise immediately rather than
 
984
            # stall.]
 
985
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
986
                # We have enough bytes to know the protocol version is wrong
 
987
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
988
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
989
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
990
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
991
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
992
        self.state_accept = self._state_accept_expecting_headers
 
993
 
 
994
    def _state_accept_expecting_headers(self):
 
995
        decoded = self._extract_prefixed_bencoded_data()
 
996
        if type(decoded) is not dict:
 
997
            raise errors.SmartProtocolError(
 
998
                'Header object %r is not a dict' % (decoded,))
 
999
        self.state_accept = self._state_accept_expecting_message_part
 
1000
        try:
 
1001
            self.message_handler.headers_received(decoded)
 
1002
        except:
 
1003
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1004
 
 
1005
    def _state_accept_expecting_message_part(self):
 
1006
        message_part_kind = self._extract_single_byte()
 
1007
        if message_part_kind == 'o':
 
1008
            self.state_accept = self._state_accept_expecting_one_byte
 
1009
        elif message_part_kind == 's':
 
1010
            self.state_accept = self._state_accept_expecting_structure
 
1011
        elif message_part_kind == 'b':
 
1012
            self.state_accept = self._state_accept_expecting_bytes
 
1013
        elif message_part_kind == 'e':
 
1014
            self.done()
 
1015
        else:
 
1016
            raise errors.SmartProtocolError(
 
1017
                'Bad message kind byte: %r' % (message_part_kind,))
 
1018
 
 
1019
    def _state_accept_expecting_one_byte(self):
 
1020
        byte = self._extract_single_byte()
 
1021
        self.state_accept = self._state_accept_expecting_message_part
 
1022
        try:
 
1023
            self.message_handler.byte_part_received(byte)
 
1024
        except:
 
1025
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1026
 
 
1027
    def _state_accept_expecting_bytes(self):
 
1028
        # XXX: this should not buffer whole message part, but instead deliver
 
1029
        # the bytes as they arrive.
 
1030
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1031
        self.state_accept = self._state_accept_expecting_message_part
 
1032
        try:
 
1033
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1034
        except:
 
1035
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1036
 
 
1037
    def _state_accept_expecting_structure(self):
 
1038
        structure = self._extract_prefixed_bencoded_data()
 
1039
        self.state_accept = self._state_accept_expecting_message_part
 
1040
        try:
 
1041
            self.message_handler.structure_part_received(structure)
 
1042
        except:
 
1043
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1044
 
 
1045
    def done(self):
 
1046
        self.unused_data = self._get_in_buffer()
 
1047
        self._set_in_buffer(None)
 
1048
        self.state_accept = self._state_accept_reading_unused
 
1049
        try:
 
1050
            self.message_handler.end_received()
 
1051
        except:
 
1052
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1053
 
 
1054
    def _state_accept_reading_unused(self):
 
1055
        self.unused_data += self._get_in_buffer()
 
1056
        self._set_in_buffer(None)
 
1057
 
 
1058
    def next_read_size(self):
 
1059
        if self.state_accept == self._state_accept_reading_unused:
 
1060
            return 0
 
1061
        elif self.decoding_failed:
 
1062
            # An exception occured while processing this message, probably from
 
1063
            # self.message_handler.  We're not sure that this state machine is
 
1064
            # in a consistent state, so just signal that we're done (i.e. give
 
1065
            # up).
 
1066
            return 0
 
1067
        else:
 
1068
            if self._number_needed_bytes is not None:
 
1069
                return self._number_needed_bytes - self._in_buffer_len
 
1070
            else:
 
1071
                raise AssertionError("don't know how many bytes are expected!")
 
1072
 
 
1073
 
 
1074
class _ProtocolThreeEncoder(object):
 
1075
 
 
1076
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1077
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1078
 
 
1079
    def __init__(self, write_func):
 
1080
        self._buf = []
 
1081
        self._buf_len = 0
 
1082
        self._real_write_func = write_func
 
1083
 
 
1084
    def _write_func(self, bytes):
 
1085
        # TODO: Another possibility would be to turn this into an async model.
 
1086
        #       Where we let another thread know that we have some bytes if
 
1087
        #       they want it, but we don't actually block for it
 
1088
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1089
        #       we might just push out smaller bits at a time?
 
1090
        self._buf.append(bytes)
 
1091
        self._buf_len += len(bytes)
 
1092
        if self._buf_len > self.BUFFER_SIZE:
 
1093
            self.flush()
 
1094
 
 
1095
    def flush(self):
 
1096
        if self._buf:
 
1097
            self._real_write_func(''.join(self._buf))
 
1098
            del self._buf[:]
 
1099
            self._buf_len = 0
 
1100
 
 
1101
    def _serialise_offsets(self, offsets):
 
1102
        """Serialise a readv offset list."""
 
1103
        txt = []
 
1104
        for start, length in offsets:
 
1105
            txt.append('%d,%d' % (start, length))
 
1106
        return '\n'.join(txt)
 
1107
 
 
1108
    def _write_protocol_version(self):
 
1109
        self._write_func(MESSAGE_VERSION_THREE)
 
1110
 
 
1111
    def _write_prefixed_bencode(self, structure):
 
1112
        bytes = bencode(structure)
 
1113
        self._write_func(struct.pack('!L', len(bytes)))
 
1114
        self._write_func(bytes)
 
1115
 
 
1116
    def _write_headers(self, headers):
 
1117
        self._write_prefixed_bencode(headers)
 
1118
 
 
1119
    def _write_structure(self, args):
 
1120
        self._write_func('s')
 
1121
        utf8_args = []
 
1122
        for arg in args:
 
1123
            if type(arg) is unicode:
 
1124
                utf8_args.append(arg.encode('utf8'))
 
1125
            else:
 
1126
                utf8_args.append(arg)
 
1127
        self._write_prefixed_bencode(utf8_args)
 
1128
 
 
1129
    def _write_end(self):
 
1130
        self._write_func('e')
 
1131
        self.flush()
 
1132
 
 
1133
    def _write_prefixed_body(self, bytes):
 
1134
        self._write_func('b')
 
1135
        self._write_func(struct.pack('!L', len(bytes)))
 
1136
        self._write_func(bytes)
 
1137
 
 
1138
    def _write_chunked_body_start(self):
 
1139
        self._write_func('oC')
 
1140
 
 
1141
    def _write_error_status(self):
 
1142
        self._write_func('oE')
 
1143
 
 
1144
    def _write_success_status(self):
 
1145
        self._write_func('oS')
 
1146
 
 
1147
 
 
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1149
 
 
1150
    def __init__(self, write_func):
 
1151
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1152
        self.response_sent = False
 
1153
        self._headers = {'Software version': bzrlib.__version__}
 
1154
        if 'hpss' in debug.debug_flags:
 
1155
            self._thread_id = thread.get_ident()
 
1156
            self._response_start_time = None
 
1157
 
 
1158
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1159
        if self._response_start_time is None:
 
1160
            self._response_start_time = osutils.timer_func()
 
1161
        if include_time:
 
1162
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1163
        else:
 
1164
            t = ''
 
1165
        if extra_bytes is None:
 
1166
            extra = ''
 
1167
        else:
 
1168
            extra = ' ' + repr(extra_bytes[:40])
 
1169
            if len(extra) > 33:
 
1170
                extra = extra[:29] + extra[-1] + '...'
 
1171
        mutter('%12s: [%s] %s%s%s'
 
1172
               % (action, self._thread_id, t, message, extra))
 
1173
 
 
1174
    def send_error(self, exception):
 
1175
        if self.response_sent:
 
1176
            raise AssertionError(
 
1177
                "send_error(%s) called, but response already sent."
 
1178
                % (exception,))
 
1179
        if isinstance(exception, errors.UnknownSmartMethod):
 
1180
            failure = request.FailedSmartServerResponse(
 
1181
                ('UnknownMethod', exception.verb))
 
1182
            self.send_response(failure)
 
1183
            return
 
1184
        if 'hpss' in debug.debug_flags:
 
1185
            self._trace('error', str(exception))
 
1186
        self.response_sent = True
 
1187
        self._write_protocol_version()
 
1188
        self._write_headers(self._headers)
 
1189
        self._write_error_status()
 
1190
        self._write_structure(('error', str(exception)))
 
1191
        self._write_end()
 
1192
 
 
1193
    def send_response(self, response):
 
1194
        if self.response_sent:
 
1195
            raise AssertionError(
 
1196
                "send_response(%r) called, but response already sent."
 
1197
                % (response,))
 
1198
        self.response_sent = True
 
1199
        self._write_protocol_version()
 
1200
        self._write_headers(self._headers)
 
1201
        if response.is_successful():
 
1202
            self._write_success_status()
 
1203
        else:
 
1204
            self._write_error_status()
 
1205
        if 'hpss' in debug.debug_flags:
 
1206
            self._trace('response', repr(response.args))
 
1207
        self._write_structure(response.args)
 
1208
        if response.body is not None:
 
1209
            self._write_prefixed_body(response.body)
 
1210
            if 'hpss' in debug.debug_flags:
 
1211
                self._trace('body', '%d bytes' % (len(response.body),),
 
1212
                            response.body, include_time=True)
 
1213
        elif response.body_stream is not None:
 
1214
            count = num_bytes = 0
 
1215
            first_chunk = None
 
1216
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1217
                count += 1
 
1218
                if exc_info is not None:
 
1219
                    self._write_error_status()
 
1220
                    error_struct = request._translate_error(exc_info[1])
 
1221
                    self._write_structure(error_struct)
 
1222
                    break
 
1223
                else:
 
1224
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1225
                        self._write_error_status()
 
1226
                        self._write_structure(chunk.args)
 
1227
                        break
 
1228
                    num_bytes += len(chunk)
 
1229
                    if first_chunk is None:
 
1230
                        first_chunk = chunk
 
1231
                    self._write_prefixed_body(chunk)
 
1232
                    self.flush()
 
1233
                    if 'hpssdetail' in debug.debug_flags:
 
1234
                        # Not worth timing separately, as _write_func is
 
1235
                        # actually buffered
 
1236
                        self._trace('body chunk',
 
1237
                                    '%d bytes' % (len(chunk),),
 
1238
                                    chunk, suppress_time=True)
 
1239
            if 'hpss' in debug.debug_flags:
 
1240
                self._trace('body stream',
 
1241
                            '%d bytes %d chunks' % (num_bytes, count),
 
1242
                            first_chunk)
 
1243
        self._write_end()
 
1244
        if 'hpss' in debug.debug_flags:
 
1245
            self._trace('response end', '', include_time=True)
 
1246
 
 
1247
 
 
1248
def _iter_with_errors(iterable):
 
1249
    """Handle errors from iterable.next().
 
1250
 
 
1251
    Use like::
 
1252
 
 
1253
        for exc_info, value in _iter_with_errors(iterable):
 
1254
            ...
 
1255
 
 
1256
    This is a safer alternative to::
 
1257
 
 
1258
        try:
 
1259
            for value in iterable:
 
1260
               ...
 
1261
        except:
 
1262
            ...
 
1263
 
 
1264
    Because the latter will catch errors from the for-loop body, not just
 
1265
    iterable.next()
 
1266
 
 
1267
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1268
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1269
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1270
    will not be itercepted.
 
1271
    """
 
1272
    iterator = iter(iterable)
 
1273
    while True:
 
1274
        try:
 
1275
            yield None, iterator.next()
 
1276
        except StopIteration:
 
1277
            return
 
1278
        except (KeyboardInterrupt, SystemExit):
 
1279
            raise
 
1280
        except Exception:
 
1281
            mutter('_iter_with_errors caught error')
 
1282
            log_exception_quietly()
 
1283
            yield sys.exc_info(), None
 
1284
            return
 
1285
 
 
1286
 
 
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1288
 
 
1289
    def __init__(self, medium_request):
 
1290
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1291
        self._medium_request = medium_request
 
1292
        self._headers = {}
 
1293
        self.body_stream_started = None
 
1294
 
 
1295
    def set_headers(self, headers):
 
1296
        self._headers = headers.copy()
 
1297
 
 
1298
    def call(self, *args):
 
1299
        if 'hpss' in debug.debug_flags:
 
1300
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1301
            base = getattr(self._medium_request._medium, 'base', None)
 
1302
            if base is not None:
 
1303
                mutter('             (to %s)', base)
 
1304
            self._request_start_time = osutils.timer_func()
 
1305
        self._write_protocol_version()
 
1306
        self._write_headers(self._headers)
 
1307
        self._write_structure(args)
 
1308
        self._write_end()
 
1309
        self._medium_request.finished_writing()
 
1310
 
 
1311
    def call_with_body_bytes(self, args, body):
 
1312
        """Make a remote call of args with body bytes 'body'.
 
1313
 
 
1314
        After calling this, call read_response_tuple to find the result out.
 
1315
        """
 
1316
        if 'hpss' in debug.debug_flags:
 
1317
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1318
            path = getattr(self._medium_request._medium, '_path', None)
 
1319
            if path is not None:
 
1320
                mutter('                  (to %s)', path)
 
1321
            mutter('              %d bytes', len(body))
 
1322
            self._request_start_time = osutils.timer_func()
 
1323
        self._write_protocol_version()
 
1324
        self._write_headers(self._headers)
 
1325
        self._write_structure(args)
 
1326
        self._write_prefixed_body(body)
 
1327
        self._write_end()
 
1328
        self._medium_request.finished_writing()
 
1329
 
 
1330
    def call_with_body_readv_array(self, args, body):
 
1331
        """Make a remote call with a readv array.
 
1332
 
 
1333
        The body is encoded with one line per readv offset pair. The numbers in
 
1334
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1335
        """
 
1336
        if 'hpss' in debug.debug_flags:
 
1337
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1338
            path = getattr(self._medium_request._medium, '_path', None)
 
1339
            if path is not None:
 
1340
                mutter('                  (to %s)', path)
 
1341
            self._request_start_time = osutils.timer_func()
 
1342
        self._write_protocol_version()
 
1343
        self._write_headers(self._headers)
 
1344
        self._write_structure(args)
 
1345
        readv_bytes = self._serialise_offsets(body)
 
1346
        if 'hpss' in debug.debug_flags:
 
1347
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1348
        self._write_prefixed_body(readv_bytes)
 
1349
        self._write_end()
 
1350
        self._medium_request.finished_writing()
 
1351
 
 
1352
    def call_with_body_stream(self, args, stream):
 
1353
        if 'hpss' in debug.debug_flags:
 
1354
            mutter('hpss call w/body stream: %r', args)
 
1355
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
            if path is not None:
 
1357
                mutter('                  (to %s)', path)
 
1358
            self._request_start_time = osutils.timer_func()
 
1359
        self.body_stream_started = False
 
1360
        self._write_protocol_version()
 
1361
        self._write_headers(self._headers)
 
1362
        self._write_structure(args)
 
1363
        # TODO: notice if the server has sent an early error reply before we
 
1364
        #       have finished sending the stream.  We would notice at the end
 
1365
        #       anyway, but if the medium can deliver it early then it's good
 
1366
        #       to short-circuit the whole request...
 
1367
        # Provoke any ConnectionReset failures before we start the body stream.
 
1368
        self.flush()
 
1369
        self.body_stream_started = True
 
1370
        for exc_info, part in _iter_with_errors(stream):
 
1371
            if exc_info is not None:
 
1372
                # Iterating the stream failed.  Cleanly abort the request.
 
1373
                self._write_error_status()
 
1374
                # Currently the client unconditionally sends ('error',) as the
 
1375
                # error args.
 
1376
                self._write_structure(('error',))
 
1377
                self._write_end()
 
1378
                self._medium_request.finished_writing()
 
1379
                raise exc_info[0], exc_info[1], exc_info[2]
 
1380
            else:
 
1381
                self._write_prefixed_body(part)
 
1382
                self.flush()
 
1383
        self._write_end()
 
1384
        self._medium_request.finished_writing()
343
1385