~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Patch Queue Manager
  • Date: 2015-09-30 16:43:21 UTC
  • mfrom: (6603.2.2 fix-keep-dirty)
  • Revision ID: pqm@pqm.ubuntu.com-20150930164321-ct2v2qnmvimqt8qf
(vila) Avoid associating dirty patch headers with the previous file in the
 patch. (Colin Watson)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Wire-level encoding and decoding of requests and responses for the smart
18
18
client and server.
19
19
"""
20
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
import collections
21
24
from cStringIO import StringIO
 
25
import struct
 
26
import sys
 
27
import thread
22
28
import time
23
29
 
24
 
from bzrlib import debug
25
 
from bzrlib import errors
26
 
from bzrlib.smart import request
 
30
import bzrlib
 
31
from bzrlib import (
 
32
    debug,
 
33
    errors,
 
34
    osutils,
 
35
    )
 
36
from bzrlib.smart import message, request
27
37
from bzrlib.trace import log_exception_quietly, mutter
 
38
from bzrlib.bencode import bdecode_as_tuple, bencode
28
39
 
29
40
 
30
41
# Protocol version strings.  These are sent as prefixes of bzr requests and
33
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
34
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
35
46
 
 
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
48
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
49
 
36
50
 
37
51
def _recv_tuple(from_file):
38
52
    req_line = from_file.readline()
40
54
 
41
55
 
42
56
def _decode_tuple(req_line):
43
 
    if req_line == None or req_line == '':
 
57
    if req_line is None or req_line == '':
44
58
        return None
45
59
    if req_line[-1] != '\n':
46
60
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
49
63
 
50
64
def _encode_tuple(args):
51
65
    """Encode the tuple args to a bytestream."""
52
 
    return '\x01'.join(args) + '\n'
 
66
    joined = '\x01'.join(args) + '\n'
 
67
    if type(joined) is unicode:
 
68
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
69
        mutter('response args contain unicode, should be only bytes: %r',
 
70
               joined)
 
71
        joined = joined.encode('ascii')
 
72
    return joined
 
73
 
 
74
 
 
75
class Requester(object):
 
76
    """Abstract base class for an object that can issue requests on a smart
 
77
    medium.
 
78
    """
 
79
 
 
80
    def call(self, *args):
 
81
        """Make a remote call.
 
82
 
 
83
        :param args: the arguments of this call.
 
84
        """
 
85
        raise NotImplementedError(self.call)
 
86
 
 
87
    def call_with_body_bytes(self, args, body):
 
88
        """Make a remote call with a body.
 
89
 
 
90
        :param args: the arguments of this call.
 
91
        :type body: str
 
92
        :param body: the body to send with the request.
 
93
        """
 
94
        raise NotImplementedError(self.call_with_body_bytes)
 
95
 
 
96
    def call_with_body_readv_array(self, args, body):
 
97
        """Make a remote call with a readv array.
 
98
 
 
99
        :param args: the arguments of this call.
 
100
        :type body: iterable of (start, length) tuples.
 
101
        :param body: the readv ranges to send with this request.
 
102
        """
 
103
        raise NotImplementedError(self.call_with_body_readv_array)
 
104
 
 
105
    def set_headers(self, headers):
 
106
        raise NotImplementedError(self.set_headers)
53
107
 
54
108
 
55
109
class SmartProtocolBase(object):
67
121
        for start, length in offsets:
68
122
            txt.append('%d,%d' % (start, length))
69
123
        return '\n'.join(txt)
70
 
        
 
124
 
71
125
 
72
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
73
127
    """Server-side encoding and decoding logic for smart version 1."""
74
 
    
75
 
    def __init__(self, backing_transport, write_func):
 
128
 
 
129
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
130
            jail_root=None):
76
131
        self._backing_transport = backing_transport
77
 
        self.excess_buffer = ''
 
132
        self._root_client_path = root_client_path
 
133
        self._jail_root = jail_root
 
134
        self.unused_data = ''
78
135
        self._finished = False
79
136
        self.in_buffer = ''
80
 
        self.has_dispatched = False
 
137
        self._has_dispatched = False
81
138
        self.request = None
82
139
        self._body_decoder = None
83
140
        self._write_func = write_func
84
141
 
85
142
    def accept_bytes(self, bytes):
86
143
        """Take bytes, and advance the internal state machine appropriately.
87
 
        
 
144
 
88
145
        :param bytes: must be a byte string
89
146
        """
90
 
        assert isinstance(bytes, str)
 
147
        if not isinstance(bytes, str):
 
148
            raise ValueError(bytes)
91
149
        self.in_buffer += bytes
92
 
        if not self.has_dispatched:
 
150
        if not self._has_dispatched:
93
151
            if '\n' not in self.in_buffer:
94
152
                # no command line yet
95
153
                return
96
 
            self.has_dispatched = True
 
154
            self._has_dispatched = True
97
155
            try:
98
156
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
99
157
                first_line += '\n'
100
158
                req_args = _decode_tuple(first_line)
101
159
                self.request = request.SmartServerRequestHandler(
102
 
                    self._backing_transport, commands=request.request_handlers)
103
 
                self.request.dispatch_command(req_args[0], req_args[1:])
 
160
                    self._backing_transport, commands=request.request_handlers,
 
161
                    root_client_path=self._root_client_path,
 
162
                    jail_root=self._jail_root)
 
163
                self.request.args_received(req_args)
104
164
                if self.request.finished_reading:
105
165
                    # trivial request
106
 
                    self.excess_buffer = self.in_buffer
 
166
                    self.unused_data = self.in_buffer
107
167
                    self.in_buffer = ''
108
168
                    self._send_response(self.request.response)
109
169
            except KeyboardInterrupt:
110
170
                raise
 
171
            except errors.UnknownSmartMethod, err:
 
172
                protocol_error = errors.SmartProtocolError(
 
173
                    "bad request %r" % (err.verb,))
 
174
                failure = request.FailedSmartServerResponse(
 
175
                    ('error', str(protocol_error)))
 
176
                self._send_response(failure)
 
177
                return
111
178
            except Exception, exception:
112
179
                # everything else: pass to client, flush, and quit
113
180
                log_exception_quietly()
115
182
                    ('error', str(exception))))
116
183
                return
117
184
 
118
 
        if self.has_dispatched:
 
185
        if self._has_dispatched:
119
186
            if self._finished:
120
 
                # nothing to do.XXX: this routine should be a single state 
 
187
                # nothing to do.XXX: this routine should be a single state
121
188
                # machine too.
122
 
                self.excess_buffer += self.in_buffer
 
189
                self.unused_data += self.in_buffer
123
190
                self.in_buffer = ''
124
191
                return
125
192
            if self._body_decoder is None:
130
197
            self.request.accept_body(body_data)
131
198
            if self._body_decoder.finished_reading:
132
199
                self.request.end_of_body()
133
 
                assert self.request.finished_reading, \
134
 
                    "no more body, request not finished"
 
200
                if not self.request.finished_reading:
 
201
                    raise AssertionError("no more body, request not finished")
135
202
            if self.request.response is not None:
136
203
                self._send_response(self.request.response)
137
 
                self.excess_buffer = self.in_buffer
 
204
                self.unused_data = self.in_buffer
138
205
                self.in_buffer = ''
139
206
            else:
140
 
                assert not self.request.finished_reading, \
141
 
                    "no response and we have finished reading."
 
207
                if self.request.finished_reading:
 
208
                    raise AssertionError(
 
209
                        "no response and we have finished reading.")
142
210
 
143
211
    def _send_response(self, response):
144
212
        """Send a smart server response down the output stream."""
145
 
        assert not self._finished, 'response already sent'
 
213
        if self._finished:
 
214
            raise AssertionError('response already sent')
146
215
        args = response.args
147
216
        body = response.body
148
217
        self._finished = True
150
219
        self._write_success_or_failure_prefix(response)
151
220
        self._write_func(_encode_tuple(args))
152
221
        if body is not None:
153
 
            assert isinstance(body, str), 'body must be a str'
 
222
            if not isinstance(body, str):
 
223
                raise ValueError(body)
154
224
            bytes = self._encode_bulk_data(body)
155
225
            self._write_func(bytes)
156
226
 
157
227
    def _write_protocol_version(self):
158
228
        """Write any prefixes this protocol requires.
