~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/protocol.py

(vila) Fix test failures blocking package builds. (Vincent Ladeuil)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Wire-level encoding and decoding of requests and responses for the smart
 
18
client and server.
 
19
"""
 
20
 
 
21
from __future__ import absolute_import
 
22
 
 
23
import collections
 
24
from cStringIO import StringIO
 
25
import struct
 
26
import sys
 
27
import thread
 
28
import time
 
29
 
 
30
import bzrlib
 
31
from bzrlib import (
 
32
    debug,
 
33
    errors,
 
34
    osutils,
 
35
    )
 
36
from bzrlib.smart import message, request
 
37
from bzrlib.trace import log_exception_quietly, mutter
 
38
from bzrlib.bencode import bdecode_as_tuple, bencode
 
39
 
 
40
 
 
41
# Protocol version strings.  These are sent as prefixes of bzr requests and
 
42
# responses to identify the protocol version being used. (There are no version
 
43
# one strings because that version doesn't send any).
 
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
 
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
 
46
 
 
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
 
48
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
 
49
 
 
50
 
 
51
def _recv_tuple(from_file):
 
52
    req_line = from_file.readline()
 
53
    return _decode_tuple(req_line)
 
54
 
 
55
 
 
56
def _decode_tuple(req_line):
 
57
    if req_line is None or req_line == '':
 
58
        return None
 
59
    if req_line[-1] != '\n':
 
60
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
 
61
    return tuple(req_line[:-1].split('\x01'))
 
62
 
 
63
 
 
64
def _encode_tuple(args):
 
65
    """Encode the tuple args to a bytestream."""
 
66
    joined = '\x01'.join(args) + '\n'
 
67
    if type(joined) is unicode:
 
68
        # XXX: We should fix things so this never happens!  -AJB, 20100304
 
69
        mutter('response args contain unicode, should be only bytes: %r',
 
70
               joined)
 
71
        joined = joined.encode('ascii')
 
72
    return joined
 
73
 
 
74
 
 
75
class Requester(object):
 
76
    """Abstract base class for an object that can issue requests on a smart
 
77
    medium.
 
78
    """
 
79
 
 
80
    def call(self, *args):
 
81
        """Make a remote call.
 
82
 
 
83
        :param args: the arguments of this call.
 
84
        """
 
85
        raise NotImplementedError(self.call)
 
86
 
 
87
    def call_with_body_bytes(self, args, body):
 
88
        """Make a remote call with a body.
 
89
 
 
90
        :param args: the arguments of this call.
 
91
        :type body: str
 
92
        :param body: the body to send with the request.
 
93
        """
 
94
        raise NotImplementedError(self.call_with_body_bytes)
 
95
 
 
96
    def call_with_body_readv_array(self, args, body):
 
97
        """Make a remote call with a readv array.
 
98
 
 
99
        :param args: the arguments of this call.
 
100
        :type body: iterable of (start, length) tuples.
 
101
        :param body: the readv ranges to send with this request.
 
102
        """
 
103
        raise NotImplementedError(self.call_with_body_readv_array)
 
104
 
 
105
    def set_headers(self, headers):
 
106
        raise NotImplementedError(self.set_headers)
 
107
 
 
108
 
 
109
class SmartProtocolBase(object):
 
110
    """Methods common to client and server"""
 
111
 
 
112
    # TODO: this only actually accomodates a single block; possibly should
 
113
    # support multiple chunks?
 
114
    def _encode_bulk_data(self, body):
 
115
        """Encode body as a bulk data chunk."""
 
116
        return ''.join(('%d\n' % len(body), body, 'done\n'))
 
117
 
 
118
    def _serialise_offsets(self, offsets):
 
119
        """Serialise a readv offset list."""
 
120
        txt = []
 
121
        for start, length in offsets:
 
122
            txt.append('%d,%d' % (start, length))
 
123
        return '\n'.join(txt)
 
124
 
 
125
 
 
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
127
    """Server-side encoding and decoding logic for smart version 1."""
 
128
 
 
129
    def __init__(self, backing_transport, write_func, root_client_path='/',
 
130
            jail_root=None):
 
131
        self._backing_transport = backing_transport
 
132
        self._root_client_path = root_client_path
 
133
        self._jail_root = jail_root
 
134
        self.unused_data = ''
 
135
        self._finished = False
 
136
        self.in_buffer = ''
 
137
        self._has_dispatched = False
 
138
        self.request = None
 
139
        self._body_decoder = None
 
140
        self._write_func = write_func
 
141
 
 
142
    def accept_bytes(self, bytes):
 
143
        """Take bytes, and advance the internal state machine appropriately.
 
144
 
 
145
        :param bytes: must be a byte string
 
146
        """
 
147
        if not isinstance(bytes, str):
 
148
            raise ValueError(bytes)
 
149
        self.in_buffer += bytes
 
150
        if not self._has_dispatched:
 
151
            if '\n' not in self.in_buffer:
 
152
                # no command line yet
 
153
                return
 
154
            self._has_dispatched = True
 
155
            try:
 
156
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
157
                first_line += '\n'
 
158
                req_args = _decode_tuple(first_line)
 
159
                self.request = request.SmartServerRequestHandler(
 
160
                    self._backing_transport, commands=request.request_handlers,
 
161
                    root_client_path=self._root_client_path,
 
162
                    jail_root=self._jail_root)
 
163
                self.request.args_received(req_args)
 
164
                if self.request.finished_reading:
 
165
                    # trivial request
 
166
                    self.unused_data = self.in_buffer
 
167
                    self.in_buffer = ''
 
168
                    self._send_response(self.request.response)
 
169
            except KeyboardInterrupt:
 
170
                raise
 
171
            except errors.UnknownSmartMethod, err:
 
172
                protocol_error = errors.SmartProtocolError(
 
173
                    "bad request %r" % (err.verb,))
 
174
                failure = request.FailedSmartServerResponse(
 
175
                    ('error', str(protocol_error)))
 
176
                self._send_response(failure)
 
177
                return
 
178
            except Exception, exception:
 
179
                # everything else: pass to client, flush, and quit
 
180
                log_exception_quietly()
 
181
                self._send_response(request.FailedSmartServerResponse(
 
182
                    ('error', str(exception))))
 
183
                return
 
184
 
 
185
        if self._has_dispatched:
 
186
            if self._finished:
 
187
                # nothing to do.XXX: this routine should be a single state
 
188
                # machine too.
 
189
                self.unused_data += self.in_buffer
 
190
                self.in_buffer = ''
 
191
                return
 
192
            if self._body_decoder is None:
 
193
                self._body_decoder = LengthPrefixedBodyDecoder()
 
194
            self._body_decoder.accept_bytes(self.in_buffer)
 
195
            self.in_buffer = self._body_decoder.unused_data
 
196
            body_data = self._body_decoder.read_pending_data()
 
197
            self.request.accept_body(body_data)
 
198
            if self._body_decoder.finished_reading:
 
199
                self.request.end_of_body()
 
200
                if not self.request.finished_reading:
 
201
                    raise AssertionError("no more body, request not finished")
 
202
            if self.request.response is not None:
 
203
                self._send_response(self.request.response)
 
204
                self.unused_data = self.in_buffer
 
205
                self.in_buffer = ''
 
206
            else:
 
207
                if self.request.finished_reading:
 
208
                    raise AssertionError(
 
209
                        "no response and we have finished reading.")
 
210
 
 
211
    def _send_response(self, response):
 
212
        """Send a smart server response down the output stream."""
 
213
        if self._finished:
 
214
            raise AssertionError('response already sent')
 
215
        args = response.args
 
216
        body = response.body
 
217
        self._finished = True
 
218
        self._write_protocol_version()
 
219
        self._write_success_or_failure_prefix(response)
 
220
        self._write_func(_encode_tuple(args))
 
221
        if body is not None:
 
222
            if not isinstance(body, str):
 
223
                raise ValueError(body)
 
224
            bytes = self._encode_bulk_data(body)
 
225
            self._write_func(bytes)
 
226
 
 
227
    def _write_protocol_version(self):
 
228
        """Write any prefixes this protocol requires.
 
