~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2007-12-20 04:20:19 UTC
  • mfrom: (3062.2.13 fast-plan-merge2)
  • Revision ID: pqm@pqm.ubuntu.com-20071220042019-wsij5vgvhgw4qhdt
Annotate merge can do cherrypicks (abentley)

Show diffs side-by-side

added added

removed removed

Lines of Context:
20
20
 
21
21
import collections
22
22
from cStringIO import StringIO
23
 
import struct
24
 
import sys
25
23
import time
26
24
 
27
 
import bzrlib
28
25
from bzrlib import debug
29
26
from bzrlib import errors
30
 
from bzrlib.smart import message, request
 
27
from bzrlib.smart import request
31
28
from bzrlib.trace import log_exception_quietly, mutter
32
 
from bzrlib.util.bencode import bdecode, bencode
33
29
 
34
30
 
35
31
# Protocol version strings.  These are sent as prefixes of bzr requests and
38
34
REQUEST_VERSION_TWO = 'bzr request 2\n'
39
35
RESPONSE_VERSION_TWO = 'bzr response 2\n'
40
36
 
41
 
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
42
 
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
43
 
 
44
37
 
45
38
def _recv_tuple(from_file):
46
39
    req_line = from_file.readline()
48
41
 
49
42
 
50
43
def _decode_tuple(req_line):
51
 
    if req_line is None or req_line == '':
 
44
    if req_line == None or req_line == '':
52
45
        return None
53
46
    if req_line[-1] != '\n':
54
47
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
60
53
    return '\x01'.join(args) + '\n'
61
54
 
62
55
 
63
 
class Requester(object):
64
 
    """Abstract base class for an object that can issue requests on a smart
65
 
    medium.
66
 
    """
67
 
 
68
 
    def call(self, *args):
69
 
        """Make a remote call.
70
 
 
71
 
        :param args: the arguments of this call.
72
 
        """
73
 
        raise NotImplementedError(self.call)
74
 
 
75
 
    def call_with_body_bytes(self, args, body):
76
 
        """Make a remote call with a body.
77
 
 
78
 
        :param args: the arguments of this call.
79
 
        :type body: str
80
 
        :param body: the body to send with the request.
81
 
        """
82
 
        raise NotImplementedError(self.call_with_body_bytes)
83
 
 
84
 
    def call_with_body_readv_array(self, args, body):
85
 
        """Make a remote call with a readv array.
86
 
 
87
 
        :param args: the arguments of this call.
88
 
        :type body: iterable of (start, length) tuples.
89
 
        :param body: the readv ranges to send with this request.
90
 
        """
91
 
        raise NotImplementedError(self.call_with_body_readv_array)
92
 
 
93
 
    def set_headers(self, headers):
94
 
        raise NotImplementedError(self.set_headers)
95
 
 
96
 
 
97
56
class SmartProtocolBase(object):
98
57
    """Methods common to client and server"""
99
58
 
114
73
class SmartServerRequestProtocolOne(SmartProtocolBase):
115
74
    """Server-side encoding and decoding logic for smart version 1."""
116
75
    
117
 
    def __init__(self, backing_transport, write_func, root_client_path='/'):
 
76
    def __init__(self, backing_transport, write_func):
118
77
        self._backing_transport = backing_transport
119
 
        self._root_client_path = root_client_path
120
 
        self.unused_data = ''
 
78
        self.excess_buffer = ''
121
79
        self._finished = False
122
80
        self.in_buffer = ''
123
 
        self._has_dispatched = False
 
81
        self.has_dispatched = False
124
82
        self.request = None
125
83
        self._body_decoder = None
126
84
        self._write_func = write_func
130
88
        
131
89
        :param bytes: must be a byte string