159
 
        
 
229
 
160
230
        Version one doesn't send protocol versions.
161
231
        """
162
232
 
179
249
 
180
250
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
181
251
    r"""Version two of the server side of the smart protocol.
182
 
   
 
252
 
183
253
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
184
254
    """
185
255
 
 
256
    response_marker = RESPONSE_VERSION_TWO
 
257
    request_marker = REQUEST_VERSION_TWO
 
258
 
186
259
    def _write_success_or_failure_prefix(self, response):
187
260
        """Write the protocol specific success/failure prefix."""
188
261
        if response.is_successful():
192
265
 
193
266
    def _write_protocol_version(self):
194
267
        r"""Write any prefixes this protocol requires.
195
 
        
 
268
 
196
269
        Version two sends the value of RESPONSE_VERSION_TWO.
197
270
        """
198
 
        self._write_func(RESPONSE_VERSION_TWO)
199
 
 
200
 
 
201
 
class LengthPrefixedBodyDecoder(object):
202
 
    """Decodes the length-prefixed bulk data."""
203
 
    
 
271
        self._write_func(self.response_marker)
 
272
 
 
273
    def _send_response(self, response):
 
274
        """Send a smart server response down the output stream."""
 
275
        if (self._finished):
 
276
            raise AssertionError('response already sent')
 
277
        self._finished = True
 
278
        self._write_protocol_version()
 
279
        self._write_success_or_failure_prefix(response)
 
280
        self._write_func(_encode_tuple(response.args))
 
281
        if response.body is not None:
 
282
            if not isinstance(response.body, str):
 
283
                raise AssertionError('body must be a str')
 
284
            if not (response.body_stream is None):
 
285
                raise AssertionError(
 
286
                    'body_stream and body cannot both be set')
 
287
            bytes = self._encode_bulk_data(response.body)
 
288
            self._write_func(bytes)
 
289
        elif response.body_stream is not None:
 
290
            _send_stream(response.body_stream, self._write_func)
 
291
 
 
292
 
 
293
def _send_stream(stream, write_func):
 
294
    write_func('chunked\n')
 
295
    _send_chunks(stream, write_func)
 
296
    write_func('END\n')
 
297
 
 
298
 
 
299
def _send_chunks(stream, write_func):
 
300
    for chunk in stream:
 
301
        if isinstance(chunk, str):
 
302
            bytes = "%x\n%s" % (len(chunk), chunk)
 
303
            write_func(bytes)
 
304
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
305
            write_func('ERR\n')
 
306
            _send_chunks(chunk.args, write_func)
 
307
            return
 
308
        else:
 
309
            raise errors.BzrError(
 
310
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
311
                % chunk)
 
312
 
 
313
 
 
314
class _NeedMoreBytes(Exception):
 
315
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
316
    have been received.
 
317
    """
 
318
 
 
319
    def __init__(self, count=None):
 
320
        """Constructor.
 
321
 
 
322
        :param count: the total number of bytes needed by the current state.
 
323
            May be None if the number of bytes needed is unknown.
 
324
        """
 
325
        self.count = count
 
326
 
 
327
 
 
328
class _StatefulDecoder(object):
 
329
    """Base class for writing state machines to decode byte streams.
 
330
 
 
331
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
332
    and, if appropriate, updates self.state_accept to a different function.
 
333
    accept_bytes will call state_accept as often as necessary to make sure the
 
334
    state machine has progressed as far as possible before it returns.
 
335
 
 
336
    See ProtocolThreeDecoder for an example subclass.
 
337
    """
 
338
 
204
339
    def __init__(self):
 
340
        self.finished_reading = False
 
341
        self._in_buffer_list = []
 
342
        self._in_buffer_len = 0
 
343
        self.unused_data = ''
205
344
        self.bytes_left = None
206
 
        self.finished_reading = False
207
 
        self.unused_data = ''
208
 
        self.state_accept = self._state_accept_expecting_length
209
 
        self.state_read = self._state_read_no_data
210
 
        self._in_buffer = ''
211
 
        self._trailer_buffer = ''
212
 
    
 
345
        self._number_needed_bytes = None
 
346
 
 
347
    def _get_in_buffer(self):
 
348
        if len(self._in_buffer_list) == 1:
 
349
            return self._in_buffer_list[0]
 
350
        in_buffer = ''.join(self._in_buffer_list)
 
351
        if len(in_buffer) != self._in_buffer_len:
 
352
            raise AssertionError(
 
353
                "Length of buffer did not match expected value: %s != %s"
 
354
                % self._in_buffer_len, len(in_buffer))
 
355
        self._in_buffer_list = [in_buffer]
 
356
        return in_buffer
 
357
 
 
358
    def _get_in_bytes(self, count):
 
359
        """Grab X bytes from the input_buffer.
 
360
 
 
361
        Callers should have already checked that self._in_buffer_len is >
 
362
        count. Note, this does not consume the bytes from the buffer. The
 
363
        caller will still need to call _get_in_buffer() and then
 
364
        _set_in_buffer() if they actually need to consume the bytes.
 
365
        """
 
366
        # check if we can yield the bytes from just the first entry in our list
 
367
        if len(self._in_buffer_list) == 0:
 
368
            raise AssertionError('Callers must be sure we have buffered bytes'
 
369
                ' before calling _get_in_bytes')
 
370
        if len(self._in_buffer_list[0]) > count:
 
371
            return self._in_buffer_list[0][:count]
 
372
        # We can't yield it from the first buffer, so collapse all buffers, and
 
373
        # yield it from that
 
374
        in_buf = self._get_in_buffer()
 
375
        return in_buf[:count]
 
376
 
 
377
    def _set_in_buffer(self, new_buf):
 
378
        if new_buf is not None:
 
379
            self._in_buffer_list = [new_buf]
 