229
 
 
230
        Version one doesn't send protocol versions.
 
231
        """
 
232
 
 
233
    def _write_success_or_failure_prefix(self, response):
 
234
        """Write the protocol specific success/failure prefix.
 
235
 
 
236
        For SmartServerRequestProtocolOne this is omitted but we
 
237
        call is_successful to ensure that the response is valid.
 
238
        """
 
239
        response.is_successful()
 
240
 
 
241
    def next_read_size(self):
 
242
        if self._finished:
 
243
            return 0
 
244
        if self._body_decoder is None:
 
245
            return 1
 
246
        else:
 
247
            return self._body_decoder.next_read_size()
 
248
 
 
249
 
 
250
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
 
251
    r"""Version two of the server side of the smart protocol.
 
252
 
 
253
    This prefixes responses with the value of RESPONSE_VERSION_TWO.
 
254
    """
 
255
 
 
256
    response_marker = RESPONSE_VERSION_TWO
 
257
    request_marker = REQUEST_VERSION_TWO
 
258
 
 
259
    def _write_success_or_failure_prefix(self, response):
 
260
        """Write the protocol specific success/failure prefix."""
 
261
        if response.is_successful():
 
262
            self._write_func('success\n')
 
263
        else:
 
264
            self._write_func('failed\n')
 
265
 
 
266
    def _write_protocol_version(self):
 
267
        r"""Write any prefixes this protocol requires.
 
268
 
 
269
        Version two sends the value of RESPONSE_VERSION_TWO.
 
270
        """
 
271
        self._write_func(self.response_marker)
 
272
 
 
273
    def _send_response(self, response):
 
274
        """Send a smart server response down the output stream."""
 
275
        if (self._finished):
 
276
            raise AssertionError('response already sent')
 
277
        self._finished = True
 
278
        self._write_protocol_version()
 
279
        self._write_success_or_failure_prefix(response)
 
280
        self._write_func(_encode_tuple(response.args))
 
281
        if response.body is not None:
 
282
            if not isinstance(response.body, str):
 
283
                raise AssertionError('body must be a str')
 
284
            if not (response.body_stream is None):
 
285
                raise AssertionError(
 
286
                    'body_stream and body cannot both be set')
 
287
            bytes = self._encode_bulk_data(response.body)
 
288
            self._write_func(bytes)
 
289
        elif response.body_stream is not None:
 
290
            _send_stream(response.body_stream, self._write_func)
 
291
 
 
292
 
 
293
def _send_stream(stream, write_func):
 
294
    write_func('chunked\n')
 
295
    _send_chunks(stream, write_func)
 
296
    write_func('END\n')
 
297
 
 
298
 
 
299
def _send_chunks(stream, write_func):
 
300
    for chunk in stream:
 
301
        if isinstance(chunk, str):
 
302
            bytes = "%x\n%s" % (len(chunk), chunk)
 
303
            write_func(bytes)
 
304
        elif isinstance(chunk, request.FailedSmartServerResponse):
 
305
            write_func('ERR\n')
 
306
            _send_chunks(chunk.args, write_func)
 
307
            return
 
308
        else:
 
309
            raise errors.BzrError(
 
310
                'Chunks must be str or FailedSmartServerResponse, got %r'
 
311
                % chunk)
 
312
 
 
313
 
 
314
class _NeedMoreBytes(Exception):
 
315
    """Raise this inside a _StatefulDecoder to stop decoding until more bytes
 
316
    have been received.
 
317
    """
 
318
 
 
319
    def __init__(self, count=None):
 
320
        """Constructor.
 
321
 
 
322
        :param count: the total number of bytes needed by the current state.
 
323
            May be None if the number of bytes needed is unknown.
 
324
        """
 
325
        self.count = count
 
326
 
 
327
 
 
328
class _StatefulDecoder(object):
 
329
    """Base class for writing state machines to decode byte streams.
 
330
 
 
331
    Subclasses should provide a self.state_accept attribute that accepts bytes
 
332
    and, if appropriate, updates self.state_accept to a different function.
 
333
    accept_bytes will call state_accept as often as necessary to make sure the
 
334
    state machine has progressed as far as possible before it returns.
 
335
 
 
336
    See ProtocolThreeDecoder for an example subclass.
 
337
    """
 
338
 
 
339
    def __init__(self):
 
340
        self.finished_reading = False
 
341
        self._in_buffer_list = []
 
342
        self._in_buffer_len = 0
 
343
        self.unused_data = ''
 
344
        self.bytes_left = None
 
345
        self._number_needed_bytes = None
 
346
 
 
347
    def _get_in_buffer(self):
 
348
        if len(self._in_buffer_list) == 1:
 
349
            return self._in_buffer_list[0]
 
350
        in_buffer = ''.join(self._in_buffer_list)
 
351
        if len(in_buffer) != self._in_buffer_len:
 
352
            raise AssertionError(
 
353
                "Length of buffer did not match expected value: %s != %s"
 
354
                % self._in_buffer_len, len(in_buffer))
 
355
        self._in_buffer_list = [in_buffer]
 
356
        return in_buffer
 
357
 
 
358
    def _get_in_bytes(self, count):
 
359
        """Grab X bytes from the input_buffer.
 
360
 
 
361
        Callers should have already checked that self._in_buffer_len is >
 
362
        count. Note, this does not consume the bytes from the buffer. The
 
363
        caller will still need to call _get_in_buffer() and then
 
364
        _set_in_buffer() if they actually need to consume the bytes.
 
365
        """
 
366
        # check if we can yield the bytes from just the first entry in our list
 
367
        if len(self._in_buffer_list) == 0:
 
368
            raise AssertionError('Callers must be sure we have buffered bytes'
 
369
                ' before calling _get_in_bytes')
 
370
        if len(self._in_buffer_list[0]) > count:
 
371
            return self._in_buffer_list[0][:count]
 
372
        # We can't yield it from the first buffer, so collapse all buffers, and
 
373
        # yield it from that
 
374
        in_buf = self._get_in_buffer()
 
375
        return in_buf[:count]
 
376
 
 
377
    def _set_in_buffer(self, new_buf):
 
378
        if new_buf is not None:
 
379
            self._in_buffer_list = [new_buf]
 
380
            self._in_buffer_len = len(new_buf)
 
381
        else:
 
382
            self._in_buffer_list = []
 
383
            self._in_buffer_len = 0
 
384
 
 
385
    def accept_bytes(self, bytes):
 
386
        """Decode as much of bytes as possible.
 