132
90
        """
133
 
        if not isinstance(bytes, str):
134
 
            raise ValueError(bytes)
 
91
        assert isinstance(bytes, str)
135
92
        self.in_buffer += bytes
136
 
        if not self._has_dispatched:
 
93
        if not self.has_dispatched:
137
94
            if '\n' not in self.in_buffer:
138
95
                # no command line yet
139
96
                return
140
 
            self._has_dispatched = True
 
97
            self.has_dispatched = True
141
98
            try:
142
99
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
143
100
                first_line += '\n'
144
101
                req_args = _decode_tuple(first_line)
145
102
                self.request = request.SmartServerRequestHandler(
146
 
                    self._backing_transport, commands=request.request_handlers,
147
 
                    root_client_path=self._root_client_path)
 
103
                    self._backing_transport, commands=request.request_handlers)
148
104
                self.request.dispatch_command(req_args[0], req_args[1:])
149
105
                if self.request.finished_reading:
150
106
                    # trivial request
151
 
                    self.unused_data = self.in_buffer
 
107
                    self.excess_buffer = self.in_buffer
152
108
                    self.in_buffer = ''
153
109
                    self._send_response(self.request.response)
154
110
            except KeyboardInterrupt:
155
111
                raise
156
 
            except errors.UnknownSmartMethod, err:
157
 
                protocol_error = errors.SmartProtocolError(
158
 
                    "bad request %r" % (err.verb,))
159
 
                failure = request.FailedSmartServerResponse(
160
 
                    ('error', str(protocol_error)))
161
 
                self._send_response(failure)
162
 
                return
163
112
            except Exception, exception:
164
113
                # everything else: pass to client, flush, and quit
165
114
                log_exception_quietly()
167
116
                    ('error', str(exception))))
168
117
                return
169
118
 
170
 
        if self._has_dispatched:
 
119
        if self.has_dispatched:
171
120
            if self._finished:
172
121
                # nothing to do.XXX: this routine should be a single state 
173
122
                # machine too.
174
 
                self.unused_data += self.in_buffer
 
123
                self.excess_buffer += self.in_buffer
175
124
                self.in_buffer = ''
176
125
                return
177
126
            if self._body_decoder is None:
182
131
            self.request.accept_body(body_data)
183
132
            if self._body_decoder.finished_reading:
184
133
                self.request.end_of_body()
185
 
                if not self.request.finished_reading:
186
 
                    raise AssertionError("no more body, request not finished")
 
134
                assert self.request.finished_reading, \
 
135
                    "no more body, request not finished"
187
136
            if self.request.response is not None:
188
137
                self._send_response(self.request.response)
189
 
                self.unused_data = self.in_buffer
 
138
                self.excess_buffer = self.in_buffer
190
139
                self.in_buffer = ''
191
140
            else:
192
 
                if self.request.finished_reading:
193
 
                    raise AssertionError(
194
 
                        "no response and we have finished reading.")
 
141
                assert not self.request.finished_reading, \
 
142
                    "no response and we have finished reading."
195
143
 
196
144
    def _send_response(self, response):
197
145
        """Send a smart server response down the output stream."""
198
 
        if self._finished:
199
 
            raise AssertionError('response already sent')
 
146
        assert not self._finished, 'response already sent'
200
147
        args = response.args
201
148
        body = response.body
202
149
        self._finished = True
204
151
        self._write_success_or_failure_prefix(response)
205
152
        self._write_func(_encode_tuple(args))
206
153
        if body is not None:
207
 
            if not isinstance(body, str):
208
 
                raise ValueError(body)
 
154
            assert isinstance(body, str), 'body must be a str'
209
155
            bytes = self._encode_bulk_data(body)
210
156
            self._write_func(bytes)
211
157
 
238
184
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
239
185
    """
240
186
 
241
 
    response_marker = RESPONSE_VERSION_TWO
242
 
    request_marker = REQUEST_VERSION_TWO
243
 
 
244
187
    def _write_success_or_failure_prefix(self, response):
245
188
        """Write the protocol specific success/failure prefix."""
246
189
        if response.is_successful():
253
196
        
254
197
        Version two sends the value of RESPONSE_VERSION_TWO.
255
198
        """
256
 
        self._write_func(self.response_marker)
 
199
        self._write_func(RESPONSE_VERSION_TWO)
257
200
 
258
201
    def _send_response(self, response):
259
202
        """Send a smart server response down the output stream."""
260
 
        if (self._finished):
261
 
            raise AssertionError('response already sent')
 
203
        assert not self._finished, 'response already sent'
262
204
        self._finished = True
263
205
        self._write_protocol_version()
264
206
        self._write_success_or_failure_prefix(response)
265
207
        self._write_func(_encode_tuple(response.args))
266
208
        if response.body is not None:
267
 
            if not isinstance(response.body, str):
268
 
                raise AssertionError('body must be a str')
269
 
            if not (response.body_stream is None):
270
 
                raise AssertionError(
271
 
                    'body_stream and body cannot both be set')
 
209
            assert isinstance(response.body, str), 'body must be a str'
 
210
            assert response.body_stream is None, (
 
211
                'body_stream and body cannot both be set')
272
212
            bytes = self._encode_bulk_data(response.body)
273
213
            self._write_func(bytes)
274
214
        elif response.body_stream is not None:
296
236
                % chunk)
297
237
 
298
238
 
299
 
class _NeedMoreBytes(Exception):
300
 
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
301
 
    have been received.
302
 
    """
303
 
 
304
 
    def __init__(self, count=None):
305
 
        """Constructor.
306
 
 
307
 
        :param count: the total number of bytes needed by the current state.
308
 
            May be None if the number of bytes needed is unknown.
309
 
        """
310
 
        self.count = count
311
 
 
312
 
 
313
239
class _StatefulDecoder(object):
314
 
    """Base class for writing state machines to decode byte streams.
315
 
 
316
 
    Subclasses should provide a self.state_accept attribute that accepts bytes
317
 
    and, if appropriate, updates self.state_accept to a different function.
318
 
    accept_bytes will call state_accept as often as necessary to make sure the
319
 
    state machine has progressed as far as possible before it returns.
320
 
 
321
 
    See ProtocolThreeDecoder for an example subclass.
322
 
    """
323
240
 
324
241
    def __init__(self):
325
242
        self.finished_reading = False
326
 
        self._in_buffer = ''
327
243
        self.unused_data = ''
328
244
        self.bytes_left = None
329
 
        self._number_needed_bytes = None
330
245
 
331
246
    def accept_bytes(self, bytes):
332
247
        """Decode as much of bytes as possible.
339
254
        """
340
255
        # accept_bytes is allowed to change the state
341
256
        current_state = self.state_accept
342
 
        self._number_needed_bytes = None
343
 
        self._in_buffer += bytes
344
 
        try:
345
 
            # Run the function for the current state.
346
 
            self.state_accept()
347
 
            while current_state != self.state_accept:
348
 
                # The current state has changed.  Run the function for the new
349
 
                # current state, so that it can:
350
 
                #   - decode any unconsumed bytes left in a buffer, and
351
 
                #   - signal how many more bytes are expected (via raising
352
 
                #     _NeedMoreBytes).
353
 
                current_state = self.state_accept
354
 
                self.state_accept()
355
 
        except _NeedMoreBytes, e:
356
 
            self._number_needed_bytes = e.count
 
257
        self.state_accept(bytes)
 
258
        while current_state != self.state_accept:
 
259
            current_state = self.state_accept
 
260
            self.state_accept('')
357
261
 
358
262
 
359
263
class ChunkedBodyDecoder(_StatefulDecoder):
366
270
    def __init__(self):
367
271
        _StatefulDecoder.__init__(self)
368
272
        self.state_accept = self._state_accept_expecting_header
 
273
        self._in_buffer = ''
369
274
        self.chunk_in_progress = None
370
275
        self.chunks = collections.deque()
371
276
        self.error = False
403
308
    def _extract_line(self):
404
309
        pos = self._in_buffer.find('\n')
405
310
        if pos == -1:
406
 
            # We haven't read a complete line yet, so request more bytes before
407
 
            # we continue.
408
 
            raise _NeedMoreBytes(1)
 
311
            # We haven't read a complete length prefix yet, so there's nothing
 
312
            # to do.
 
313
            return None
409
314
        line = self._in_buffer[:pos]
410
315
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
411
316
        self._in_buffer = self._in_buffer[pos+1:]
413
318
 
414
319
    def _finished(self):
415
320
        self.unused_data = self._in_buffer
416
 
        self._in_buffer = ''
 
321
        self._in_buffer = None
417
322
        self.state_accept = self._state_accept_reading_unused
418
323
        if self.error:
419
324
            error_args = tuple(self.error_in_progress)
421
326
            self.error_in_progress = None
422
327
        self.finished_reading = True
423
328
 
424
 
    def _state_accept_expecting_header(self):
 
329
    def _state_accept_expecting_header(self, bytes):
 
330
        self._in_buffer += bytes
425
331
        prefix = self._extract_line()
426
 
        if prefix == 'chunked':
 
332
        if prefix is None:
 
333
            # We haven't read a complete length prefix yet, so there's nothing
 
334
            # to do.
 
335
            return
 
336
        elif prefix == 'chunked':
427
337
            self.state_accept = self._state_accept_expecting_length
428
338
        else:
429
339
            raise errors.SmartProtocolError(
430
340
                'Bad chunked body header: "%s"' % (prefix,))
431
341
 
432
 
    def _state_accept_expecting_length(self):
 
342
    def _state_accept_expecting_length(self, bytes):
 
343
        self._in_buffer += bytes
433
344
        prefix = self._extract_line()
434
 
        if prefix == 'ERR':
 
345
        if prefix is None:
 
346
            # We haven't read a complete length prefix yet, so there's nothing
 
347
            # to do.
 
348
            return
 
349
        elif prefix == 'ERR':
435
350
            self.error = True
436
351
            self.error_in_progress = []
437
 
            self._state_accept_expecting_length()
 
352
            self._state_accept_expecting_length('')
438
353
            return
439
354
        elif prefix == 'END':
440
355
            # We've read the end-of-body marker.
447
362
            self.chunk_in_progress = ''
448
363
            self.state_accept = self._state_accept_reading_chunk
449
364
 
450
 
    def _state_accept_reading_chunk(self):
 
365
    def _state_accept_reading_chunk(self, bytes):
 
366
        self._in_buffer += bytes
451
367
        in_buffer_len = len(self._in_buffer)
452
368
        self.chunk_in_progress += self._in_buffer[:self.bytes_left]
453
369
        self._in_buffer = self._in_buffer[self.bytes_left:]
462
378
            self.chunk_in_progress = None
463
379
            self.state_accept = self._state_accept_expecting_length
464
380
        
465
 
    def _state_accept_reading_unused(self):
466
 
        self.unused_data += self._in_buffer
467
 
        self._in_buffer = ''
 
381
    def _state_accept_reading_unused(self, bytes):
 
382
        self.unused_data += bytes
468
383
 
469
384
 
470
385
class LengthPrefixedBodyDecoder(_StatefulDecoder):
474
389
        _StatefulDecoder.__init__(self)
475
390
        self.state_accept = self._state_accept_expecting_length
476
391
        self.state_read = self._state_read_no_data
477
 
        self._body = ''
 
392
        self._in_buffer = ''
478
393
        self._trailer_buffer = ''
479
394
    
480
395
    def next_read_size(self):
497
412
        """Return any pending data that has been decoded."""
498
413
        return self.state_read()
499
414
 
500
 
    def _state_accept_expecting_length(self):
 
415
    def _state_accept_expecting_length(self, bytes):
 
416
        self._in_buffer += bytes
501
417
        pos = self._in_buffer.find('\n')
502
418
        if pos == -1:
503
419
            return
504
420
        self.bytes_left = int(self._in_buffer[:pos])
505
421
        self._in_buffer = self._in_buffer[pos+1:]
 
422
        self.bytes_left -= len(self._in_buffer)
506
423
        self.state_accept = self._state_accept_reading_body
507
 
        self.state_read = self._state_read_body_buffer
 
424
        self.state_read = self._state_read_in_buffer
508
425
 
509
 
    def _state_accept_reading_body(self):
510
 
        self._body += self._in_buffer
511
 
        self.bytes_left -= len(self._in_buffer)
512
 
        self._in_buffer = ''
 
426
    def _state_accept_reading_body(self, bytes):
 
427
        self._in_buffer += bytes
 
428
        self.bytes_left -= len(bytes)
513
429
        if self.bytes_left <= 0:
514
430
            # Finished with body
515
431
            if self.bytes_left != 0:
516
 
                self._trailer_buffer = self._body[self.bytes_left:]