380
            self._in_buffer_len = len(new_buf)
 
381
        else:
 
382
            self._in_buffer_list = []
 
383
            self._in_buffer_len = 0
 
384
 
213
385
    def accept_bytes(self, bytes):
214
386
        """Decode as much of bytes as possible.
215
387
 
220
392
        data will be appended to self.unused_data.
221
393
        """
222
394
        # accept_bytes is allowed to change the state
223
 
        current_state = self.state_accept
224
 
        self.state_accept(bytes)
225
 
        while current_state != self.state_accept:
 
395
        self._number_needed_bytes = None
 
396
        # lsprof puts a very large amount of time on this specific call for
 
397
        # large readv arrays
 
398
        self._in_buffer_list.append(bytes)
 
399
        self._in_buffer_len += len(bytes)
 
400
        try:
 
401
            # Run the function for the current state.
226
402
            current_state = self.state_accept
227
 
            self.state_accept('')
 
403
            self.state_accept()
 
404
            while current_state != self.state_accept:
 
405
                # The current state has changed.  Run the function for the new
 
406
                # current state, so that it can:
 
407
                #   - decode any unconsumed bytes left in a buffer, and
 
408
                #   - signal how many more bytes are expected (via raising
 
409
                #     _NeedMoreBytes).
 
410
                current_state = self.state_accept
 
411
                self.state_accept()
 
412
        except _NeedMoreBytes, e:
 
413
            self._number_needed_bytes = e.count
 
414
 
 
415
 
 
416
class ChunkedBodyDecoder(_StatefulDecoder):
 
417
    """Decoder for chunked body data.
 
418
 
 
419
    This is very similar the HTTP's chunked encoding.  See the description of
 
420
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
421
    """
 
422
 
 
423
    def __init__(self):
 
424
        _StatefulDecoder.__init__(self)
 
425
        self.state_accept = self._state_accept_expecting_header
 
426
        self.chunk_in_progress = None
 
427
        self.chunks = collections.deque()
 
428
        self.error = False
 
429
        self.error_in_progress = None
 
430
 
 
431
    def next_read_size(self):
 
432
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
433
        # end-of-body marker is 4 bytes: 'END\n'.
 
434
        if self.state_accept == self._state_accept_reading_chunk:
 
435
            # We're expecting more chunk content.  So we're expecting at least
 
436
            # the rest of this chunk plus an END chunk.
 
437
            return self.bytes_left + 4
 
438
        elif self.state_accept == self._state_accept_expecting_length:
 
439
            if self._in_buffer_len == 0:
 
440
                # We're expecting a chunk length.  There's at least two bytes
 
441
                # left: a digit plus '\n'.
 
442
                return 2
 
443
            else:
 
444
                # We're in the middle of reading a chunk length.  So there's at
 
445
                # least one byte left, the '\n' that terminates the length.
 
446
                return 1
 
447
        elif self.state_accept == self._state_accept_reading_unused:
 
448
            return 1
 
449
        elif self.state_accept == self._state_accept_expecting_header:
 
450
            return max(0, len('chunked\n') - self._in_buffer_len)
 
451
        else:
 
452
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
453
 
 
454
    def read_next_chunk(self):
 
455
        try:
 
456
            return self.chunks.popleft()
 
457
        except IndexError:
 
458
            return None
 
459
 
 
460
    def _extract_line(self):
 
461
        in_buf = self._get_in_buffer()
 
462
        pos = in_buf.find('\n')
 
463
        if pos == -1:
 
464
            # We haven't read a complete line yet, so request more bytes before
 
465
            # we continue.
 
466
            raise _NeedMoreBytes(1)
 
467
        line = in_buf[:pos]
 
468
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
469
        self._set_in_buffer(in_buf[pos+1:])
 
470
        return line
 
471
 
 
472
    def _finished(self):
 
473
        self.unused_data = self._get_in_buffer()
 
474
        self._in_buffer_list = []
 
475
        self._in_buffer_len = 0
 
476
        self.state_accept = self._state_accept_reading_unused
 
477
        if self.error:
 
478
            error_args = tuple(self.error_in_progress)
 
479
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
480
            self.error_in_progress = None
 
481
        self.finished_reading = True
 
482
 
 
483
    def _state_accept_expecting_header(self):
 
484
        prefix = self._extract_line()
 
485
        if prefix == 'chunked':
 
486
            self.state_accept = self._state_accept_expecting_length
 
487
        else:
 
488
            raise errors.SmartProtocolError(
 
489
                'Bad chunked body header: "%s"' % (prefix,))
 
490
 
 
491
    def _state_accept_expecting_length(self):
 
492
        prefix = self._extract_line()
 
493
        if prefix == 'ERR':
 
494
            self.error = True
 
495
            self.error_in_progress = []
 
496
            self._state_accept_expecting_length()
 
497
            return
 
498
        elif prefix == 'END':
 
499
            # We've read the end-of-body marker.
 
500
            # Any further bytes are unused data, including the bytes left in
 
501
            # the _in_buffer.
 
502
            self._finished()
 
503
            return
 
504
        else:
 
505
            self.bytes_left = int(prefix, 16)
 
506
            self.chunk_in_progress = ''
 
507
            self.state_accept = self._state_accept_reading_chunk
 
508
 
 
509
    def _state_accept_reading_chunk(self):
 
510
        in_buf = self._get_in_buffer()
 
511
        in_buffer_len = len(in_buf)
 
512
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
513
        self._set_in_buffer(in_buf[self.bytes_left:])
 
514
        self.bytes_left -= in_buffer_len
 
515
        if self.bytes_left <= 0:
 
516
            # Finished with chunk
 
517
            self.bytes_left = None
 
518
            if self.error:
 
519
                self.error_in_progress.append(self.chunk_in_progress)
 
520
            else:
 
521
                self.chunks.append(self.chunk_in_progress)
 
522
            self.chunk_in_progress = None
 
523
            self.state_accept = self._state_accept_expecting_length
 
524
 
 
525
    def _state_accept_reading_unused(self):
 
526
        self.unused_data += self._get_in_buffer()
 
527
        self._in_buffer_list = []
 
528
 
 
529
 
 
530
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
531
    """Decodes the length-prefixed bulk data."""
 
532
 
 
533
    def __init__(self):
 
534
        _StatefulDecoder.__init__(self)
 
535
        self.state_accept = self._state_accept_expecting_length
 
536
        self.state_read = self._state_read_no_data
 
537
        self._body = ''
 
538
        self._trailer_buffer = ''
228
539
 
229
540
    def next_read_size(self):
230
541
        if self.bytes_left is not None:
241
552
        else:
242
553
            # Reading excess data.  Either way, 1 byte at a time is fine.
243
554
            return 1
244
 
        
 
555
 
245
556
    def read_pending_data(self):
246
557
        """Return any pending data that has been decoded."""
247
558
        return self.state_read()
248
559
 
249
 
    def _state_accept_expecting_length(self, bytes):
250
 
        self._in_buffer += bytes
251
 
        pos = self._in_buffer.find('\n')
 