387
 
 
388
        If 'bytes' contains too much data it will be appended to
 
389
        self.unused_data.
 
390
 
 
391
        finished_reading will be set when no more data is required.  Further
 
392
        data will be appended to self.unused_data.
 
393
        """
 
394
        # accept_bytes is allowed to change the state
 
395
        self._number_needed_bytes = None
 
396
        # lsprof puts a very large amount of time on this specific call for
 
397
        # large readv arrays
 
398
        self._in_buffer_list.append(bytes)
 
399
        self._in_buffer_len += len(bytes)
 
400
        try:
 
401
            # Run the function for the current state.
 
402
            current_state = self.state_accept
 
403
            self.state_accept()
 
404
            while current_state != self.state_accept:
 
405
                # The current state has changed.  Run the function for the new
 
406
                # current state, so that it can:
 
407
                #   - decode any unconsumed bytes left in a buffer, and
 
408
                #   - signal how many more bytes are expected (via raising
 
409
                #     _NeedMoreBytes).
 
410
                current_state = self.state_accept
 
411
                self.state_accept()
 
412
        except _NeedMoreBytes, e:
 
413
            self._number_needed_bytes = e.count
 
414
 
 
415
 
 
416
class ChunkedBodyDecoder(_StatefulDecoder):
 
417
    """Decoder for chunked body data.
 
418
 
 
419
    This is very similar the HTTP's chunked encoding.  See the description of
 
420
    streamed body data in `doc/developers/network-protocol.txt` for details.
 
421
    """
 
422
 
 
423
    def __init__(self):
 
424
        _StatefulDecoder.__init__(self)
 
425
        self.state_accept = self._state_accept_expecting_header
 
426
        self.chunk_in_progress = None
 
427
        self.chunks = collections.deque()
 
428
        self.error = False
 
429
        self.error_in_progress = None
 
430
 
 
431
    def next_read_size(self):
 
432
        # Note: the shortest possible chunk is 2 bytes: '0\n', and the
 
433
        # end-of-body marker is 4 bytes: 'END\n'.
 
434
        if self.state_accept == self._state_accept_reading_chunk:
 
435
            # We're expecting more chunk content.  So we're expecting at least
 
436
            # the rest of this chunk plus an END chunk.
 
437
            return self.bytes_left + 4
 
438
        elif self.state_accept == self._state_accept_expecting_length:
 
439
            if self._in_buffer_len == 0:
 
440
                # We're expecting a chunk length.  There's at least two bytes
 
441
                # left: a digit plus '\n'.
 
442
                return 2
 
443
            else:
 
444
                # We're in the middle of reading a chunk length.  So there's at
 
445
                # least one byte left, the '\n' that terminates the length.
 
446
                return 1
 
447
        elif self.state_accept == self._state_accept_reading_unused:
 
448
            return 1
 
449
        elif self.state_accept == self._state_accept_expecting_header:
 
450
            return max(0, len('chunked\n') - self._in_buffer_len)
 
451
        else:
 
452
            raise AssertionError("Impossible state: %r" % (self.state_accept,))
 
453
 
 
454
    def read_next_chunk(self):
 
455
        try:
 
456
            return self.chunks.popleft()
 
457
        except IndexError:
 
458
            return None
 
459
 
 
460
    def _extract_line(self):
 
461
        in_buf = self._get_in_buffer()
 
462
        pos = in_buf.find('\n')
 
463
        if pos == -1:
 
464
            # We haven't read a complete line yet, so request more bytes before
 
465
            # we continue.
 
466
            raise _NeedMoreBytes(1)
 
467
        line = in_buf[:pos]
 
468
        # Trim the prefix (including '\n' delimiter) from the _in_buffer.
 
469
        self._set_in_buffer(in_buf[pos+1:])
 
470
        return line
 
471
 
 
472
    def _finished(self):
 
473
        self.unused_data = self._get_in_buffer()
 
474
        self._in_buffer_list = []
 
475
        self._in_buffer_len = 0
 
476
        self.state_accept = self._state_accept_reading_unused
 
477
        if self.error:
 
478
            error_args = tuple(self.error_in_progress)
 
479
            self.chunks.append(request.FailedSmartServerResponse(error_args))
 
480
            self.error_in_progress = None
 
481
        self.finished_reading = True
 
482
 
 
483
    def _state_accept_expecting_header(self):
 
484
        prefix = self._extract_line()
 
485
        if prefix == 'chunked':
 
486
            self.state_accept = self._state_accept_expecting_length
 
487
        else:
 
488
            raise errors.SmartProtocolError(
 
489
                'Bad chunked body header: "%s"' % (prefix,))
 
490
 
 
491
    def _state_accept_expecting_length(self):
 
492
        prefix = self._extract_line()
 
493
        if prefix == 'ERR':
 
494
            self.error = True
 
495
            self.error_in_progress = []
 
496
            self._state_accept_expecting_length()
 
497
            return
 
498
        elif prefix == 'END':
 
499
            # We've read the end-of-body marker.
 
500
            # Any further bytes are unused data, including the bytes left in
 
501
            # the _in_buffer.
 
502
            self._finished()
 
503
            return
 
504
        else:
 
505
            self.bytes_left = int(prefix, 16)
 
506
            self.chunk_in_progress = ''
 
507
            self.state_accept = self._state_accept_reading_chunk
 
508
 
 
509
    def _state_accept_reading_chunk(self):
 
510
        in_buf = self._get_in_buffer()
 
511
        in_buffer_len = len(in_buf)
 
512
        self.chunk_in_progress += in_buf[:self.bytes_left]
 
513
        self._set_in_buffer(in_buf[self.bytes_left:])
 
514
        self.bytes_left -= in_buffer_len
 
515
        if self.bytes_left <= 0:
 
516
            # Finished with chunk
 
517
            self.bytes_left = None
 
518
            if self.error:
 
519
                self.error_in_progress.append(self.chunk_in_progress)
 
520
            else:
 
521
                self.chunks.append(self.chunk_in_progress)
 
522
            self.chunk_in_progress = None
 
523
            self.state_accept = self._state_accept_expecting_length
 
524
 
 
525
    def _state_accept_reading_unused(self):
 
526
        self.unused_data += self._get_in_buffer()
 
527
        self._in_buffer_list = []
 
528
 
 
529
 
 
530
class LengthPrefixedBodyDecoder(_StatefulDecoder):
 
531
    """Decodes the length-prefixed bulk data."""
 
532
 
 
533
    def __init__(self):
 
534
        _StatefulDecoder.__init__(self)
 
535
        self.state_accept = self._state_accept_expecting_length
 
536
        self.state_read = self._state_read_no_data
 
537
        self._body = ''
 
538
        self._trailer_buffer = ''
 
539
 
 
540
    def next_read_size(self):
 
541
        if self.bytes_left is not None:
 
542
            # Ideally we want to read all the remainder of the body and the
 
543
            # trailer in one go.
 
544
            return self.bytes_left + 5
 
545
        elif self.state_accept == self._state_accept_reading_trailer:
 
546
            # Just the trailer left
 
547
            return 5 - len(self._trailer_buffer)
 
548
        elif self.state_accept == self._state_accept_expecting_length:
 
549
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
550
            # 'done\n').
 
551
            return 6
 
552
        else:
 
553
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
554
            return 1
 
555
 
 
556
    def read_pending_data(self):
 
557
        """Return any pending data that has been decoded."""
 
558
        return self.state_read()
 
559
 
 
560
    def _state_accept_expecting_length(self):
 
561
        in_buf = self._get_in_buffer()
 
562
        pos = in_buf.find('\n')
 
563
        if pos == -1:
 
564
            return
 
565
        self.bytes_left = int(in_buf[:pos])
 
566
        self._set_in_buffer(in_buf[pos+1:])
 
567
        self.state_accept = self._state_accept_reading_body
 
568
        self.state_read = self._state_read_body_buffer
 
569
 
 
570
    def _state_accept_reading_body(self):
 
571
        in_buf = self._get_in_buffer()
 
572
        self._body += in_buf
 
573
        self.bytes_left -= len(in_buf)
 
574
        self._set_in_buffer(None)
 
575
        if self.bytes_left <= 0:
 
576
            # Finished with body
 
577
            if self.bytes_left != 0:
 
578
                self._trailer_buffer = self._body[self.bytes_left:]
 
579
                self._body = self._body[:self.bytes_left]
 
580
            self.bytes_left = None
 
581
            self.state_accept = self._state_accept_reading_trailer
 
582
 
 
583
    def _state_accept_reading_trailer(self):
 
584
        self._trailer_buffer += self._get_in_buffer()
 
585
        self._set_in_buffer(None)
 
586
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
587
        # a ProtocolViolation exception?
 
588
        if self._trailer_buffer.startswith('done\n'):
 
589
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
590
            self.state_accept = self._state_accept_reading_unused
 
591
            self.finished_reading = True
 
592
 
 
593
    def _state_accept_reading_unused(self):
 
594
        self.unused_data += self._get_in_buffer()
 
595
        self._set_in_buffer(None)
 
596
 
 
597
    def _state_read_no_data(self):
 
598
        return ''
 
599
 
 
600
    def _state_read_body_buffer(self):
 
601
        result = self._body
 
602
        self._body = ''
 
603
        return result
 
604
 
 
605
 
 
606
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
 
607
                                    message.ResponseHandler):
 
608
    """The client-side protocol for smart version 1."""
 
609
 
 
610
    def __init__(self, request):
 
611
        """Construct a SmartClientRequestProtocolOne.
 