517
 
                self._body = self._body[:self.bytes_left]
 
432
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
433
                self._in_buffer = self._in_buffer[:self.bytes_left]
518
434
            self.bytes_left = None
519
435
            self.state_accept = self._state_accept_reading_trailer
520
436
        
521
 
    def _state_accept_reading_trailer(self):
522
 
        self._trailer_buffer += self._in_buffer
523
 
        self._in_buffer = ''
 
437
    def _state_accept_reading_trailer(self, bytes):
 
438
        self._trailer_buffer += bytes
524
439
        # TODO: what if the trailer does not match "done\n"?  Should this raise
525
440
        # a ProtocolViolation exception?
526
441
        if self._trailer_buffer.startswith('done\n'):
528
443
            self.state_accept = self._state_accept_reading_unused
529
444
            self.finished_reading = True
530
445
    
531
 
    def _state_accept_reading_unused(self):
532
 
        self.unused_data += self._in_buffer
533
 
        self._in_buffer = ''
 
446
    def _state_accept_reading_unused(self, bytes):
 
447
        self.unused_data += bytes
534
448
 
535
449
    def _state_read_no_data(self):
536
450
        return ''
537
451
 
538
 
    def _state_read_body_buffer(self):
539
 
        result = self._body
540
 
        self._body = ''
 
452
    def _state_read_in_buffer(self):
 
453
        result = self._in_buffer
 
454
        self._in_buffer = ''
541
455
        return result
542
456
 
543
457
 
544
 
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
545
 
                                    message.ResponseHandler):
 
458
class SmartClientRequestProtocolOne(SmartProtocolBase):
546
459
    """The client-side protocol for smart version 1."""
547
460
 
548
461
    def __init__(self, request):
554
467
        self._request = request
555
468
        self._body_buffer = None
556
469
        self._request_start_time = None
557
 
        self._last_verb = None
558
 
        self._headers = None
559
 
 
560
 
    def set_headers(self, headers):
561
 
        self._headers = dict(headers)
562
470
 
563
471
    def call(self, *args):
564
472
        if 'hpss' in debug.debug_flags:
565
473
            mutter('hpss call:   %s', repr(args)[1:-1])
566
 
            if getattr(self._request._medium, 'base', None) is not None:
567
 
                mutter('             (to %s)', self._request._medium.base)
568
474
            self._request_start_time = time.time()
569
475
        self._write_args(args)
570
476
        self._request.finished_writing()
571
 
        self._last_verb = args[0]
572
477
 
573
478
    def call_with_body_bytes(self, args, body):
574
479
        """Make a remote call of args with body bytes 'body'.
577
482
        """
578
483
        if 'hpss' in debug.debug_flags:
579
484
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
580
 
            if getattr(self._request._medium, '_path', None) is not None:
581
 
                mutter('                  (to %s)', self._request._medium._path)
582
485
            mutter('              %d bytes', len(body))
583
486
            self._request_start_time = time.time()
584
 
            if 'hpssdetail' in debug.debug_flags:
585
 
                mutter('hpss body content: %s', body)
586
487
        self._write_args(args)
587
488
        bytes = self._encode_bulk_data(body)
588
489
        self._request.accept_bytes(bytes)
589
490
        self._request.finished_writing()
590
 
        self._last_verb = args[0]
591
491
 
592
492
    def call_with_body_readv_array(self, args, body):
593
493
        """Make a remote call with a readv array.
597
497
        """
598
498
        if 'hpss' in debug.debug_flags:
599
499
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
600
 
            if getattr(self._request._medium, '_path', None) is not None:
601
 
                mutter('                  (to %s)', self._request._medium._path)
602
500
            self._request_start_time = time.time()
603
501
        self._write_args(args)
604
502
        readv_bytes = self._serialise_offsets(body)
607
505
        self._request.finished_writing()
608
506
        if 'hpss' in debug.debug_flags:
609
507
            mutter('              %d bytes in readv request', len(readv_bytes))
610
 
        self._last_verb = args[0]
611
508
 
612
509
    def cancel_read_body(self):
613
510
        """After expecting a body, a response code may indicate one otherwise.
618
515
        """
619
516
        self._request.finished_reading()
620
517
 
621
 
    def _read_response_tuple(self):
 
518
    def read_response_tuple(self, expect_body=False):
 
519
        """Read a response tuple from the wire.
 
520
 
 
521
        This should only be called once.
 
522
        """
622
523
        result = self._recv_tuple()
623
524
        if 'hpss' in debug.debug_flags:
624
525
            if self._request_start_time is not None:
628
529
                self._request_start_time = None
629
530
            else:
630
531
                mutter('   result:   %s', repr(result)[1:-1])
631
 
        return result
632
 
 
633
 
    def read_response_tuple(self, expect_body=False):
634
 
        """Read a response tuple from the wire.
635
 
 
636
 
        This should only be called once.
637
 
        """
638
 
        result = self._read_response_tuple()
639
 
        self._response_is_unknown_method(result)
640
 
        self._raise_args_if_error(result)
641
532
        if not expect_body:
642
533
            self._request.finished_reading()
643
534
        return result
644
535
 
645
 
    def _raise_args_if_error(self, result_tuple):
646
 
        # Later protocol versions have an explicit flag in the protocol to say
647
 
        # if an error response is "failed" or not.  In version 1 we don't have
648
 
        # that luxury.  So here is a complete list of errors that can be
649
 
        # returned in response to existing version 1 smart requests.  Responses
650
 
        # starting with these codes are always "failed" responses.