560
    def _state_accept_expecting_length(self):
 
561
        in_buf = self._get_in_buffer()
 
562
        pos = in_buf.find('\n')
252
563
        if pos == -1:
253
564
            return
254
 
        self.bytes_left = int(self._in_buffer[:pos])
255
 
        self._in_buffer = self._in_buffer[pos+1:]
256
 
        self.bytes_left -= len(self._in_buffer)
 
565
        self.bytes_left = int(in_buf[:pos])
 
566
        self._set_in_buffer(in_buf[pos+1:])
257
567
        self.state_accept = self._state_accept_reading_body
258
 
        self.state_read = self._state_read_in_buffer
 
568
        self.state_read = self._state_read_body_buffer
259
569
 
260
 
    def _state_accept_reading_body(self, bytes):
261
 
        self._in_buffer += bytes
262
 
        self.bytes_left -= len(bytes)
 
570
    def _state_accept_reading_body(self):
 
571
        in_buf = self._get_in_buffer()
 
572
        self._body += in_buf
 
573
        self.bytes_left -= len(in_buf)
 
574
        self._set_in_buffer(None)
263
575
        if self.bytes_left <= 0:
264
576
            # Finished with body
265
577
            if self.bytes_left != 0:
266
 
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
267
 
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
578
                self._trailer_buffer = self._body[self.bytes_left:]
 
579
                self._body = self._body[:self.bytes_left]
268
580
            self.bytes_left = None
269
581
            self.state_accept = self._state_accept_reading_trailer
270
 
        
271
 
    def _state_accept_reading_trailer(self, bytes):
272
 
        self._trailer_buffer += bytes
 
582
 
 
583
    def _state_accept_reading_trailer(self):
 
584
        self._trailer_buffer += self._get_in_buffer()
 
585
        self._set_in_buffer(None)
273
586
        # TODO: what if the trailer does not match "done\n"?  Should this raise
274
587
        # a ProtocolViolation exception?
275
588
        if self._trailer_buffer.startswith('done\n'):
276
589
            self.unused_data = self._trailer_buffer[len('done\n'):]
277
590
            self.state_accept = self._state_accept_reading_unused
278
591
            self.finished_reading = True
279
 
    
280
 
    def _state_accept_reading_unused(self, bytes):
281
 
        self.unused_data += bytes
 
592
 
 
593
    def _state_accept_reading_unused(self):
 
594
        self.unused_data += self._get_in_buffer()
 
595
        self._set_in_buffer(None)
282
596
 
283
597
    def _state_read_no_data(self):
284
598
        return ''
285
599
 
286
 
    def _state_read_in_buffer(self):
287
 
        result = self._in_buffer
288
 
        self._in_buffer = ''
 
600
    def _state_read_body_buffer(self):
 
601
        result = self._body
 
602
        self._body = ''
289
603
        return result
290
604
 
291
605
 
292
 
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
606
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
607
                                    message.ResponseHandler):
293
608
    """The client-side protocol for smart version 1."""
294
609
 
295
610
    def __init__(self, request):
301
616
        self._request = request
302
617
        self._body_buffer = None
303
618
        self._request_start_time = None
 
619
        self._last_verb = None
 
620
        self._headers = None
 
621
 
 
622
    def set_headers(self, headers):
 
623
        self._headers = dict(headers)
304
624
 
305
625
    def call(self, *args):
306
626
        if 'hpss' in debug.debug_flags:
307
627
            mutter('hpss call:   %s', repr(args)[1:-1])
308
 
            self._request_start_time = time.time()
 
628
            if getattr(self._request._medium, 'base', None) is not None:
 
629
                mutter('             (to %s)', self._request._medium.base)
 
630
            self._request_start_time = osutils.timer_func()
309
631
        self._write_args(args)
310
632
        self._request.finished_writing()
 
633
        self._last_verb = args[0]
311
634
 
312
635
    def call_with_body_bytes(self, args, body):
313
636
        """Make a remote call of args with body bytes 'body'.
316
639
        """
317
640
        if 'hpss' in debug.debug_flags:
318
641
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
642
            if getattr(self._request._medium, '_path', None) is not None:
 
643
                mutter('                  (to %s)', self._request._medium._path)
319
644
            mutter('              %d bytes', len(body))
320
 
            self._request_start_time = time.time()
 
645
            self._request_start_time = osutils.timer_func()
 
646
            if 'hpssdetail' in debug.debug_flags:
 
647
                mutter('hpss body content: %s', body)
321
648
        self._write_args(args)
322
649
        bytes = self._encode_bulk_data(body)
323
650
        self._request.accept_bytes(bytes)
324
651
        self._request.finished_writing()
 
652
        self._last_verb = args[0]
325
653
 
326
654
    def call_with_body_readv_array(self, args, body):
327
655
        """Make a remote call with a readv array.
328
656
 
329
657
        The body is encoded with one line per readv offset pair. The numbers in
330
 
        each pair are separated by a comma, and no trailing \n is emitted.
 
658
        each pair are separated by a comma, and no trailing \\n is emitted.
331
659
        """
332
660
        if 'hpss' in debug.debug_flags:
333
661
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
334
 
            self._request_start_time = time.time()
 
662
            if getattr(self._request._medium, '_path', None) is not None:
 
663
                mutter('                  (to %s)', self._request._medium._path)
 
664
            self._request_start_time = osutils.timer_func()
335
665
        self._write_args(args)
336
666
        readv_bytes = self._serialise_offsets(body)
337
667
        bytes = self._encode_bulk_data(readv_bytes)
339
669
        self._request.finished_writing()
340
670
        if 'hpss' in debug.debug_flags:
341
671
            mutter('              %d bytes in readv request', len(readv_bytes))
 
672
        self._last_verb = args[0]
 
673
 
 
674
    def call_with_body_stream(self, args, stream):
 
675
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
676
        # assume that a v1/v2 server doesn't support whatever method we're
 
677
        # trying to call with a body stream.
 
678
        self._request.finished_writing()
 
679
        self._request.finished_reading()
 
680
        raise errors.UnknownSmartMethod(args[0])
342
681
 
343
682
    def cancel_read_body(self):
344
683
        """After expecting a body, a response code may indicate one otherwise.
349
688
        """
350
689
        self._request.finished_reading()
351
690
 
352
 
    def read_response_tuple(self, expect_body=False):
353
 
        """Read a response tuple from the wire.
354
 
 
355
 
        This should only be called once.
356
 
        """
 
691
    def _read_response_tuple(self):
357
692
        result = self._recv_tuple()
358
693
        if 'hpss' in debug.debug_flags:
359
694
            if self._request_start_time is not None:
360
695
                mutter('   result:   %6.3fs  %s',
361
 
                       time.time() - self._request_start_time,
 
696
                       osutils.timer_func() - self._request_start_time,
362
697
                       repr(result)[1:-1])
363
698
                self._request_start_time = None
364
699
            else:
365
700
                mutter('   result:   %s', repr(result)[1:-1])
 
701
        return result
 