612
 
 
613
        :param request: A SmartClientMediumRequest to serialise onto and
 
614
            deserialise from.
 
615
        """
 
616
        self._request = request
 
617
        self._body_buffer = None
 
618
        self._request_start_time = None
 
619
        self._last_verb = None
 
620
        self._headers = None
 
621
 
 
622
    def set_headers(self, headers):
 
623
        self._headers = dict(headers)
 
624
 
 
625
    def call(self, *args):
 
626
        if 'hpss' in debug.debug_flags:
 
627
            mutter('hpss call:   %s', repr(args)[1:-1])
 
628
            if getattr(self._request._medium, 'base', None) is not None:
 
629
                mutter('             (to %s)', self._request._medium.base)
 
630
            self._request_start_time = osutils.timer_func()
 
631
        self._write_args(args)
 
632
        self._request.finished_writing()
 
633
        self._last_verb = args[0]
 
634
 
 
635
    def call_with_body_bytes(self, args, body):
 
636
        """Make a remote call of args with body bytes 'body'.
 
637
 
 
638
        After calling this, call read_response_tuple to find the result out.
 
639
        """
 
640
        if 'hpss' in debug.debug_flags:
 
641
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
642
            if getattr(self._request._medium, '_path', None) is not None:
 
643
                mutter('                  (to %s)', self._request._medium._path)
 
644
            mutter('              %d bytes', len(body))
 
645
            self._request_start_time = osutils.timer_func()
 
646
            if 'hpssdetail' in debug.debug_flags:
 
647
                mutter('hpss body content: %s', body)
 
648
        self._write_args(args)
 
649
        bytes = self._encode_bulk_data(body)
 
650
        self._request.accept_bytes(bytes)
 
651
        self._request.finished_writing()
 
652
        self._last_verb = args[0]
 
653
 
 
654
    def call_with_body_readv_array(self, args, body):
 
655
        """Make a remote call with a readv array.
 
656
 
 
657
        The body is encoded with one line per readv offset pair. The numbers in
 
658
        each pair are separated by a comma, and no trailing \\n is emitted.
 
659
        """
 
660
        if 'hpss' in debug.debug_flags:
 
661
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
662
            if getattr(self._request._medium, '_path', None) is not None:
 
663
                mutter('                  (to %s)', self._request._medium._path)
 
664
            self._request_start_time = osutils.timer_func()
 
665
        self._write_args(args)
 
666
        readv_bytes = self._serialise_offsets(body)
 
667
        bytes = self._encode_bulk_data(readv_bytes)
 
668
        self._request.accept_bytes(bytes)
 
669
        self._request.finished_writing()
 
670
        if 'hpss' in debug.debug_flags:
 
671
            mutter('              %d bytes in readv request', len(readv_bytes))
 
672
        self._last_verb = args[0]
 
673
 
 
674
    def call_with_body_stream(self, args, stream):
 
675
        # Protocols v1 and v2 don't support body streams.  So it's safe to
 
676
        # assume that a v1/v2 server doesn't support whatever method we're
 
677
        # trying to call with a body stream.
 
678
        self._request.finished_writing()
 
679
        self._request.finished_reading()
 
680
        raise errors.UnknownSmartMethod(args[0])
 
681
 
 
682
    def cancel_read_body(self):
 
683
        """After expecting a body, a response code may indicate one otherwise.
 
684
 
 
685
        This method lets the domain client inform the protocol that no body
 
686
        will be transmitted. This is a terminal method: after calling it the
 
687
        protocol is not able to be used further.
 
688
        """
 
689
        self._request.finished_reading()
 
690
 
 
691
    def _read_response_tuple(self):
 
692
        result = self._recv_tuple()
 
693
        if 'hpss' in debug.debug_flags:
 
694
            if self._request_start_time is not None:
 
695
                mutter('   result:   %6.3fs  %s',
 
696
                       osutils.timer_func() - self._request_start_time,
 
697
                       repr(result)[1:-1])
 
698
                self._request_start_time = None
 
699
            else:
 
700
                mutter('   result:   %s', repr(result)[1:-1])
 
701
        return result
 
702
 
 
703
    def read_response_tuple(self, expect_body=False):
 
704
        """Read a response tuple from the wire.
 
705
 
 
706
        This should only be called once.
 