651
 
        v1_error_codes = [
652
 
            'norepository',
653
 
            'NoSuchFile',
654
 
            'FileExists',
655
 
            'DirectoryNotEmpty',
656
 
            'ShortReadvError',
657
 
            'UnicodeEncodeError',
658
 
            'UnicodeDecodeError',
659
 
            'ReadOnlyError',
660
 
            'nobranch',
661
 
            'NoSuchRevision',
662
 
            'nosuchrevision',
663
 
            'LockContention',
664
 
            'UnlockableTransport',
665
 
            'LockFailed',
666
 
            'TokenMismatch',
667
 
            'ReadError',
668
 
            'PermissionDenied',
669
 
            ]
670
 
        if result_tuple[0] in v1_error_codes:
671
 
            self._request.finished_reading()
672
 
            raise errors.ErrorFromSmartServer(result_tuple)
673
 
 
674
 
    def _response_is_unknown_method(self, result_tuple):
675
 
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
676
 
        method' response to the request.
677
 
        
678
 
        :param response: The response from a smart client call_expecting_body
679
 
            call.
680
 
        :param verb: The verb used in that call.
681
 
        :raises: UnexpectedSmartServerResponse
682
 
        """
683
 
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
684
 
                "bad request '%s'" % self._last_verb) or
685
 
              result_tuple == ('error', "Generic bzr smart protocol error: "
686
 
                "bad request u'%s'" % self._last_verb)):
687
 
            # The response will have no body, so we've finished reading.
688
 
            self._request.finished_reading()
689
 
            raise errors.UnknownSmartMethod(self._last_verb)
690
 
        
691
536
    def read_body_bytes(self, count=-1):
692
537
        """Read bytes from the body, decoding into a byte stream.
693
538
        
699
544
        _body_decoder = LengthPrefixedBodyDecoder()
700
545
 
701
546
        while not _body_decoder.finished_reading:
702
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
703
 
            if bytes == '':
704
 
                # end of file encountered reading from server
705
 
                raise errors.ConnectionReset(
706
 
                    "Connection lost while reading response body.")
 
547
            bytes_wanted = _body_decoder.next_read_size()
 
548
            bytes = self._request.read_bytes(bytes_wanted)
707
549
            _body_decoder.accept_bytes(bytes)
708
550
        self._request.finished_reading()
709
551
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
715
557
 
716
558
    def _recv_tuple(self):
717
559
        """Receive a tuple from the medium request."""
718
 
        return _decode_tuple(self._request.read_line())
 
560
        return _decode_tuple(self._recv_line())
 
561
 
 
562
    def _recv_line(self):
 
563
        """Read an entire line from the medium request."""
 
564
        line = ''
 
565
        while not line or line[-1] != '\n':
 
566
            # TODO: this is inefficient - but tuples are short.
 
567
            new_char = self._request.read_bytes(1)
 
568
            if new_char == '':
 
569
                # end of file encountered reading from server
 
570
                raise errors.ConnectionReset(
 
571
                    "please check connectivity and permissions",
 
572
                    "(and try -Dhpss if further diagnosis is required)")
 
573
            line += new_char
 
574
        return line
719
575
 
720
576
    def query_version(self):
721
577
        """Return protocol version number of the server."""
746
602
    This prefixes the request with the value of REQUEST_VERSION_TWO.
747
603
    """
748
604
 
749
 
    response_marker = RESPONSE_VERSION_TWO
750
 
    request_marker = REQUEST_VERSION_TWO
751
 
 
752
605
    def read_response_tuple(self, expect_body=False):
753
606
        """Read a response tuple from the wire.
754
607
 
755
608
        This should only be called once.
756
609
        """
757
610
        version = self._request.read_line()
758
 
        if version != self.response_marker:
759
 
            self._request.finished_reading()
760
 
            raise errors.UnexpectedProtocolVersionMarker(version)
761
 
        response_status = self._request.read_line()
762
 
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
763
 
        self._response_is_unknown_method(result)
764
 
        if response_status == 'success\n':
765
 
            self.response_status = True
766
 
            if not expect_body:
767
 
                self._request.finished_reading()
768
 
            return result
769
 
        elif response_status == 'failed\n':
770
 
            self.response_status = False
771
 
            self._request.finished_reading()
772
 
            raise errors.ErrorFromSmartServer(result)
773
 
        else:
 
611
        if version != RESPONSE_VERSION_TWO:
 
612
            raise errors.SmartProtocolError('bad protocol marker %r' % version)
 
613
        response_status = self._recv_line()
 
614
        if response_status not in ('success\n', 'failed\n'):
774
615
            raise errors.SmartProtocolError(
775
616
                'bad protocol status %r' % response_status)
 
617
        self.response_status = response_status == 'success\n'
 
618
        return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
776
619
 
777
620
    def _write_protocol_version(self):
778
621
        """Write any prefixes this protocol requires.
779
622
        
780
623
        Version two sends the value of REQUEST_VERSION_TWO.
781
624
        """
782
 
        self._request.accept_bytes(self.request_marker)
 
625
        self._request.accept_bytes(REQUEST_VERSION_TWO)
783
626
 
784
627
    def read_streamed_body(self):
785
628
        """Read bytes from the body, decoding into a byte stream.