702
 
 
703
    def read_response_tuple(self, expect_body=False):
 
704
        """Read a response tuple from the wire.
 
705
 
 
706
        This should only be called once.
 
707
        """
 
708
        result = self._read_response_tuple()
 
709
        self._response_is_unknown_method(result)
 
710
        self._raise_args_if_error(result)
366
711
        if not expect_body:
367
712
            self._request.finished_reading()
368
713
        return result
369
714
 
 
715
    def _raise_args_if_error(self, result_tuple):
 
716
        # Later protocol versions have an explicit flag in the protocol to say
 
717
        # if an error response is "failed" or not.  In version 1 we don't have
 
718
        # that luxury.  So here is a complete list of errors that can be
 
719
        # returned in response to existing version 1 smart requests.  Responses
 
720
        # starting with these codes are always "failed" responses.
 
721
        v1_error_codes = [
 
722
            'norepository',
 
723
            'NoSuchFile',
 
724
            'FileExists',
 
725
            'DirectoryNotEmpty',
 
726
            'ShortReadvError',
 
727
            'UnicodeEncodeError',
 
728
            'UnicodeDecodeError',
 
729
            'ReadOnlyError',
 
730
            'nobranch',
 
731
            'NoSuchRevision',
 
732
            'nosuchrevision',
 
733
            'LockContention',
 
734
            'UnlockableTransport',
 
735
            'LockFailed',
 
736
            'TokenMismatch',
 
737
            'ReadError',
 
738
            'PermissionDenied',
 
739
            ]
 
740
        if result_tuple[0] in v1_error_codes:
 
741
            self._request.finished_reading()
 
742
            raise errors.ErrorFromSmartServer(result_tuple)
 
743
 
 
744
    def _response_is_unknown_method(self, result_tuple):
 
745
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
746
        method' response to the request.
 
747
 
 
748
        :param response: The response from a smart client call_expecting_body
 
749
            call.
 
750
        :param verb: The verb used in that call.
 
751
        :raises: UnexpectedSmartServerResponse
 
752
        """
 
753
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
754
                "bad request '%s'" % self._last_verb) or
 
755
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
756
                "bad request u'%s'" % self._last_verb)):
 
757
            # The response will have no body, so we've finished reading.
 
758
            self._request.finished_reading()
 
759
            raise errors.UnknownSmartMethod(self._last_verb)
 
760
 
370
761
    def read_body_bytes(self, count=-1):
371
762
        """Read bytes from the body, decoding into a byte stream.
372
 
        
373
 
        We read all bytes at once to ensure we've checked the trailer for 
 
763
 
 
764
        We read all bytes at once to ensure we've checked the trailer for
374
765
        errors, and then feed the buffer back as read_body_bytes is called.
375
766
        """
376
767
        if self._body_buffer is not None:
378
769
        _body_decoder = LengthPrefixedBodyDecoder()
379
770
 
380
771
        while not _body_decoder.finished_reading:
381
 
            bytes_wanted = _body_decoder.next_read_size()
382
 
            bytes = self._request.read_bytes(bytes_wanted)
 
772
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
773
            if bytes == '':
 
774
                # end of file encountered reading from server
 
775
                raise errors.ConnectionReset(
 
776
                    "Connection lost while reading response body.")
383
777
            _body_decoder.accept_bytes(bytes)
384
778
        self._request.finished_reading()
385
779
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
391
785
 
392
786
    def _recv_tuple(self):
393
787
        """Receive a tuple from the medium request."""
394
 
        return _decode_tuple(self._recv_line())
395
 
 
396
 
    def _recv_line(self):
397
 
        """Read an entire line from the medium request."""
398
 
        line = ''
399
 
        while not line or line[-1] != '\n':
400
 
            # TODO: this is inefficient - but tuples are short.
401
 
            new_char = self._request.read_bytes(1)
402
 
            if new_char == '':
403
 
                # end of file encountered reading from server
404
 
                raise errors.ConnectionReset(
405
 
                    "please check connectivity and permissions",
406
 
                    "(and try -Dhpss if further diagnosis is required)")
407
 
            line += new_char
408
 
        return line
 
788
        return _decode_tuple(self._request.read_line())
409
789
 
410
790
    def query_version(self):
411
791
        """Return protocol version number of the server."""
425
805
 
426
806
    def _write_protocol_version(self):
427
807
        """Write any prefixes this protocol requires.
428
 
        
 
808
 
429
809
        Version one doesn't send protocol versions.
430
810
        """
431
811
 
432
812
 
433
813
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
434
814
    """Version two of the client side of the smart protocol.
435
 
    
 
815
 
436
816
    This prefixes the request with the value of REQUEST_VERSION_TWO.
437
817
    """
438
818
 
 
819
    response_marker = RESPONSE_VERSION_TWO
 
820
    request_marker = REQUEST_VERSION_TWO
 
821
 
439
822
    def read_response_tuple(self, expect_body=False):
440
823
        """Read a response tuple from the wire.
441
824
 
442
825
        This should only be called once.
443
826
        """
444
827
        version = self._request.read_line()
445
 
        if version != RESPONSE_VERSION_TWO:
446
 
            raise errors.SmartProtocolError('bad protocol marker %r' % version)
447
 
        response_status = self._recv_line()
448
 
        if response_status not in ('success\n', 'failed\n'):
 
828
        if version != self.response_marker:
 
829
            self._request.finished_reading()
 
830
            raise errors.UnexpectedProtocolVersionMarker(version)
 
831
        response_status = self._request.read_line()
 
832
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
833
        self._response_is_unknown_method(result)
 
834
        if response_status == 'success\n':
 
835
            self.response_status = True
 
836
            if not expect_body:
 
837
                self._request.finished_reading()
 
838
            return result
 
839
        elif response_status == 'failed\n':
 
840
            self.response_status = False
 
841
            self._request.finished_reading()
 
842
            raise errors.ErrorFromSmartServer(result)
 
843
        else:
449
844
            raise errors.SmartProtocolError(
450
845
                'bad protocol status %r' % response_status)
451
 
        self.response_status = response_status == 'success\n'
452
 
        return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
453
846
 
454
847
    def _write_protocol_version(self):
455
 
        r"""Write any prefixes this protocol requires.
456
 
        
 
848
        """Write any prefixes this protocol requires.
 
849
 
457
850
        Version two sends the value of REQUEST_VERSION_TWO.
458
851
        """
459
 
        self._request.accept_bytes(REQUEST_VERSION_TWO)
 
852
        self._request.accept_bytes(self.request_marker)
 
853
 
 
854
    def read_streamed_body(self):
 
855
        """Read bytes from the body, decoding into a byte stream.
 
856
        """
 
857
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
858
        # buffer space available) on Windows.
 
859
        _body_decoder = ChunkedBodyDecoder()
 
860
        while not _body_decoder.finished_reading:
 
861
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
862
            if bytes == '':
 
863
                # end of file encountered reading from server
 
864
                raise errors.ConnectionReset(
 
865
                    "Connection lost while reading streamed body.")
 
866
            _body_decoder.accept_bytes(bytes)
 