707
        """
 
708
        result = self._read_response_tuple()
 
709
        self._response_is_unknown_method(result)
 
710
        self._raise_args_if_error(result)
 
711
        if not expect_body:
 
712
            self._request.finished_reading()
 
713
        return result
 
714
 
 
715
    def _raise_args_if_error(self, result_tuple):
 
716
        # Later protocol versions have an explicit flag in the protocol to say
 
717
        # if an error response is "failed" or not.  In version 1 we don't have
 
718
        # that luxury.  So here is a complete list of errors that can be
 
719
        # returned in response to existing version 1 smart requests.  Responses
 
720
        # starting with these codes are always "failed" responses.
 
721
        v1_error_codes = [
 
722
            'norepository',
 
723
            'NoSuchFile',
 
724
            'FileExists',
 
725
            'DirectoryNotEmpty',
 
726
            'ShortReadvError',
 
727
            'UnicodeEncodeError',
 
728
            'UnicodeDecodeError',
 
729
            'ReadOnlyError',
 
730
            'nobranch',
 
731
            'NoSuchRevision',
 
732
            'nosuchrevision',
 
733
            'LockContention',
 
734
            'UnlockableTransport',
 
735
            'LockFailed',
 
736
            'TokenMismatch',
 
737
            'ReadError',
 
738
            'PermissionDenied',
 
739
            ]
 
740
        if result_tuple[0] in v1_error_codes:
 
741
            self._request.finished_reading()
 
742
            raise errors.ErrorFromSmartServer(result_tuple)
 
743
 
 
744
    def _response_is_unknown_method(self, result_tuple):
 
745
        """Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
 
746
        method' response to the request.
 
747
 
 
748
        :param response: The response from a smart client call_expecting_body
 
749
            call.
 
750
        :param verb: The verb used in that call.
 
751
        :raises: UnexpectedSmartServerResponse
 
752
        """
 
753
        if (result_tuple == ('error', "Generic bzr smart protocol error: "
 
754
                "bad request '%s'" % self._last_verb) or
 
755
              result_tuple == ('error', "Generic bzr smart protocol error: "
 
756
                "bad request u'%s'" % self._last_verb)):
 
757
            # The response will have no body, so we've finished reading.
 
758
            self._request.finished_reading()
 
759
            raise errors.UnknownSmartMethod(self._last_verb)
 
760
 
 
761
    def read_body_bytes(self, count=-1):
 
762
        """Read bytes from the body, decoding into a byte stream.
 
763
 
 
764
        We read all bytes at once to ensure we've checked the trailer for
 
765
        errors, and then feed the buffer back as read_body_bytes is called.
 
766
        """
 
767
        if self._body_buffer is not None:
 
768
            return self._body_buffer.read(count)
 
769
        _body_decoder = LengthPrefixedBodyDecoder()
 
770
 
 
771
        while not _body_decoder.finished_reading:
 
772
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
773
            if bytes == '':
 
774
                # end of file encountered reading from server
 
775
                raise errors.ConnectionReset(
 
776
                    "Connection lost while reading response body.")
 
777
            _body_decoder.accept_bytes(bytes)
 
778
        self._request.finished_reading()
 
779
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
780
        # XXX: TODO check the trailer result.
 
781
        if 'hpss' in debug.debug_flags:
 
782
            mutter('              %d body bytes read',
 
783
                   len(self._body_buffer.getvalue()))
 
784
        return self._body_buffer.read(count)
 
785
 
 
786
    def _recv_tuple(self):
 
787
        """Receive a tuple from the medium request."""
 
788
        return _decode_tuple(self._request.read_line())
 
789
 
 
790
    def query_version(self):
 
791
        """Return protocol version number of the server."""
 
792
        self.call('hello')
 
793
        resp = self.read_response_tuple()
 
794
        if resp == ('ok', '1'):
 
795
            return 1
 
796
        elif resp == ('ok', '2'):
 
797
            return 2
 
798
        else:
 
799
            raise errors.SmartProtocolError("bad response %r" % (resp,))
 
800
 
 
801
    def _write_args(self, args):
 
802
        self._write_protocol_version()
 
803
        bytes = _encode_tuple(args)
 
804
        self._request.accept_bytes(bytes)
 
805
 
 
806
    def _write_protocol_version(self):
 
807
        """Write any prefixes this protocol requires.
 
808
 
 
809
        Version one doesn't send protocol versions.
 
810
        """
 
811
 
 
812
 
 
813
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
 
814
    """Version two of the client side of the smart protocol.
 
815
 
 
816
    This prefixes the request with the value of REQUEST_VERSION_TWO.
 
817
    """
 
818
 
 
819
    response_marker = RESPONSE_VERSION_TWO
 
820
    request_marker = REQUEST_VERSION_TWO
 
821
 
 
822
    def read_response_tuple(self, expect_body=False):
 
823
        """Read a response tuple from the wire.
 
824
 
 
825
        This should only be called once.
 
826
        """
 
827
        version = self._request.read_line()
 
828
        if version != self.response_marker:
 
829
            self._request.finished_reading()
 
830
            raise errors.UnexpectedProtocolVersionMarker(version)
 
831
        response_status = self._request.read_line()
 
832
        result = SmartClientRequestProtocolOne._read_response_tuple(self)
 
833
        self._response_is_unknown_method(result)
 
834
        if response_status == 'success\n':
 
835
            self.response_status = True
 
836
            if not expect_body:
 
837
                self._request.finished_reading()
 
838
            return result
 
839
        elif response_status == 'failed\n':
 
840
            self.response_status = False
 
841
            self._request.finished_reading()
 
842
            raise errors.ErrorFromSmartServer(result)
 
843
        else:
 
844
            raise errors.SmartProtocolError(
 
845
                'bad protocol status %r' % response_status)
 
846
 
 
847
    def _write_protocol_version(self):
 
848
        """Write any prefixes this protocol requires.
 
849
 
 
850
        Version two sends the value of REQUEST_VERSION_TWO.
 
851
        """
 
852
        self._request.accept_bytes(self.request_marker)
 
853
 
 
854
    def read_streamed_body(self):
 
855
        """Read bytes from the body, decoding into a byte stream.
 