786
629
        """
787
 
        # Read no more than 64k at a time so that we don't risk error 10055 (no
788
 
        # buffer space available) on Windows.
789
630
        _body_decoder = ChunkedBodyDecoder()
790
631
        while not _body_decoder.finished_reading:
791
 
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
792
 
            if bytes == '':
793
 
                # end of file encountered reading from server
794
 
                raise errors.ConnectionReset(
795
 
                    "Connection lost while reading streamed body.")
 
632
            bytes_wanted = _body_decoder.next_read_size()
 
633
            bytes = self._request.read_bytes(bytes_wanted)
796
634
            _body_decoder.accept_bytes(bytes)
797
635
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
798
 
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
636
                if 'hpss' in debug.debug_flags:
799
637
                    mutter('              %d byte chunk read',
800
638
                           len(body_bytes))
801
639
                yield body_bytes
802
640
        self._request.finished_reading()
803
641
 
804
 
 
805
 
def build_server_protocol_three(backing_transport, write_func,
806
 
                                root_client_path):
807
 
    request_handler = request.SmartServerRequestHandler(
808
 
        backing_transport, commands=request.request_handlers,
809
 
        root_client_path=root_client_path)
810
 
    responder = ProtocolThreeResponder(write_func)
811
 
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
812
 
    return ProtocolThreeDecoder(message_handler)
813
 
 
814
 
 
815
 
class ProtocolThreeDecoder(_StatefulDecoder):
816
 
 
817
 
    response_marker = RESPONSE_VERSION_THREE
818
 
    request_marker = REQUEST_VERSION_THREE
819
 
 
820
 
    def __init__(self, message_handler, expect_version_marker=False):
821
 
        _StatefulDecoder.__init__(self)
822
 
        self._has_dispatched = False
823
 
        # Initial state
824
 
        if expect_version_marker:
825
 
            self.state_accept = self._state_accept_expecting_protocol_version
826
 
            # We're expecting at least the protocol version marker + some
827
 
            # headers.
828
 
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
829
 
        else:
830
 
            self.state_accept = self._state_accept_expecting_headers
831
 
            self._number_needed_bytes = 4
832
 
        self.decoding_failed = False
833
 
        self.request_handler = self.message_handler = message_handler
834
 
 
835
 
    def accept_bytes(self, bytes):
836
 
        self._number_needed_bytes = None
837
 
        try:
838
 
            _StatefulDecoder.accept_bytes(self, bytes)
839
 
        except KeyboardInterrupt:
840
 
            raise
841
 
        except errors.SmartMessageHandlerError, exception:
842
 
            # We do *not* set self.decoding_failed here.  The message handler
843
 
            # has raised an error, but the decoder is still able to parse bytes
844
 
            # and determine when this message ends.
845
 
            log_exception_quietly()
846
 
            self.message_handler.protocol_error(exception.exc_value)
847
 
            # The state machine is ready to continue decoding, but the
848
 
            # exception has interrupted the loop that runs the state machine.
849
 
            # So we call accept_bytes again to restart it.
850
 
            self.accept_bytes('')
851
 
        except Exception, exception:
852
 
            # The decoder itself has raised an exception.  We cannot continue
853
 
            # decoding.
854
 
            self.decoding_failed = True
855
 
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
856
 
                # This happens during normal operation when the client tries a
857
 
                # protocol version the server doesn't understand, so no need to
858
 
                # log a traceback every time.
859
 
                # Note that this can only happen when
860
 
                # expect_version_marker=True, which is only the case on the
861
 
                # client side.
862
 
                pass
863
 
            else:
864
 
                log_exception_quietly()
865
 
            self.message_handler.protocol_error(exception)
866
 
 
867
 
    def _extract_length_prefixed_bytes(self):
868
 
        if len(self._in_buffer) < 4:
869
 
            # A length prefix by itself is 4 bytes, and we don't even have that
870
 
            # many yet.
871
 
            raise _NeedMoreBytes(4)
872
 
        (length,) = struct.unpack('!L', self._in_buffer[:4])
873
 
        end_of_bytes = 4 + length
874
 
        if len(self._in_buffer) < end_of_bytes:
875
 
            # We haven't yet read as many bytes as the length-prefix says there
876
 
            # are.
877
 
            raise _NeedMoreBytes(end_of_bytes)
878
 
        # Extract the bytes from the buffer.
879
 
        bytes = self._in_buffer[4:end_of_bytes]
880
 
        self._in_buffer = self._in_buffer[end_of_bytes:]
881
 
        return bytes
882
 
 
883
 
    def _extract_prefixed_bencoded_data(self):
884
 
        prefixed_bytes = self._extract_length_prefixed_bytes()
885
 
        try:
886
 
            decoded = bdecode(prefixed_bytes)
887
 
        except ValueError:
888
 
            raise errors.SmartProtocolError(
889
 
                'Bytes %r not bencoded' % (prefixed_bytes,))
890
 
        return decoded
891
 
 
892
 
    def _extract_single_byte(self):
893
 
        if self._in_buffer == '':
894
 
            # The buffer is empty
895
 
            raise _NeedMoreBytes(1)
896
 
        one_byte = self._in_buffer[0]
897
 
        self._in_buffer = self._in_buffer[1:]
898
 
        return one_byte
899
 
 
900
 
    def _state_accept_expecting_protocol_version(self):
901
 
        needed_bytes = len(MESSAGE_VERSION_THREE) - len(self._in_buffer)
902
 
        if needed_bytes > 0:
903
 
            # We don't have enough bytes to check if the protocol version
904
 
            # marker is right.  But we can check if it is already wrong by
905
 
            # checking that the start of MESSAGE_VERSION_THREE matches what
906
 
            # we've read so far.
907
 
            # [In fact, if the remote end isn't bzr we might never receive
908
 
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
909
 
            # are wrong then we should just raise immediately rather than
910
 
            # stall.]
911
 
            if not MESSAGE_VERSION_THREE.startswith(self._in_buffer):
912
 
                # We have enough bytes to know the protocol version is wrong
913
 
                raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
914
 
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
915
 
        if not self._in_buffer.startswith(MESSAGE_VERSION_THREE):
916
 
            raise errors.UnexpectedProtocolVersionMarker(self._in_buffer)
917
 
        self._in_buffer = self._in_buffer[len(MESSAGE_VERSION_THREE):]
918
 
        self.state_accept = self._state_accept_expecting_headers
919
 
 
920
 
    def _state_accept_expecting_headers(self):
921
 
        decoded = self._extract_prefixed_bencoded_data()
922
 
        if type(decoded) is not dict:
923
 
            raise errors.SmartProtocolError(
924
 
                'Header object %r is not a dict' % (decoded,))
925
 
        self.state_accept = self._state_accept_expecting_message_part
926
 
        try:
927
 
            self.message_handler.headers_received(decoded)
928
 
        except:
929
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
930
 
    
931
 
    def _state_accept_expecting_message_part(self):
932
 
        message_part_kind = self._extract_single_byte()
933
 
        if message_part_kind == 'o':
934
 
            self.state_accept = self._state_accept_expecting_one_byte
935
 
        elif message_part_kind == 's':
936
 
            self.state_accept = self._state_accept_expecting_structure
937
 
        elif message_part_kind == 'b':
938
 
            self.state_accept = self._state_accept_expecting_bytes
939
 
        elif message_part_kind == 'e':
940
 
            self.done()
941
 
        else:
942
 
            raise errors.SmartProtocolError(
943
 
                'Bad message kind byte: %r' % (message_part_kind,))
944
 
 
945
 
    def _state_accept_expecting_one_byte(self):
946
 
        byte = self._extract_single_byte()
947
 
        self.state_accept = self._state_accept_expecting_message_part
948
 
        try:
949
 
            self.message_handler.byte_part_received(byte)
950
 
        except:
951
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
952
 
 
953
 
    def _state_accept_expecting_bytes(self):
954
 
        # XXX: this should not buffer whole message part, but instead deliver
955
 
        # the bytes as they arrive.
956
 
        prefixed_bytes = self._extract_length_prefixed_bytes()
957
 
        self.state_accept = self._state_accept_expecting_message_part
958
 
        try:
959
 
            self.message_handler.bytes_part_received(prefixed_bytes)
960
 
        except:
961
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
962
 
 
963
 
    def _state_accept_expecting_structure(self):
964
 
        structure = self._extract_prefixed_bencoded_data()
965
 
        self.state_accept = self._state_accept_expecting_message_part
966
 
        try:
967
 
            self.message_handler.structure_part_received(structure)
968
 
        except:
969
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
970
 
 
971
 
    def done(self):
972
 
        self.unused_data = self._in_buffer
973
 
        self._in_buffer = ''
974
 
        self.state_accept = self._state_accept_reading_unused
975
 
        try:
976
 
            self.message_handler.end_received()
977
 
        except:
978
 
            raise errors.SmartMessageHandlerError(sys.exc_info())
979
 
 
980
 
    def _state_accept_reading_unused(self):
981
 
        self.unused_data += self._in_buffer
982
 
        self._in_buffer = ''
983
 
 
984
 
    def next_read_size(self):
985
 
        if self.state_accept == self._state_accept_reading_unused:
986
 
            return 0
987
 
        elif self.decoding_failed:
988
 
            # An exception occured while processing this message, probably from
989
 
            # self.message_handler.  We're not sure that this state machine is
990
 
            # in a consistent state, so just signal that we're done (i.e. give
991
 
            # up).
992
 
            return 0
993
 
        else:
994
 
            if self._number_needed_bytes is not None:
995
 
                return self._number_needed_bytes - len(self._in_buffer)
996
 
            else:
997
 
                raise AssertionError("don't know how many bytes are expected!")
998
 
 
999
 
 
1000
 
class _ProtocolThreeEncoder(object):
1001
 
 
1002
 
    response_marker = request_marker = MESSAGE_VERSION_THREE
1003
 
 
1004
 
    def __init__(self, write_func):
1005
 
        self._buf = ''
1006
 
        self._real_write_func = write_func
1007
 
 
1008
 
    def _write_func(self, bytes):
1009
 
        self._buf += bytes
1010
 
 
1011
 
    def flush(self):
1012
 
        if self._buf:
1013
 
            self._real_write_func(self._buf)
1014
 
            self._buf = ''
1015
 
 
1016
 
    def _serialise_offsets(self, offsets):
1017
 
        """Serialise a readv offset list."""
1018
 
        txt = []
1019
 
        for start, length in offsets:
1020
 
            txt.append('%d,%d' % (start, length))
1021
 
        return '\n'.join(txt)
1022
 
        
1023
 
    def _write_protocol_version(self):
1024
 
        self._write_func(MESSAGE_VERSION_THREE)
1025
 
 
1026
 
    def _write_prefixed_bencode(self, structure):
1027
 
        bytes = bencode(structure)
1028
 
        self._write_func(struct.pack('!L', len(bytes)))
1029
 
        self._write_func(bytes)
1030
 
 
1031
 
    def _write_headers(self, headers):
1032
 
        self._write_prefixed_bencode(headers)
1033
 
 
1034
 
    def _write_structure(self, args):
1035
 
        self._write_func('s')
1036
 
        utf8_args = []
1037
 
        for arg in args:
1038
 
            if type(arg) is unicode:
1039
 
                utf8_args.append(arg.encode('utf8'))
1040
 
            else:
1041
 
                utf8_args.append(arg)
1042
 
        self._write_prefixed_bencode(utf8_args)
1043
 
 
1044
 
    def _write_end(self):
1045
 
        self._write_func('e')
1046
 
        self.flush()
1047
 
 
1048
 
    def _write_prefixed_body(self, bytes):
1049
 
        self._write_func('b')
1050
 
        self._write_func(struct.pack('!L', len(bytes)))
1051
 
        self._write_func(bytes)
1052
 
 
1053
 
    def _write_error_status(self):
1054
 
        self._write_func('oE')
1055
 
 
1056
 
    def _write_success_status(self):
1057
 
        self._write_func('oS')
1058
 
 
1059
 
 
1060
 
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1061
 
 
1062
 
    def __init__(self, write_func):
1063
 
        _ProtocolThreeEncoder.__init__(self, write_func)
1064
 
        self.response_sent = False
1065
 
        self._headers = {'Software version': bzrlib.__version__}
1066
 
 
1067
 
    def send_error(self, exception):
1068
 
        if self.response_sent:
1069
 
            raise AssertionError(
1070
 
                "send_error(%s) called, but response already sent."
1071
 
                % (exception,))
1072
 
        if isinstance(exception, errors.UnknownSmartMethod):
1073
 
            failure = request.FailedSmartServerResponse(
1074
 
                ('UnknownMethod', exception.verb))
1075
 
            self.send_response(failure)
1076
 
            return
1077
 
        self.response_sent = True
1078
 
        self._write_protocol_version()
1079
 
        self._write_headers(self._headers)
1080
 
        self._write_error_status()
1081
 
        self._write_structure(('error', str(exception)))
1082
 
        self._write_end()
1083
 
 
1084
 
    def send_response(self, response):
1085
 
        if self.response_sent:
1086
 
            raise AssertionError(
1087
 
                "send_response(%r) called, but response already sent."
1088
 
                % (response,))
1089
 
        self.response_sent = True
1090
 
        self._write_protocol_version()
1091
 
        self._write_headers(self._headers)
1092
 
        if response.is_successful():
1093
 
            self._write_success_status()
1094
 
        else:
1095
 
            self._write_error_status()
1096
 
        self._write_structure(response.args)
1097
 
        if response.body is not None:
1098
 
            self._write_prefixed_body(response.body)
1099
 
        elif response.body_stream is not None:
1100
 
            for chunk in response.body_stream:
1101
 
                self._write_prefixed_body(chunk)
1102
 
                self.flush()
1103
 
        self._write_end()
1104
 
        
1105
 
 
1106
 
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1107
 
 
1108
 
    def __init__(self, medium_request):
1109
 
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1110
 
        self._medium_request = medium_request
1111
 
        self._headers = {}
1112
 
 
1113
 
    def set_headers(self, headers):
1114
 
        self._headers = headers.copy()
1115
 
        
1116
 
    def call(self, *args):
1117
 
        if 'hpss' in debug.debug_flags:
1118
 
            mutter('hpss call:   %s', repr(args)[1:-1])
1119
 
            base = getattr(self._medium_request._medium, 'base', None)
1120
 
            if base is not None:
1121
 
                mutter('             (to %s)', base)
1122
 
            self._request_start_time = time.time()
1123
 
        self._write_protocol_version()
1124
 
        self._write_headers(self._headers)
1125
 
        self._write_structure(args)
1126
 
        self._write_end()
1127
 
        self._medium_request.finished_writing()
1128
 
 
1129
 
    def call_with_body_bytes(self, args, body):
1130
 
        """Make a remote call of args with body bytes 'body'.