867
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
868
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
869
                    mutter('              %d byte chunk read',
 
870
                           len(body_bytes))
 
871
                yield body_bytes
 
872
        self._request.finished_reading()
 
873
 
 
874
 
 
875
def build_server_protocol_three(backing_transport, write_func,
 
876
                                root_client_path, jail_root=None):
 
877
    request_handler = request.SmartServerRequestHandler(
 
878
        backing_transport, commands=request.request_handlers,
 
879
        root_client_path=root_client_path, jail_root=jail_root)
 
880
    responder = ProtocolThreeResponder(write_func)
 
881
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
882
    return ProtocolThreeDecoder(message_handler)
 
883
 
 
884
 
 
885
class ProtocolThreeDecoder(_StatefulDecoder):
 
886
 
 
887
    response_marker = RESPONSE_VERSION_THREE
 
888
    request_marker = REQUEST_VERSION_THREE
 
889
 
 
890
    def __init__(self, message_handler, expect_version_marker=False):
 
891
        _StatefulDecoder.__init__(self)
 
892
        self._has_dispatched = False
 
893
        # Initial state
 
894
        if expect_version_marker:
 
895
            self.state_accept = self._state_accept_expecting_protocol_version
 
896
            # We're expecting at least the protocol version marker + some
 
897
            # headers.
 
898
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
899
        else:
 
900
            self.state_accept = self._state_accept_expecting_headers
 
901
            self._number_needed_bytes = 4
 
902
        self.decoding_failed = False
 
903
        self.request_handler = self.message_handler = message_handler
 
904
 
 
905
    def accept_bytes(self, bytes):
 
906
        self._number_needed_bytes = None
 
907
        try:
 
908
            _StatefulDecoder.accept_bytes(self, bytes)
 
909
        except KeyboardInterrupt:
 
910
            raise
 
911
        except errors.SmartMessageHandlerError, exception:
 
912
            # We do *not* set self.decoding_failed here.  The message handler
 
913
            # has raised an error, but the decoder is still able to parse bytes
 
914
            # and determine when this message ends.
 
915
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
916
                log_exception_quietly()
 
917
            self.message_handler.protocol_error(exception.exc_value)
 
918
            # The state machine is ready to continue decoding, but the
 
919
            # exception has interrupted the loop that runs the state machine.
 
920
            # So we call accept_bytes again to restart it.
 
921
            self.accept_bytes('')
 
922
        except Exception, exception:
 
923
            # The decoder itself has raised an exception.  We cannot continue
 
924
            # decoding.
 
925
            self.decoding_failed = True
 
926
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
927
                # This happens during normal operation when the client tries a
 
928
                # protocol version the server doesn't understand, so no need to
 
929
                # log a traceback every time.
 
930
                # Note that this can only happen when
 
931
                # expect_version_marker=True, which is only the case on the
 
932
                # client side.
 
933
                pass
 
934
            else:
 
935
                log_exception_quietly()
 
936
            self.message_handler.protocol_error(exception)
 
937
 
 
938
    def _extract_length_prefixed_bytes(self):
 
939
        if self._in_buffer_len < 4:
 
940
            # A length prefix by itself is 4 bytes, and we don't even have that
 
941
            # many yet.
 
942
            raise _NeedMoreBytes(4)
 
943
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
944
        end_of_bytes = 4 + length
 
945
        if self._in_buffer_len < end_of_bytes:
 
946
            # We haven't yet read as many bytes as the length-prefix says there
 
947
            # are.
 
948
            raise _NeedMoreBytes(end_of_bytes)
 
949
        # Extract the bytes from the buffer.
 
950
        in_buf = self._get_in_buffer()
 
951
        bytes = in_buf[4:end_of_bytes]
 
952
        self._set_in_buffer(in_buf[end_of_bytes:])
 
953
        return bytes
 
954
 
 
955
    def _extract_prefixed_bencoded_data(self):
 
956
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
957
        try:
 
958
            decoded = bdecode_as_tuple(prefixed_bytes)
 
959
        except ValueError:
 