856
        """
 
857
        # Read no more than 64k at a time so that we don't risk error 10055 (no
 
858
        # buffer space available) on Windows.
 
859
        _body_decoder = ChunkedBodyDecoder()
 
860
        while not _body_decoder.finished_reading:
 
861
            bytes = self._request.read_bytes(_body_decoder.next_read_size())
 
862
            if bytes == '':
 
863
                # end of file encountered reading from server
 
864
                raise errors.ConnectionReset(
 
865
                    "Connection lost while reading streamed body.")
 
866
            _body_decoder.accept_bytes(bytes)
 
867
            for body_bytes in iter(_body_decoder.read_next_chunk, None):
 
868
                if 'hpss' in debug.debug_flags and type(body_bytes) is str:
 
869
                    mutter('              %d byte chunk read',
 
870
                           len(body_bytes))
 
871
                yield body_bytes
 
872
        self._request.finished_reading()
 
873
 
 
874
 
 
875
def build_server_protocol_three(backing_transport, write_func,
 
876
                                root_client_path, jail_root=None):
 
877
    request_handler = request.SmartServerRequestHandler(
 
878
        backing_transport, commands=request.request_handlers,
 
879
        root_client_path=root_client_path, jail_root=jail_root)
 
880
    responder = ProtocolThreeResponder(write_func)
 
881
    message_handler = message.ConventionalRequestHandler(request_handler, responder)
 
882
    return ProtocolThreeDecoder(message_handler)
 
883
 
 
884
 
 
885
class ProtocolThreeDecoder(_StatefulDecoder):
 
886
 
 
887
    response_marker = RESPONSE_VERSION_THREE
 
888
    request_marker = REQUEST_VERSION_THREE
 
889
 
 
890
    def __init__(self, message_handler, expect_version_marker=False):
 
891
        _StatefulDecoder.__init__(self)
 
892
        self._has_dispatched = False
 
893
        # Initial state
 
894
        if expect_version_marker:
 
895
            self.state_accept = self._state_accept_expecting_protocol_version
 
896
            # We're expecting at least the protocol version marker + some
 
897
            # headers.
 
898
            self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
 
899
        else:
 
900
            self.state_accept = self._state_accept_expecting_headers
 
901
            self._number_needed_bytes = 4
 
902
        self.decoding_failed = False
 
903
        self.request_handler = self.message_handler = message_handler
 
904
 
 
905
    def accept_bytes(self, bytes):
 
906
        self._number_needed_bytes = None
 
907
        try:
 
908
            _StatefulDecoder.accept_bytes(self, bytes)
 
909
        except KeyboardInterrupt:
 
910
            raise
 
911
        except errors.SmartMessageHandlerError, exception:
 
912
            # We do *not* set self.decoding_failed here.  The message handler
 
913
            # has raised an error, but the decoder is still able to parse bytes
 
914
            # and determine when this message ends.
 
915
            if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
 
916
                log_exception_quietly()
 
917
            self.message_handler.protocol_error(exception.exc_value)
 
918
            # The state machine is ready to continue decoding, but the
 
919
            # exception has interrupted the loop that runs the state machine.
 
920
            # So we call accept_bytes again to restart it.
 
921
            self.accept_bytes('')
 
922
        except Exception, exception:
 
923
            # The decoder itself has raised an exception.  We cannot continue
 
924
            # decoding.
 
925
            self.decoding_failed = True
 
926
            if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
 
927
                # This happens during normal operation when the client tries a
 
928
                # protocol version the server doesn't understand, so no need to
 
929
                # log a traceback every time.
 
930
                # Note that this can only happen when
 
931
                # expect_version_marker=True, which is only the case on the
 
932
                # client side.
 
933
                pass
 
934
            else:
 
935
                log_exception_quietly()
 
936
            self.message_handler.protocol_error(exception)
 
937
 
 
938
    def _extract_length_prefixed_bytes(self):
 
939
        if self._in_buffer_len < 4:
 
940
            # A length prefix by itself is 4 bytes, and we don't even have that
 
941
            # many yet.
 
942
            raise _NeedMoreBytes(4)
 
943
        (length,) = struct.unpack('!L', self._get_in_bytes(4))
 
944
        end_of_bytes = 4 + length
 
945
        if self._in_buffer_len < end_of_bytes:
 
946
            # We haven't yet read as many bytes as the length-prefix says there
 
947
            # are.
 
948
            raise _NeedMoreBytes(end_of_bytes)
 
949
        # Extract the bytes from the buffer.
 
950
        in_buf = self._get_in_buffer()
 
951
        bytes = in_buf[4:end_of_bytes]
 
952
        self._set_in_buffer(in_buf[end_of_bytes:])
 
953
        return bytes
 
954
 
 
955
    def _extract_prefixed_bencoded_data(self):
 
956
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
957
        try:
 
958
            decoded = bdecode_as_tuple(prefixed_bytes)
 
959
        except ValueError:
 
960
            raise errors.SmartProtocolError(
 
961
                'Bytes %r not bencoded' % (prefixed_bytes,))
 
962
        return decoded
 
963
 
 
964
    def _extract_single_byte(self):
 
965
        if self._in_buffer_len == 0:
 
966
            # The buffer is empty
 
967
            raise _NeedMoreBytes(1)
 
968
        in_buf = self._get_in_buffer()
 
969
        one_byte = in_buf[0]
 
970
        self._set_in_buffer(in_buf[1:])
 
971
        return one_byte
 
972
 
 
973
    def _state_accept_expecting_protocol_version(self):
 
974
        needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
 
975
        in_buf = self._get_in_buffer()
 
976
        if needed_bytes > 0:
 
977
            # We don't have enough bytes to check if the protocol version
 
978
            # marker is right.  But we can check if it is already wrong by
 
979
            # checking that the start of MESSAGE_VERSION_THREE matches what
 
980
            # we've read so far.
 
981
            # [In fact, if the remote end isn't bzr we might never receive
 
982
            # len(MESSAGE_VERSION_THREE) bytes.  So if the bytes we have so far
 
983
            # are wrong then we should just raise immediately rather than
 
984
            # stall.]
 
985
            if not MESSAGE_VERSION_THREE.startswith(in_buf):
 
986
                # We have enough bytes to know the protocol version is wrong
 
987
                raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
988
            raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
 
989
        if not in_buf.startswith(MESSAGE_VERSION_THREE):
 
990
            raise errors.UnexpectedProtocolVersionMarker(in_buf)
 
991
        self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
 
992
        self.state_accept = self._state_accept_expecting_headers
 
993
 
 
994
    def _state_accept_expecting_headers(self):
 
995
        decoded = self._extract_prefixed_bencoded_data()
 
996
        if type(decoded) is not dict:
 
997
            raise errors.SmartProtocolError(
 
998
                'Header object %r is not a dict' % (decoded,))
 
999
        self.state_accept = self._state_accept_expecting_message_part
 
1000
        try:
 
1001
            self.message_handler.headers_received(decoded)
 
1002
        except:
 
1003
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1004
 
 
1005
    def _state_accept_expecting_message_part(self):
 
1006
        message_part_kind = self._extract_single_byte()
 
1007
        if message_part_kind == 'o':
 
1008
            self.state_accept = self._state_accept_expecting_one_byte
 
1009
        elif message_part_kind == 's':
 
1010
            self.state_accept = self._state_accept_expecting_structure
 
1011
        elif message_part_kind == 'b':
 
1012
            self.state_accept = self._state_accept_expecting_bytes
 
1013
        elif message_part_kind == 'e':
 
1014
            self.done()
 
1015
        else:
 
1016
            raise errors.SmartProtocolError(
 
1017
                'Bad message kind byte: %r' % (message_part_kind,))
 
1018
 
 
1019
    def _state_accept_expecting_one_byte(self):
 
1020
        byte = self._extract_single_byte()
 
1021
        self.state_accept = self._state_accept_expecting_message_part
 
1022
        try:
 
1023
            self.message_handler.byte_part_received(byte)
 
1024
        except:
 
1025
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1026
 
 
1027
    def _state_accept_expecting_bytes(self):
 
1028
        # XXX: this should not buffer whole message part, but instead deliver
 
1029
        # the bytes as they arrive.
 
1030
        prefixed_bytes = self._extract_length_prefixed_bytes()
 
1031
        self.state_accept = self._state_accept_expecting_message_part
 
1032
        try:
 
1033
            self.message_handler.bytes_part_received(prefixed_bytes)
 
1034
        except:
 
1035
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1036
 
 
1037
    def _state_accept_expecting_structure(self):
 
1038
        structure = self._extract_prefixed_bencoded_data()
 
1039
        self.state_accept = self._state_accept_expecting_message_part
 
1040
        try:
 
1041
            self.message_handler.structure_part_received(structure)
 
1042
        except:
 
1043
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1044
 
 
1045
    def done(self):
 
1046
        self.unused_data = self._get_in_buffer()
 
1047
        self._set_in_buffer(None)
 
1048
        self.state_accept = self._state_accept_reading_unused
 
1049
        try:
 
1050
            self.message_handler.end_received()
 
1051
        except:
 
1052
            raise errors.SmartMessageHandlerError(sys.exc_info())
 
1053
 
 
1054
    def _state_accept_reading_unused(self):
 
1055
        self.unused_data += self._get_in_buffer()
 
1056
        self._set_in_buffer(None)
 
1057
 
 
1058
    def next_read_size(self):
 
1059
        if self.state_accept == self._state_accept_reading_unused:
 
1060
            return 0
 
1061
        elif self.decoding_failed:
 
1062
            # An exception occured while processing this message, probably from
 
1063
            # self.message_handler.  We're not sure that this state machine is
 
1064
            # in a consistent state, so just signal that we're done (i.e. give
 
1065
            # up).
 
1066
            return 0
 
1067
        else:
 
1068
            if self._number_needed_bytes is not None:
 
1069
                return self._number_needed_bytes - self._in_buffer_len
 
1070
            else:
 
1071
                raise AssertionError("don't know how many bytes are expected!")
 
1072
 
 
1073
 
 
1074
class _ProtocolThreeEncoder(object):
 
1075
 
 
1076
    response_marker = request_marker = MESSAGE_VERSION_THREE
 
1077
    BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
 
1078
 
 
1079
    def __init__(self, write_func):
 
1080
        self._buf = []
 
1081
        self._buf_len = 0
 
1082
        self._real_write_func = write_func
 
1083
 
 
1084
    def _write_func(self, bytes):
 
1085
        # TODO: Another possibility would be to turn this into an async model.
 
1086
        #       Where we let another thread know that we have some bytes if
 
1087
        #       they want it, but we don't actually block for it
 
1088
        #       Note that osutils.send_all always sends 64kB chunks anyway, so
 
1089
        #       we might just push out smaller bits at a time?
 
1090
        self._buf.append(bytes)
 
1091
        self._buf_len += len(bytes)
 
1092
        if self._buf_len > self.BUFFER_SIZE:
 
1093
            self.flush()
 
1094
 
 
1095
    def flush(self):
 
1096
        if self._buf:
 
1097
            self._real_write_func(''.join(self._buf))
 
1098
            del self._buf[:]
 
1099
            self._buf_len = 0
 
1100
 
 
1101
    def _serialise_offsets(self, offsets):
 
1102
        """Serialise a readv offset list."""
 
1103
        txt = []
 
1104
        for start, length in offsets:
 
1105
            txt.append('%d,%d' % (start, length))
 
1106
        return '\n'.join(txt)
 
1107
 
 
1108
    def _write_protocol_version(self):
 
1109
        self._write_func(MESSAGE_VERSION_THREE)
 
1110
 
 
1111
    def _write_prefixed_bencode(self, structure):
 
1112
        bytes = bencode(structure)
 
1113
        self._write_func(struct.pack('!L', len(bytes)))
 
1114
        self._write_func(bytes)
 
1115
 
 
1116
    def _write_headers(self, headers):
 
1117
        self._write_prefixed_bencode(headers)
 
1118
 
 
1119
    def _write_structure(self, args):
 
1120
        self._write_func('s')
 
1121
        utf8_args = []
 
1122
        for arg in args:
 
1123
            if type(arg) is unicode:
 
1124
                utf8_args.append(arg.encode('utf8'))
 
1125
            else:
 
1126
                utf8_args.append(arg)
 
1127
        self._write_prefixed_bencode(utf8_args)
 
1128
 
 
1129
    def _write_end(self):
 
1130
        self._write_func('e')
 
1131
        self.flush()
 
1132
 
 
1133
    def _write_prefixed_body(self, bytes):
 
1134
        self._write_func('b')
 
1135
        self._write_func(struct.pack('!L', len(bytes)))
 
1136
        self._write_func(bytes)
 
1137
 
 
1138
    def _write_chunked_body_start(self):
 
1139
        self._write_func('oC')
 
1140
 
 
1141
    def _write_error_status(self):
 
1142
        self._write_func('oE')
 
1143
 
 
1144
    def _write_success_status(self):
 
1145
        self._write_func('oS')
 
1146
 
 
1147
 
 
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
 
1149
 
 
1150
    def __init__(self, write_func):
 
1151
        _ProtocolThreeEncoder.__init__(self, write_func)
 
1152
        self.response_sent = False
 
1153
        self._headers = {'Software version': bzrlib.__version__}
 
1154
        if 'hpss' in debug.debug_flags:
 
1155
            self._thread_id = thread.get_ident()
 
1156
            self._response_start_time = None
 
1157
 
 
1158
    def _trace(self, action, message, extra_bytes=None, include_time=False):
 
1159
        if self._response_start_time is None:
 
1160
            self._response_start_time = osutils.timer_func()
 
1161
        if include_time:
 
1162
            t = '%5.3fs ' % (time.clock() - self._response_start_time)
 
1163
        else:
 
1164
            t = ''
 
1165
        if extra_bytes is None:
 
1166
            extra = ''
 
1167
        else:
 
1168
            extra = ' ' + repr(extra_bytes[:40])
 
1169
            if len(extra) > 33:
 
1170
                extra = extra[:29] + extra[-1] + '...'
 
1171
        mutter('%12s: [%s] %s%s%s'
 
1172
               % (action, self._thread_id, t, message, extra))
 
1173
 
 
1174
    def send_error(self, exception):
 
1175
        if self.response_sent:
 
1176
            raise AssertionError(
 
1177
                "send_error(%s) called, but response already sent."
 
1178
                % (exception,))
 
1179
        if isinstance(exception, errors.UnknownSmartMethod):
 
1180
            failure = request.FailedSmartServerResponse(
 
1181
                ('UnknownMethod', exception.verb))
 
1182
            self.send_response(failure)
 
1183
            return
 
1184
        if 'hpss' in debug.debug_flags:
 
1185
            self._trace('error', str(exception))
 
1186
        self.response_sent = True
 
1187
        self._write_protocol_version()
 
1188
        self._write_headers(self._headers)
 
1189
        self._write_error_status()
 
1190
        self._write_structure(('error', str(exception)))
 
1191
        self._write_end()
 
1192
 
 
1193
    def send_response(self, response):
 
1194
        if self.response_sent:
 
1195
            raise AssertionError(
 
1196
                "send_response(%r) called, but response already sent."
 
1197
                % (response,))
 
1198
        self.response_sent = True
 
1199
        self._write_protocol_version()
 
1200
        self._write_headers(self._headers)
 
1201
        if response.is_successful():
 
1202
            self._write_success_status()
 
1203
        else:
 
1204
            self._write_error_status()
 
1205
        if 'hpss' in debug.debug_flags:
 
1206
            self._trace('response', repr(response.args))
 
1207
        self._write_structure(response.args)
 
1208
        if response.body is not None:
 
1209
            self._write_prefixed_body(response.body)
 
1210
            if 'hpss' in debug.debug_flags:
 
1211
                self._trace('body', '%d bytes' % (len(response.body),),
 
1212
                            response.body, include_time=True)
 
1213
        elif response.body_stream is not None:
 
1214
            count = num_bytes = 0
 
1215
            first_chunk = None
 
1216
            for exc_info, chunk in _iter_with_errors(response.body_stream):
 
1217
                count += 1
 
1218
                if exc_info is not None:
 
1219
                    self._write_error_status()
 
1220
                    error_struct = request._translate_error(exc_info[1])
 
1221
                    self._write_structure(error_struct)
 
1222
                    break
 
1223
                else:
 
1224
                    if isinstance(chunk, request.FailedSmartServerResponse):
 
1225
                        self._write_error_status()
 
1226
                        self._write_structure(chunk.args)
 
1227
                        break
 
1228
                    num_bytes += len(chunk)
 
1229
                    if first_chunk is None:
 
1230
                        first_chunk = chunk
 
1231
                    self._write_prefixed_body(chunk)
 
1232
                    self.flush()
 
1233
                    if 'hpssdetail' in debug.debug_flags:
 
1234
                        # Not worth timing separately, as _write_func is
 
1235
                        # actually buffered
 
1236
                        self._trace('body chunk',
 
1237
                                    '%d bytes' % (len(chunk),),
 
1238
                                    chunk, suppress_time=True)
 
1239
            if 'hpss' in debug.debug_flags:
 
1240
                self._trace('body stream',
 
1241
                            '%d bytes %d chunks' % (num_bytes, count),
 
1242
                            first_chunk)
 
1243
        self._write_end()
 
1244
        if 'hpss' in debug.debug_flags:
 
1245
            self._trace('response end', '', include_time=True)
 
1246
 
 
1247
 
 
1248
def _iter_with_errors(iterable):
 
1249
    """Handle errors from iterable.next().
 