1131
 
 
1132
 
        After calling this, call read_response_tuple to find the result out.
1133
 
        """
1134
 
        if 'hpss' in debug.debug_flags:
1135
 
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1136
 
            path = getattr(self._medium_request._medium, '_path', None)
1137
 
            if path is not None:
1138
 
                mutter('                  (to %s)', path)
1139
 
            mutter('              %d bytes', len(body))
1140
 
            self._request_start_time = time.time()
1141
 
        self._write_protocol_version()
1142
 
        self._write_headers(self._headers)
1143
 
        self._write_structure(args)
1144
 
        self._write_prefixed_body(body)
1145
 
        self._write_end()
1146
 
        self._medium_request.finished_writing()
1147
 
 
1148
 
    def call_with_body_readv_array(self, args, body):
1149
 
        """Make a remote call with a readv array.
1150
 
 
1151
 
        The body is encoded with one line per readv offset pair. The numbers in
1152
 
        each pair are separated by a comma, and no trailing \n is emitted.
1153
 
        """
1154
 
        if 'hpss' in debug.debug_flags:
1155
 
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
1156
 
            path = getattr(self._medium_request._medium, '_path', None)
1157
 
            if path is not None:
1158
 
                mutter('                  (to %s)', path)
1159
 
            self._request_start_time = time.time()
1160
 
        self._write_protocol_version()
1161
 
        self._write_headers(self._headers)
1162
 
        self._write_structure(args)
1163
 
        readv_bytes = self._serialise_offsets(body)
1164
 
        if 'hpss' in debug.debug_flags:
1165
 
            mutter('              %d bytes in readv request', len(readv_bytes))
1166
 
        self._write_prefixed_body(readv_bytes)
1167
 
        self._write_end()
1168
 
        self._medium_request.finished_writing()
1169