960
            raise errors.SmartProtocolError(
 
961
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
962
        return decoded
 
963
 
 
964
    def _extract_single_byte(self):
 
965
        if self._in_buffer_len == 0:
 
966
            # The buffer is empty
 
967
            raise _NeedMoreBytes(1)
 
968
        in_buf = self._get_in_buffer()
 
969
        one_byte = in_buf[0]
 
970
        self._set_in_buffer(in_buf[1:])
 
971
        return one_byte
 
972
 
 
973
    def _state_accept_expecting_protocol_version(self):
 
974
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
975
        in_buf = self._get_in_buffer()
 
976
        if needed_bytes > 0:
 
977
            # We don't have enough bytes to check if the protocol version
 
978
            # marker is right.  But we can check if it is already wrong by
 
979
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
980
            # we've read so far.
 
981
            # [In fact, if the remote end isn't bzr we might never receive
 
982
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
983
            # are wrong then we should just raise immediately rather than
 
984
            # stall.]
 
985
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
986
                # We have enough bytes to know the protocol version is wrong
 
987
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
988
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
989
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
990
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
991
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
992
        self.state_accept = self._state_accept_expecting_headers
 
993
 
 
994
    def _state_accept_expecting_headers(self):
 
995
        decoded = self._extract_prefixed_bencoded_data()
 
996
        if type(decoded) is not dict:
 
997
            raise errors.SmartProtocolError(
 
998
                'Header object %r is not a dict' % (decoded,))
 
999
        self.state_accept = self._state_accept_expecting_message_part
 
1000
        try:
 
1001
            self.message_handler.headers_received(decoded)
 
1002
        except:
 
1003
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1004
 
 
1005
    def _state_accept_expecting_message_part(self):
 
1006
        message_part_kind = self._extract_single_byte()
 
1007
        if message_part_kind == 'o':
 
1008
            self.state_accept = self._state_accept_expecting_one_byte
 
1009
        elif message_part_kind == 's':
 
1010
            self.state_accept = self._state_accept_expecting_structure
 
1011
        elif message_part_kind == 'b':
 
1012
            self.state_accept = self._state_accept_expecting_bytes
 
1013
        elif message_part_kind == 'e':
 
1014
            self.done()
 
1015
        else:
 
1016
            raise errors.SmartProtocolError(
 
1017
                'Bad message kind byte: %r' % (message_part_kind,))
 
1018
 
 
1019
    def _state_accept_expecting_one_byte(self):
 
1020
        byte = self._extract_single_byte()
 
1021
        self.state_accept = self._state_accept_expecting_message_part
 
1022
        try:
 
1023
            self.message_handler.byte_part_received(byte)
 
1024
        except:
 
1025
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1026
 
 
1027
    def _state_accept_expecting_bytes(self):
 
1028
        # XXX: this should not buffer whole message part, but instead deliver
 
1029
        # the bytes as they arrive.
 
1030
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1031
        self.state_accept = self._state_accept_expecting_message_part
 
1032
        try:
 
1033
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1034
        except:
 
1035
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1036
 
 
1037
    def _state_accept_expecting_structure(self):
 
1038
        structure = self._extract_prefixed_bencoded_data()
 
1039
        self.state_accept = self._state_accept_expecting_message_part
 
1040
        try:
 
1041
            self.message_handler.structure_part_received(structure)
 
1042
        except:
 
1043
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1044
 
 
1045
    def done(self):
 
1046
        self.unused_data = self._get_in_buffer()
 
1047
        self._set_in_buffer(None)
 
1048
        self.state_accept = self._state_accept_reading_unused
 
1049
        try:
 
1050
            self.message_handler.end_received()
 
1051
        except:
 
1052
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1053
 
 
1054
    def _state_accept_reading_unused(self):
 
1055
        self.unused_data += self._get_in_buffer()
 
1056
        self._set_in_buffer(None)
 
1057
 
 
1058
    def next_read_size(self):
 
1059
        if self.state_accept == self._state_accept_reading_unused:
 
1060
            return 0
 
1061
        elif self.decoding_failed:
 
1062
            # An exception occured while processing this message, probably from
 
1063
            # self.message_handler.  We're not sure that this state machine is
 
1064
            # in a consistent state, so just signal that we're done (i.e. give
 
1065
            # up).
 
1066
            return 0
 
1067
        else:
 
1068
            if self._number_needed_bytes is not None:
 
1069
                return self._number_needed_bytes - self._in_buffer_len
 
1070
            else:
 
1071
                raise AssertionError("don't know how many bytes are expected!")
 
1072
 
 
1073
 
 
1074
class _ProtocolThreeEncoder(object):
 
1075
 
 
1076
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1077
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1078
 
 
1079
    def __init__(self, write_func):
 
1080
        self._buf = []
 
1081
        self._buf_len = 0
 
1082
        self._real_write_func = write_func
 
1083
 
 
1084
    def _write_func(self, bytes):
 
1085
        # TODO: Another possibility would be to turn this into an async model.
 
1086
        #       Where we let another thread know that we have some bytes if
 
1087
        #       they want it, but we don't actually block for it
 
1088
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1089
        #       we might just push out smaller bits at a time?
 
1090
        self._buf.append(bytes)
 
1091
        self._buf_len += len(bytes)
 
1092
        if self._buf_len > self.BUFFER_SIZE:
 
1093
            self.flush()
 
1094
 
 
1095
    def flush(self):
 
1096
        if self._buf:
 
1097
            self._real_write_func(''.join(self._buf))
 
1098
            del self._buf[:]
 
1099
            self._buf_len = 0
 
1100
 
 
1101
    def _serialise_offsets(self, offsets):
 
1102
        """Serialise a readv offset list."""
 
1103
        txt = []
 
1104
        for start, length in offsets:
 
1105
            txt.append('%d,%d' % (start, length))
 
1106
        return '\n'.join(txt)
 
1107
 
 
1108
    def _write_protocol_version(self):
 
1109
        self._write_func(MESSAGE_VERSION_THREE)
 
1110
 
 
1111
    def _write_prefixed_bencode(self, structure):
 
1112
        bytes = bencode(structure)
 
1113
        self._write_func(struct.pack('!L', len(bytes)))
 
1114
        self._write_func(bytes)
 
1115
 
 
1116
    def _write_headers(self, headers):
 
1117
        self._write_prefixed_bencode(headers)
 
1118
 
 
1119
    def _write_structure(self, args):
 
1120
        self._write_func('s')
 
1121
        utf8_args = []
 
1122
        for arg in args:
 
1123
            if type(arg) is unicode:
 
1124
                utf8_args.append(arg.encode('utf8'))
 
1125
            else:
 
1126
                utf8_args.append(arg)
 
1127
        self._write_prefixed_bencode(utf8_args)
 
1128
 
 
1129
    def _write_end(self):
 
1130
        self._write_func('e')
 
1131
        self.flush()
 
1132
 
 
1133
    def _write_prefixed_body(self, bytes):
 
1134
        self._write_func('b')
 
1135
        self._write_func(struct.pack('!L', len(bytes)))
 
1136
        self._write_func(bytes)
 
1137
 
 
1138
    def _write_chunked_body_start(self):
 
1139
        self._write_func('oC')
 
1140
 
 
1141
    def _write_error_status(self):
 
1142
        self._write_func('oE')
 
1143
 
 
1144
    def _write_success_status(self):
 
1145
        self._write_func('oS')
 
1146
 
 
1147
 
 
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1149
 
 
1150
    def __init__(self, write_func):
 
1151
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1152
        self.response_sent = False
 
1153
        self._headers = {'Software version': bzrlib.__version__}
 
1154
        if 'hpss' in debug.debug_flags:
 
1155
            self._thread_id = thread.get_ident()
 
1156
            self._response_start_time = None
 
1157
 
 
1158
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1159
        if self._response_start_time is None:
 
1160
            self._response_start_time = osutils.timer_func()
 
1161
        if include_time:
 
1162
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1163
        else:
 
1164
            t = ''
 
1165
        if extra_bytes is None:
 
1166
            extra = ''
 
1167
        else:
 
1168
            extra = ' ' + repr(extra_bytes[:40])
 
1169
            if len(extra) > 33:
 
1170
                extra = extra[:29] + extra[-1] + '...'
 
1171
        mutter('%12s: [%s] %s%s%s'
 
1172
               % (action, self._thread_id, t, message, extra))
 
1173
 
 
1174
    def send_error(self, exception):
 
1175
        if self.response_sent:
 
1176
            raise AssertionError(
 
1177
                "send_error(%s) called, but response already sent."
 
1178
                % (exception,))
 
1179
        if isinstance(exception, errors.UnknownSmartMethod):
 
1180
            failure = request.FailedSmartServerResponse(
 
1181
                ('UnknownMethod', exception.verb))
 
1182
            self.send_response(failure)
 
1183
            return
 
1184
        if 'hpss' in debug.debug_flags:
 
1185
            self._trace('error', str(exception))
 
1186
        self.response_sent = True
 
1187
        self._write_protocol_version()
 
1188
        self._write_headers(self._headers)
 
1189
        self._write_error_status()
 
1190
        self._write_structure(('error', str(exception)))
 
1191
        self._write_end()
 
1192
 
 
1193
    def send_response(self, response):
 
1194
        if self.response_sent:
 
1195
            raise AssertionError(
 
1196
                "send_response(%r) called, but response already sent."
 
1197
                % (response,))
 
1198
        self.response_sent = True
 
1199
        self._write_protocol_version()
 
1200
        self._write_headers(self._headers)
 
1201
        if response.is_successful():
 
1202
            self._write_success_status()
 
1203
        else:
 
1204
            self._write_error_status()
 
1205
        if 'hpss' in debug.debug_flags:
 
1206
            self._trace('response', repr(response.args))
 
1207
        self._write_structure(response.args)
 
1208
        if response.body is not None:
 
1209
            self._write_prefixed_body(response.body)
 
1210
            if 'hpss' in debug.debug_flags:
 
1211
                self._trace('body', '%d bytes' % (len(response.body),),
 
1212
                            response.body, include_time=True)
 
1213
        elif response.body_stream is not None:
 
1214
            count = num_bytes = 0
 
1215
            first_chunk = None
 
1216
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1217
                count += 1
 
1218
                if exc_info is not None:
 
1219
                    self._write_error_status()
 
1220
                    error_struct = request._translate_error(exc_info[1])
 
1221
                    self._write_structure(error_struct)
 
1222
                    break
 
1223
                else:
 
1224
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1225
                        self._write_error_status()
 
1226
                        self._write_structure(chunk.args)
 
1227
                        break
 
1228
                    num_bytes += len(chunk)
 
1229
                    if first_chunk is None:
 
1230
                        first_chunk = chunk
 
1231
                    self._write_prefixed_body(chunk)
 
1232
                    self.flush()
 
1233
                    if 'hpssdetail' in debug.debug_flags:
 
1234
                        # Not worth timing separately, as _write_func is
 
1235
                        # actually buffered
 
1236
                        self._trace('body chunk',
 
1237
                                    '%d bytes' % (len(chunk),),
 
1238
                                    chunk, suppress_time=True)
 
1239
            if 'hpss' in debug.debug_flags:
 
1240
                self._trace('body stream',
 
1241
                            '%d bytes %d chunks' % (num_bytes, count),
 
1242
                            first_chunk)
 
1243
        self._write_end()
 
1244
        if 'hpss' in debug.debug_flags:
 
1245
            self._trace('response end', '', include_time=True)
 
1246
 
 
1247
 
 
1248
def _iter_with_errors(iterable):
 
1249
    """Handle errors from iterable.next().
 
1250
 
 
1251
    Use like::
 
1252
 
 
1253
        for exc_info, value in _iter_with_errors(iterable):
 
1254
            ...
 
1255
 
 
1256
    This is a safer alternative to::
 
1257
 
 
1258
        try:
 
1259
            for value in iterable:
 
1260
               ...
 
1261
        except:
 
1262
            ...
 
1263
 
 
1264
    Because the latter will catch errors from the for-loop body, not just
 
1265
    iterable.next()
 
1266
 
 
1267
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1268
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1269
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1270
    will not be itercepted.
 
1271
    """
 
1272
    iterator = iter(iterable)
 
1273
    while True:
 
1274
        try:
 
1275
            yield None, iterator.next()
 
1276
        except StopIteration:
 
1277
            return
 
1278
        except (KeyboardInterrupt, SystemExit):
 
1279
            raise
 
1280
        except Exception:
 
1281
            mutter('_iter_with_errors caught error')
 
1282
            log_exception_quietly()
 
1283
            yield sys.exc_info(), None
 
1284
            return
 
1285
 
 
1286
 
 
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1288
 
 
1289
    def __init__(self, medium_request):
 
1290
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1291
        self._medium_request = medium_request
 
1292
        self._headers = {}
 
1293
        self.body_stream_started = None
 
1294
 
 
1295
    def set_headers(self, headers):
 
1296
        self._headers = headers.copy()
 
1297
 
 
1298
    def call(self, *args):
 
1299
        if 'hpss' in debug.debug_flags:
 
1300
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1301
            base = getattr(self._medium_request._medium, 'base', None)
 
1302
            if base is not None:
 
1303
                mutter('             (to %s)', base)
 
1304
            self._request_start_time = osutils.timer_func()
 
1305
        self._write_protocol_version()
 
1306
        self._write_headers(self._headers)
 
1307
        self._write_structure(args)
 
1308
        self._write_end()
 
1309
        self._medium_request.finished_writing()
 
1310
 
 
1311
    def call_with_body_bytes(self, args, body):
 
1312
        """Make a remote call of args with body bytes 'body'.
 
1313
 
 
1314
        After calling this, call read_response_tuple to find the result out.
 
1315
        """
 
1316
        if 'hpss' in debug.debug_flags:
 
1317
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1318
            path = getattr(self._medium_request._medium, '_path', None)
 
1319
            if path is not None:
 
1320
                mutter('                  (to %s)', path)
 
1321
            mutter('              %d bytes', len(body))
 
1322
            self._request_start_time = osutils.timer_func()
 
1323
        self._write_protocol_version()
 
1324
        self._write_headers(self._headers)
 
1325
        self._write_structure(args)
 
1326
        self._write_prefixed_body(body)
 
1327
        self._write_end()
 
1328
        self._medium_request.finished_writing()
 
1329
 
 
1330
    def call_with_body_readv_array(self, args, body):
 
1331
        """Make a remote call with a readv array.
 
1332
 
 
1333
        The body is encoded with one line per readv offset pair. The numbers in
 
1334
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1335
        """
 
1336
        if 'hpss' in debug.debug_flags:
 
1337
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1338
            path = getattr(self._medium_request._medium, '_path', None)
 
1339
            if path is not None:
 
1340
                mutter('                  (to %s)', path)
 
1341
            self._request_start_time = osutils.timer_func()
 
1342
        self._write_protocol_version()
 
1343
        self._write_headers(self._headers)
 
1344
        self._write_structure(args)
 
1345
        readv_bytes = self._serialise_offsets(body)
 
1346
        if 'hpss' in debug.debug_flags:
 
1347
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1348
        self._write_prefixed_body(readv_bytes)
 
1349
        self._write_end()
 
1350
        self._medium_request.finished_writing()
 
1351
 
 
1352
    def call_with_body_stream(self, args, stream):
 
1353
        if 'hpss' in debug.debug_flags:
 
1354
            mutter('hpss call w/body stream: %r', args)
 
1355
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
            if path is not None:
 
1357
                mutter('                  (to %s)', path)
 
1358
            self._request_start_time = osutils.timer_func()
 
1359
        self.body_stream_started = False
 
1360
        self._write_protocol_version()
 
1361
        self._write_headers(self._headers)
 
1362
        self._write_structure(args)
 
1363
        # TODO: notice if the server has sent an early error reply before we
 
1364
        #       have finished sending the stream.  We would notice at the end
 
1365
        #       anyway, but if the medium can deliver it early then it's good
 
1366
        #       to short-circuit the whole request...
 
1367
        # Provoke any ConnectionReset failures before we start the body stream.
 
1368
        self.flush()
 
1369
        self.body_stream_started = True
 
1370
        for exc_info, part in _iter_with_errors(stream):
 
1371
            if exc_info is not None:
 
1372
                # Iterating the stream failed.  Cleanly abort the request.
 
1373
                self._write_error_status()
 
1374
                # Currently the client unconditionally sends ('error',) as the
 
1375
                # error args.
 
1376
                self._write_structure(('error',))
 
1377
                self._write_end()
 
1378
                self._medium_request.finished_writing()
 
1379
                raise exc_info[0], exc_info[1], exc_info[2]
 
1380
            else:
 
1381
                self._write_prefixed_body(part)
 
1382
                self.flush()
 
1383
        self._write_end()
 
1384
        self._medium_request.finished_writing()
460
1385