1250
 
 
1251
    Use like::
 
1252
 
 
1253
        for exc_info, value in _iter_with_errors(iterable):
 
1254
            ...
 
1255
 
 
1256
    This is a safer alternative to::
 
1257
 
 
1258
        try:
 
1259
            for value in iterable:
 
1260
               ...
 
1261
        except:
 
1262
            ...
 
1263
 
 
1264
    Because the latter will catch errors from the for-loop body, not just
 
1265
    iterable.next()
 
1266
 
 
1267
    If an error occurs, exc_info will be a exc_info tuple, and the generator
 
1268
    will terminate.  Otherwise exc_info will be None, and value will be the
 
1269
    value from iterable.next().  Note that KeyboardInterrupt and SystemExit
 
1270
    will not be itercepted.
 
1271
    """
 
1272
    iterator = iter(iterable)
 
1273
    while True:
 
1274
        try:
 
1275
            yield None, iterator.next()
 
1276
        except StopIteration:
 
1277
            return
 
1278
        except (KeyboardInterrupt, SystemExit):
 
1279
            raise
 
1280
        except Exception:
 
1281
            mutter('_iter_with_errors caught error')
 
1282
            log_exception_quietly()
 
1283
            yield sys.exc_info(), None
 
1284
            return
 
1285
 
 
1286
 
 
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
 
1288
 
 
1289
    def __init__(self, medium_request):
 
1290
        _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
 
1291
        self._medium_request = medium_request
 
1292
        self._headers = {}
 
1293
        self.body_stream_started = None
 
1294
 
 
1295
    def set_headers(self, headers):
 
1296
        self._headers = headers.copy()
 
1297
 
 
1298
    def call(self, *args):
 
1299
        if 'hpss' in debug.debug_flags:
 
1300
            mutter('hpss call:   %s', repr(args)[1:-1])
 
1301
            base = getattr(self._medium_request._medium, 'base', None)
 
1302
            if base is not None:
 
1303
                mutter('             (to %s)', base)
 
1304
            self._request_start_time = osutils.timer_func()
 
1305
        self._write_protocol_version()
 
1306
        self._write_headers(self._headers)
 
1307
        self._write_structure(args)
 
1308
        self._write_end()
 
1309
        self._medium_request.finished_writing()
 
1310
 
 
1311
    def call_with_body_bytes(self, args, body):
 
1312
        """Make a remote call of args with body bytes 'body'.
 
1313
 
 
1314
        After calling this, call read_response_tuple to find the result out.
 
1315
        """
 
1316
        if 'hpss' in debug.debug_flags:
 
1317
            mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
 
1318
            path = getattr(self._medium_request._medium, '_path', None)
 
1319
            if path is not None:
 
1320
                mutter('                  (to %s)', path)
 
1321
            mutter('              %d bytes', len(body))
 
1322
            self._request_start_time = osutils.timer_func()
 
1323
        self._write_protocol_version()
 
1324
        self._write_headers(self._headers)
 
1325
        self._write_structure(args)
 
1326
        self._write_prefixed_body(body)
 
1327
        self._write_end()
 
1328
        self._medium_request.finished_writing()
 
1329
 
 
1330
    def call_with_body_readv_array(self, args, body):
 
1331
        """Make a remote call with a readv array.
 
1332
 
 
1333
        The body is encoded with one line per readv offset pair. The numbers in
 
1334
        each pair are separated by a comma, and no trailing \\n is emitted.
 
1335
        """
 
1336
        if 'hpss' in debug.debug_flags:
 
1337
            mutter('hpss call w/readv: %s', repr(args)[1:-1])
 
1338
            path = getattr(self._medium_request._medium, '_path', None)
 
1339
            if path is not None:
 
1340
                mutter('                  (to %s)', path)
 
1341
            self._request_start_time = osutils.timer_func()
 
1342
        self._write_protocol_version()
 
1343
        self._write_headers(self._headers)
 
1344
        self._write_structure(args)
 
1345
        readv_bytes = self._serialise_offsets(body)
 
1346
        if 'hpss' in debug.debug_flags:
 
1347
            mutter('              %d bytes in readv request', len(readv_bytes))
 
1348
        self._write_prefixed_body(readv_bytes)
 
1349
        self._write_end()
 
1350
        self._medium_request.finished_writing()
 
1351
 
 
1352
    def call_with_body_stream(self, args, stream):
 
1353
        if 'hpss' in debug.debug_flags:
 
1354
            mutter('hpss call w/body stream: %r', args)
 
1355
            path = getattr(self._medium_request._medium, '_path', None)
 
1356
            if path is not None:
 
1357
                mutter('                  (to %s)', path)
 
1358
            self._request_start_time = osutils.timer_func()
 
1359
        self.body_stream_started = False
 
1360
        self._write_protocol_version()
 
1361
        self._write_headers(self._headers)
 
1362
        self._write_structure(args)
 
1363
        # TODO: notice if the server has sent an early error reply before we
 
1364
        #       have finished sending the stream.  We would notice at the end
 
1365
        #       anyway, but if the medium can deliver it early then it's good
 
1366
        #       to short-circuit the whole request...
 
1367
        # Provoke any ConnectionReset failures before we start the body stream.
 
1368
        self.flush()
 
1369
        self.body_stream_started = True
 
1370
        for exc_info, part in _iter_with_errors(stream):
 
1371
            if exc_info is not None:
 
1372
                # Iterating the stream failed.  Cleanly abort the request.
 
1373
                self._write_error_status()
 
1374
                # Currently the client unconditionally sends ('error',) as the
 
1375
                # error args.
 
1376
                self._write_structure(('error',))
 
1377
                self._write_end()
 
1378
                self._medium_request.finished_writing()
 
1379
                raise exc_info[0], exc_info[1], exc_info[2]
 
1380
            else:
 
1381
                self._write_prefixed_body(part)
 
1382
                self.flush()
 
1383
        self._write_end()
 
1384
        self._medium_request.finished_writing()
 
1385