1
# Copyright (C) 2006-2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""Wire-level encoding and decoding of requests and responses for the smart
21
from __future__ import absolute_import
24
from cStringIO import StringIO
36
from bzrlib.smart import message, request
37
from bzrlib.trace import log_exception_quietly, mutter
38
from bzrlib.bencode import bdecode_as_tuple, bencode
41
# Protocol version strings. These are sent as prefixes of bzr requests and
42
# responses to identify the protocol version being used. (There are no version
43
# one strings because that version doesn't send any).
44
REQUEST_VERSION_TWO = 'bzr request 2\n'
45
RESPONSE_VERSION_TWO = 'bzr response 2\n'
47
MESSAGE_VERSION_THREE = 'bzr message 3 (bzr 1.6)\n'
48
RESPONSE_VERSION_THREE = REQUEST_VERSION_THREE = MESSAGE_VERSION_THREE
51
def _recv_tuple(from_file):
52
req_line = from_file.readline()
53
return _decode_tuple(req_line)
56
def _decode_tuple(req_line):
57
if req_line is None or req_line == '':
59
if req_line[-1] != '\n':
60
raise errors.SmartProtocolError("request %r not terminated" % req_line)
61
return tuple(req_line[:-1].split('\x01'))
64
def _encode_tuple(args):
65
"""Encode the tuple args to a bytestream."""
66
joined = '\x01'.join(args) + '\n'
67
if type(joined) is unicode:
68
# XXX: We should fix things so this never happens! -AJB, 20100304
69
mutter('response args contain unicode, should be only bytes: %r',
71
joined = joined.encode('ascii')
75
class Requester(object):
76
"""Abstract base class for an object that can issue requests on a smart
80
def call(self, *args):
81
"""Make a remote call.
83
:param args: the arguments of this call.
85
raise NotImplementedError(self.call)
87
def call_with_body_bytes(self, args, body):
88
"""Make a remote call with a body.
90
:param args: the arguments of this call.
92
:param body: the body to send with the request.
94
raise NotImplementedError(self.call_with_body_bytes)
96
def call_with_body_readv_array(self, args, body):
97
"""Make a remote call with a readv array.
99
:param args: the arguments of this call.
100
:type body: iterable of (start, length) tuples.
101
:param body: the readv ranges to send with this request.
103
raise NotImplementedError(self.call_with_body_readv_array)
105
def set_headers(self, headers):
106
raise NotImplementedError(self.set_headers)
109
class SmartProtocolBase(object):
110
"""Methods common to client and server"""
112
# TODO: this only actually accomodates a single block; possibly should
113
# support multiple chunks?
114
def _encode_bulk_data(self, body):
115
"""Encode body as a bulk data chunk."""
116
return ''.join(('%d\n' % len(body), body, 'done\n'))
118
def _serialise_offsets(self, offsets):
119
"""Serialise a readv offset list."""
121
for start, length in offsets:
122
txt.append('%d,%d' % (start, length))
123
return '\n'.join(txt)
126
class SmartServerRequestProtocolOne(SmartProtocolBase):
127
"""Server-side encoding and decoding logic for smart version 1."""
129
def __init__(self, backing_transport, write_func, root_client_path='/',
131
self._backing_transport = backing_transport
132
self._root_client_path = root_client_path
133
self._jail_root = jail_root
134
self.unused_data = ''
135
self._finished = False
137
self._has_dispatched = False
139
self._body_decoder = None
140
self._write_func = write_func
142
def accept_bytes(self, bytes):
143
"""Take bytes, and advance the internal state machine appropriately.
145
:param bytes: must be a byte string
147
if not isinstance(bytes, str):
148
raise ValueError(bytes)
149
self.in_buffer += bytes
150
if not self._has_dispatched:
151
if '\n' not in self.in_buffer:
152
# no command line yet
154
self._has_dispatched = True
156
first_line, self.in_buffer = self.in_buffer.split('\n', 1)
158
req_args = _decode_tuple(first_line)
159
self.request = request.SmartServerRequestHandler(
160
self._backing_transport, commands=request.request_handlers,
161
root_client_path=self._root_client_path,
162
jail_root=self._jail_root)
163
self.request.args_received(req_args)
164
if self.request.finished_reading:
166
self.unused_data = self.in_buffer
168
self._send_response(self.request.response)
169
except KeyboardInterrupt:
171
except errors.UnknownSmartMethod, err:
172
protocol_error = errors.SmartProtocolError(
173
"bad request %r" % (err.verb,))
174
failure = request.FailedSmartServerResponse(
175
('error', str(protocol_error)))
176
self._send_response(failure)
178
except Exception, exception:
179
# everything else: pass to client, flush, and quit
180
log_exception_quietly()
181
self._send_response(request.FailedSmartServerResponse(
182
('error', str(exception))))
185
if self._has_dispatched:
187
# nothing to do.XXX: this routine should be a single state
189
self.unused_data += self.in_buffer
192
if self._body_decoder is None:
193
self._body_decoder = LengthPrefixedBodyDecoder()
194
self._body_decoder.accept_bytes(self.in_buffer)
195
self.in_buffer = self._body_decoder.unused_data
196
body_data = self._body_decoder.read_pending_data()
197
self.request.accept_body(body_data)
198
if self._body_decoder.finished_reading:
199
self.request.end_of_body()
200
if not self.request.finished_reading:
201
raise AssertionError("no more body, request not finished")
202
if self.request.response is not None:
203
self._send_response(self.request.response)
204
self.unused_data = self.in_buffer
207
if self.request.finished_reading:
208
raise AssertionError(
209
"no response and we have finished reading.")
211
def _send_response(self, response):
212
"""Send a smart server response down the output stream."""
214
raise AssertionError('response already sent')
217
self._finished = True
218
self._write_protocol_version()
219
self._write_success_or_failure_prefix(response)
220
self._write_func(_encode_tuple(args))
222
if not isinstance(body, str):
223
raise ValueError(body)
224
bytes = self._encode_bulk_data(body)
225
self._write_func(bytes)
227
def _write_protocol_version(self):
228
"""Write any prefixes this protocol requires.
230
Version one doesn't send protocol versions.
233
def _write_success_or_failure_prefix(self, response):
234
"""Write the protocol specific success/failure prefix.
236
For SmartServerRequestProtocolOne this is omitted but we
237
call is_successful to ensure that the response is valid.
239
response.is_successful()
241
def next_read_size(self):
244
if self._body_decoder is None:
247
return self._body_decoder.next_read_size()
250
class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
251
r"""Version two of the server side of the smart protocol.
253
This prefixes responses with the value of RESPONSE_VERSION_TWO.
256
response_marker = RESPONSE_VERSION_TWO
257
request_marker = REQUEST_VERSION_TWO
259
def _write_success_or_failure_prefix(self, response):
260
"""Write the protocol specific success/failure prefix."""
261
if response.is_successful():
262
self._write_func('success\n')
264
self._write_func('failed\n')
266
def _write_protocol_version(self):
267
r"""Write any prefixes this protocol requires.
269
Version two sends the value of RESPONSE_VERSION_TWO.
271
self._write_func(self.response_marker)
273
def _send_response(self, response):
274
"""Send a smart server response down the output stream."""
276
raise AssertionError('response already sent')
277
self._finished = True
278
self._write_protocol_version()
279
self._write_success_or_failure_prefix(response)
280
self._write_func(_encode_tuple(response.args))
281
if response.body is not None:
282
if not isinstance(response.body, str):
283
raise AssertionError('body must be a str')
284
if not (response.body_stream is None):
285
raise AssertionError(
286
'body_stream and body cannot both be set')
287
bytes = self._encode_bulk_data(response.body)
288
self._write_func(bytes)
289
elif response.body_stream is not None:
290
_send_stream(response.body_stream, self._write_func)
293
def _send_stream(stream, write_func):
294
write_func('chunked\n')
295
_send_chunks(stream, write_func)
299
def _send_chunks(stream, write_func):
301
if isinstance(chunk, str):
302
bytes = "%x\n%s" % (len(chunk), chunk)
304
elif isinstance(chunk, request.FailedSmartServerResponse):
306
_send_chunks(chunk.args, write_func)
309
raise errors.BzrError(
310
'Chunks must be str or FailedSmartServerResponse, got %r'
314
class _NeedMoreBytes(Exception):
315
"""Raise this inside a _StatefulDecoder to stop decoding until more bytes
319
def __init__(self, count=None):
322
:param count: the total number of bytes needed by the current state.
323
May be None if the number of bytes needed is unknown.
328
class _StatefulDecoder(object):
329
"""Base class for writing state machines to decode byte streams.
331
Subclasses should provide a self.state_accept attribute that accepts bytes
332
and, if appropriate, updates self.state_accept to a different function.
333
accept_bytes will call state_accept as often as necessary to make sure the
334
state machine has progressed as far as possible before it returns.
336
See ProtocolThreeDecoder for an example subclass.
340
self.finished_reading = False
341
self._in_buffer_list = []
342
self._in_buffer_len = 0
343
self.unused_data = ''
344
self.bytes_left = None
345
self._number_needed_bytes = None
347
def _get_in_buffer(self):
348
if len(self._in_buffer_list) == 1:
349
return self._in_buffer_list[0]
350
in_buffer = ''.join(self._in_buffer_list)
351
if len(in_buffer) != self._in_buffer_len:
352
raise AssertionError(
353
"Length of buffer did not match expected value: %s != %s"
354
% self._in_buffer_len, len(in_buffer))
355
self._in_buffer_list = [in_buffer]
358
def _get_in_bytes(self, count):
359
"""Grab X bytes from the input_buffer.
361
Callers should have already checked that self._in_buffer_len is >
362
count. Note, this does not consume the bytes from the buffer. The
363
caller will still need to call _get_in_buffer() and then
364
_set_in_buffer() if they actually need to consume the bytes.
366
# check if we can yield the bytes from just the first entry in our list
367
if len(self._in_buffer_list) == 0:
368
raise AssertionError('Callers must be sure we have buffered bytes'
369
' before calling _get_in_bytes')
370
if len(self._in_buffer_list[0]) > count:
371
return self._in_buffer_list[0][:count]
372
# We can't yield it from the first buffer, so collapse all buffers, and
374
in_buf = self._get_in_buffer()
375
return in_buf[:count]
377
def _set_in_buffer(self, new_buf):
378
if new_buf is not None:
379
self._in_buffer_list = [new_buf]
380
self._in_buffer_len = len(new_buf)
382
self._in_buffer_list = []
383
self._in_buffer_len = 0
385
def accept_bytes(self, bytes):
386
"""Decode as much of bytes as possible.
388
If 'bytes' contains too much data it will be appended to
391
finished_reading will be set when no more data is required. Further
392
data will be appended to self.unused_data.
394
# accept_bytes is allowed to change the state
395
self._number_needed_bytes = None
396
# lsprof puts a very large amount of time on this specific call for
398
self._in_buffer_list.append(bytes)
399
self._in_buffer_len += len(bytes)
401
# Run the function for the current state.
402
current_state = self.state_accept
404
while current_state != self.state_accept:
405
# The current state has changed. Run the function for the new
406
# current state, so that it can:
407
# - decode any unconsumed bytes left in a buffer, and
408
# - signal how many more bytes are expected (via raising
410
current_state = self.state_accept
412
except _NeedMoreBytes, e:
413
self._number_needed_bytes = e.count
416
class ChunkedBodyDecoder(_StatefulDecoder):
417
"""Decoder for chunked body data.
419
This is very similar the HTTP's chunked encoding. See the description of
420
streamed body data in `doc/developers/network-protocol.txt` for details.
424
_StatefulDecoder.__init__(self)
425
self.state_accept = self._state_accept_expecting_header
426
self.chunk_in_progress = None
427
self.chunks = collections.deque()
429
self.error_in_progress = None
431
def next_read_size(self):
432
# Note: the shortest possible chunk is 2 bytes: '0\n', and the
433
# end-of-body marker is 4 bytes: 'END\n'.
434
if self.state_accept == self._state_accept_reading_chunk:
435
# We're expecting more chunk content. So we're expecting at least
436
# the rest of this chunk plus an END chunk.
437
return self.bytes_left + 4
438
elif self.state_accept == self._state_accept_expecting_length:
439
if self._in_buffer_len == 0:
440
# We're expecting a chunk length. There's at least two bytes
441
# left: a digit plus '\n'.
444
# We're in the middle of reading a chunk length. So there's at
445
# least one byte left, the '\n' that terminates the length.
447
elif self.state_accept == self._state_accept_reading_unused:
449
elif self.state_accept == self._state_accept_expecting_header:
450
return max(0, len('chunked\n') - self._in_buffer_len)
452
raise AssertionError("Impossible state: %r" % (self.state_accept,))
454
def read_next_chunk(self):
456
return self.chunks.popleft()
460
def _extract_line(self):
461
in_buf = self._get_in_buffer()
462
pos = in_buf.find('\n')
464
# We haven't read a complete line yet, so request more bytes before
466
raise _NeedMoreBytes(1)
468
# Trim the prefix (including '\n' delimiter) from the _in_buffer.
469
self._set_in_buffer(in_buf[pos+1:])
473
self.unused_data = self._get_in_buffer()
474
self._in_buffer_list = []
475
self._in_buffer_len = 0
476
self.state_accept = self._state_accept_reading_unused
478
error_args = tuple(self.error_in_progress)
479
self.chunks.append(request.FailedSmartServerResponse(error_args))
480
self.error_in_progress = None
481
self.finished_reading = True
483
def _state_accept_expecting_header(self):
484
prefix = self._extract_line()
485
if prefix == 'chunked':
486
self.state_accept = self._state_accept_expecting_length
488
raise errors.SmartProtocolError(
489
'Bad chunked body header: "%s"' % (prefix,))
491
def _state_accept_expecting_length(self):
492
prefix = self._extract_line()
495
self.error_in_progress = []
496
self._state_accept_expecting_length()
498
elif prefix == 'END':
499
# We've read the end-of-body marker.
500
# Any further bytes are unused data, including the bytes left in
505
self.bytes_left = int(prefix, 16)
506
self.chunk_in_progress = ''
507
self.state_accept = self._state_accept_reading_chunk
509
def _state_accept_reading_chunk(self):
510
in_buf = self._get_in_buffer()
511
in_buffer_len = len(in_buf)
512
self.chunk_in_progress += in_buf[:self.bytes_left]
513
self._set_in_buffer(in_buf[self.bytes_left:])
514
self.bytes_left -= in_buffer_len
515
if self.bytes_left <= 0:
516
# Finished with chunk
517
self.bytes_left = None
519
self.error_in_progress.append(self.chunk_in_progress)
521
self.chunks.append(self.chunk_in_progress)
522
self.chunk_in_progress = None
523
self.state_accept = self._state_accept_expecting_length
525
def _state_accept_reading_unused(self):
526
self.unused_data += self._get_in_buffer()
527
self._in_buffer_list = []
530
class LengthPrefixedBodyDecoder(_StatefulDecoder):
531
"""Decodes the length-prefixed bulk data."""
534
_StatefulDecoder.__init__(self)
535
self.state_accept = self._state_accept_expecting_length
536
self.state_read = self._state_read_no_data
538
self._trailer_buffer = ''
540
def next_read_size(self):
541
if self.bytes_left is not None:
542
# Ideally we want to read all the remainder of the body and the
544
return self.bytes_left + 5
545
elif self.state_accept == self._state_accept_reading_trailer:
546
# Just the trailer left
547
return 5 - len(self._trailer_buffer)
548
elif self.state_accept == self._state_accept_expecting_length:
549
# There's still at least 6 bytes left ('\n' to end the length, plus
553
# Reading excess data. Either way, 1 byte at a time is fine.
556
def read_pending_data(self):
557
"""Return any pending data that has been decoded."""
558
return self.state_read()
560
def _state_accept_expecting_length(self):
561
in_buf = self._get_in_buffer()
562
pos = in_buf.find('\n')
565
self.bytes_left = int(in_buf[:pos])
566
self._set_in_buffer(in_buf[pos+1:])
567
self.state_accept = self._state_accept_reading_body
568
self.state_read = self._state_read_body_buffer
570
def _state_accept_reading_body(self):
571
in_buf = self._get_in_buffer()
573
self.bytes_left -= len(in_buf)
574
self._set_in_buffer(None)
575
if self.bytes_left <= 0:
577
if self.bytes_left != 0:
578
self._trailer_buffer = self._body[self.bytes_left:]
579
self._body = self._body[:self.bytes_left]
580
self.bytes_left = None
581
self.state_accept = self._state_accept_reading_trailer
583
def _state_accept_reading_trailer(self):
584
self._trailer_buffer += self._get_in_buffer()
585
self._set_in_buffer(None)
586
# TODO: what if the trailer does not match "done\n"? Should this raise
587
# a ProtocolViolation exception?
588
if self._trailer_buffer.startswith('done\n'):
589
self.unused_data = self._trailer_buffer[len('done\n'):]
590
self.state_accept = self._state_accept_reading_unused
591
self.finished_reading = True
593
def _state_accept_reading_unused(self):
594
self.unused_data += self._get_in_buffer()
595
self._set_in_buffer(None)
597
def _state_read_no_data(self):
600
def _state_read_body_buffer(self):
606
class SmartClientRequestProtocolOne(SmartProtocolBase, Requester,
607
message.ResponseHandler):
608
"""The client-side protocol for smart version 1."""
610
def __init__(self, request):
611
"""Construct a SmartClientRequestProtocolOne.
613
:param request: A SmartClientMediumRequest to serialise onto and
616
self._request = request
617
self._body_buffer = None
618
self._request_start_time = None
619
self._last_verb = None
622
def set_headers(self, headers):
623
self._headers = dict(headers)
625
def call(self, *args):
626
if 'hpss' in debug.debug_flags:
627
mutter('hpss call: %s', repr(args)[1:-1])
628
if getattr(self._request._medium, 'base', None) is not None:
629
mutter(' (to %s)', self._request._medium.base)
630
self._request_start_time = osutils.timer_func()
631
self._write_args(args)
632
self._request.finished_writing()
633
self._last_verb = args[0]
635
def call_with_body_bytes(self, args, body):
636
"""Make a remote call of args with body bytes 'body'.
638
After calling this, call read_response_tuple to find the result out.
640
if 'hpss' in debug.debug_flags:
641
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
642
if getattr(self._request._medium, '_path', None) is not None:
643
mutter(' (to %s)', self._request._medium._path)
644
mutter(' %d bytes', len(body))
645
self._request_start_time = osutils.timer_func()
646
if 'hpssdetail' in debug.debug_flags:
647
mutter('hpss body content: %s', body)
648
self._write_args(args)
649
bytes = self._encode_bulk_data(body)
650
self._request.accept_bytes(bytes)
651
self._request.finished_writing()
652
self._last_verb = args[0]
654
def call_with_body_readv_array(self, args, body):
655
"""Make a remote call with a readv array.
657
The body is encoded with one line per readv offset pair. The numbers in
658
each pair are separated by a comma, and no trailing \\n is emitted.
660
if 'hpss' in debug.debug_flags:
661
mutter('hpss call w/readv: %s', repr(args)[1:-1])
662
if getattr(self._request._medium, '_path', None) is not None:
663
mutter(' (to %s)', self._request._medium._path)
664
self._request_start_time = osutils.timer_func()
665
self._write_args(args)
666
readv_bytes = self._serialise_offsets(body)
667
bytes = self._encode_bulk_data(readv_bytes)
668
self._request.accept_bytes(bytes)
669
self._request.finished_writing()
670
if 'hpss' in debug.debug_flags:
671
mutter(' %d bytes in readv request', len(readv_bytes))
672
self._last_verb = args[0]
674
def call_with_body_stream(self, args, stream):
675
# Protocols v1 and v2 don't support body streams. So it's safe to
676
# assume that a v1/v2 server doesn't support whatever method we're
677
# trying to call with a body stream.
678
self._request.finished_writing()
679
self._request.finished_reading()
680
raise errors.UnknownSmartMethod(args[0])
682
def cancel_read_body(self):
683
"""After expecting a body, a response code may indicate one otherwise.
685
This method lets the domain client inform the protocol that no body
686
will be transmitted. This is a terminal method: after calling it the
687
protocol is not able to be used further.
689
self._request.finished_reading()
691
def _read_response_tuple(self):
692
result = self._recv_tuple()
693
if 'hpss' in debug.debug_flags:
694
if self._request_start_time is not None:
695
mutter(' result: %6.3fs %s',
696
osutils.timer_func() - self._request_start_time,
698
self._request_start_time = None
700
mutter(' result: %s', repr(result)[1:-1])
703
def read_response_tuple(self, expect_body=False):
704
"""Read a response tuple from the wire.
706
This should only be called once.
708
result = self._read_response_tuple()
709
self._response_is_unknown_method(result)
710
self._raise_args_if_error(result)
712
self._request.finished_reading()
715
def _raise_args_if_error(self, result_tuple):
716
# Later protocol versions have an explicit flag in the protocol to say
717
# if an error response is "failed" or not. In version 1 we don't have
718
# that luxury. So here is a complete list of errors that can be
719
# returned in response to existing version 1 smart requests. Responses
720
# starting with these codes are always "failed" responses.
727
'UnicodeEncodeError',
728
'UnicodeDecodeError',
734
'UnlockableTransport',
740
if result_tuple[0] in v1_error_codes:
741
self._request.finished_reading()
742
raise errors.ErrorFromSmartServer(result_tuple)
744
def _response_is_unknown_method(self, result_tuple):
745
"""Raise UnexpectedSmartServerResponse if the response is an 'unknonwn
746
method' response to the request.
748
:param response: The response from a smart client call_expecting_body
750
:param verb: The verb used in that call.
751
:raises: UnexpectedSmartServerResponse
753
if (result_tuple == ('error', "Generic bzr smart protocol error: "
754
"bad request '%s'" % self._last_verb) or
755
result_tuple == ('error', "Generic bzr smart protocol error: "
756
"bad request u'%s'" % self._last_verb)):
757
# The response will have no body, so we've finished reading.
758
self._request.finished_reading()
759
raise errors.UnknownSmartMethod(self._last_verb)
761
def read_body_bytes(self, count=-1):
762
"""Read bytes from the body, decoding into a byte stream.
764
We read all bytes at once to ensure we've checked the trailer for
765
errors, and then feed the buffer back as read_body_bytes is called.
767
if self._body_buffer is not None:
768
return self._body_buffer.read(count)
769
_body_decoder = LengthPrefixedBodyDecoder()
771
while not _body_decoder.finished_reading:
772
bytes = self._request.read_bytes(_body_decoder.next_read_size())
774
# end of file encountered reading from server
775
raise errors.ConnectionReset(
776
"Connection lost while reading response body.")
777
_body_decoder.accept_bytes(bytes)
778
self._request.finished_reading()
779
self._body_buffer = StringIO(_body_decoder.read_pending_data())
780
# XXX: TODO check the trailer result.
781
if 'hpss' in debug.debug_flags:
782
mutter(' %d body bytes read',
783
len(self._body_buffer.getvalue()))
784
return self._body_buffer.read(count)
786
def _recv_tuple(self):
787
"""Receive a tuple from the medium request."""
788
return _decode_tuple(self._request.read_line())
790
def query_version(self):
791
"""Return protocol version number of the server."""
793
resp = self.read_response_tuple()
794
if resp == ('ok', '1'):
796
elif resp == ('ok', '2'):
799
raise errors.SmartProtocolError("bad response %r" % (resp,))
801
def _write_args(self, args):
802
self._write_protocol_version()
803
bytes = _encode_tuple(args)
804
self._request.accept_bytes(bytes)
806
def _write_protocol_version(self):
807
"""Write any prefixes this protocol requires.
809
Version one doesn't send protocol versions.
813
class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
814
"""Version two of the client side of the smart protocol.
816
This prefixes the request with the value of REQUEST_VERSION_TWO.
819
response_marker = RESPONSE_VERSION_TWO
820
request_marker = REQUEST_VERSION_TWO
822
def read_response_tuple(self, expect_body=False):
823
"""Read a response tuple from the wire.
825
This should only be called once.
827
version = self._request.read_line()
828
if version != self.response_marker:
829
self._request.finished_reading()
830
raise errors.UnexpectedProtocolVersionMarker(version)
831
response_status = self._request.read_line()
832
result = SmartClientRequestProtocolOne._read_response_tuple(self)
833
self._response_is_unknown_method(result)
834
if response_status == 'success\n':
835
self.response_status = True
837
self._request.finished_reading()
839
elif response_status == 'failed\n':
840
self.response_status = False
841
self._request.finished_reading()
842
raise errors.ErrorFromSmartServer(result)
844
raise errors.SmartProtocolError(
845
'bad protocol status %r' % response_status)
847
def _write_protocol_version(self):
848
"""Write any prefixes this protocol requires.
850
Version two sends the value of REQUEST_VERSION_TWO.
852
self._request.accept_bytes(self.request_marker)
854
def read_streamed_body(self):
855
"""Read bytes from the body, decoding into a byte stream.
857
# Read no more than 64k at a time so that we don't risk error 10055 (no
858
# buffer space available) on Windows.
859
_body_decoder = ChunkedBodyDecoder()
860
while not _body_decoder.finished_reading:
861
bytes = self._request.read_bytes(_body_decoder.next_read_size())
863
# end of file encountered reading from server
864
raise errors.ConnectionReset(
865
"Connection lost while reading streamed body.")
866
_body_decoder.accept_bytes(bytes)
867
for body_bytes in iter(_body_decoder.read_next_chunk, None):
868
if 'hpss' in debug.debug_flags and type(body_bytes) is str:
869
mutter(' %d byte chunk read',
872
self._request.finished_reading()
875
def build_server_protocol_three(backing_transport, write_func,
876
root_client_path, jail_root=None):
877
request_handler = request.SmartServerRequestHandler(
878
backing_transport, commands=request.request_handlers,
879
root_client_path=root_client_path, jail_root=jail_root)
880
responder = ProtocolThreeResponder(write_func)
881
message_handler = message.ConventionalRequestHandler(request_handler, responder)
882
return ProtocolThreeDecoder(message_handler)
885
class ProtocolThreeDecoder(_StatefulDecoder):
887
response_marker = RESPONSE_VERSION_THREE
888
request_marker = REQUEST_VERSION_THREE
890
def __init__(self, message_handler, expect_version_marker=False):
891
_StatefulDecoder.__init__(self)
892
self._has_dispatched = False
894
if expect_version_marker:
895
self.state_accept = self._state_accept_expecting_protocol_version
896
# We're expecting at least the protocol version marker + some
898
self._number_needed_bytes = len(MESSAGE_VERSION_THREE) + 4
900
self.state_accept = self._state_accept_expecting_headers
901
self._number_needed_bytes = 4
902
self.decoding_failed = False
903
self.request_handler = self.message_handler = message_handler
905
def accept_bytes(self, bytes):
906
self._number_needed_bytes = None
908
_StatefulDecoder.accept_bytes(self, bytes)
909
except KeyboardInterrupt:
911
except errors.SmartMessageHandlerError, exception:
912
# We do *not* set self.decoding_failed here. The message handler
913
# has raised an error, but the decoder is still able to parse bytes
914
# and determine when this message ends.
915
if not isinstance(exception.exc_value, errors.UnknownSmartMethod):
916
log_exception_quietly()
917
self.message_handler.protocol_error(exception.exc_value)
918
# The state machine is ready to continue decoding, but the
919
# exception has interrupted the loop that runs the state machine.
920
# So we call accept_bytes again to restart it.
921
self.accept_bytes('')
922
except Exception, exception:
923
# The decoder itself has raised an exception. We cannot continue
925
self.decoding_failed = True
926
if isinstance(exception, errors.UnexpectedProtocolVersionMarker):
927
# This happens during normal operation when the client tries a
928
# protocol version the server doesn't understand, so no need to
929
# log a traceback every time.
930
# Note that this can only happen when
931
# expect_version_marker=True, which is only the case on the
935
log_exception_quietly()
936
self.message_handler.protocol_error(exception)
938
def _extract_length_prefixed_bytes(self):
939
if self._in_buffer_len < 4:
940
# A length prefix by itself is 4 bytes, and we don't even have that
942
raise _NeedMoreBytes(4)
943
(length,) = struct.unpack('!L', self._get_in_bytes(4))
944
end_of_bytes = 4 + length
945
if self._in_buffer_len < end_of_bytes:
946
# We haven't yet read as many bytes as the length-prefix says there
948
raise _NeedMoreBytes(end_of_bytes)
949
# Extract the bytes from the buffer.
950
in_buf = self._get_in_buffer()
951
bytes = in_buf[4:end_of_bytes]
952
self._set_in_buffer(in_buf[end_of_bytes:])
955
def _extract_prefixed_bencoded_data(self):
956
prefixed_bytes = self._extract_length_prefixed_bytes()
958
decoded = bdecode_as_tuple(prefixed_bytes)
960
raise errors.SmartProtocolError(
961
'Bytes %r not bencoded' % (prefixed_bytes,))
964
def _extract_single_byte(self):
965
if self._in_buffer_len == 0:
966
# The buffer is empty
967
raise _NeedMoreBytes(1)
968
in_buf = self._get_in_buffer()
970
self._set_in_buffer(in_buf[1:])
973
def _state_accept_expecting_protocol_version(self):
974
needed_bytes = len(MESSAGE_VERSION_THREE) - self._in_buffer_len
975
in_buf = self._get_in_buffer()
977
# We don't have enough bytes to check if the protocol version
978
# marker is right. But we can check if it is already wrong by
979
# checking that the start of MESSAGE_VERSION_THREE matches what
981
# [In fact, if the remote end isn't bzr we might never receive
982
# len(MESSAGE_VERSION_THREE) bytes. So if the bytes we have so far
983
# are wrong then we should just raise immediately rather than
985
if not MESSAGE_VERSION_THREE.startswith(in_buf):
986
# We have enough bytes to know the protocol version is wrong
987
raise errors.UnexpectedProtocolVersionMarker(in_buf)
988
raise _NeedMoreBytes(len(MESSAGE_VERSION_THREE))
989
if not in_buf.startswith(MESSAGE_VERSION_THREE):
990
raise errors.UnexpectedProtocolVersionMarker(in_buf)
991
self._set_in_buffer(in_buf[len(MESSAGE_VERSION_THREE):])
992
self.state_accept = self._state_accept_expecting_headers
994
def _state_accept_expecting_headers(self):
995
decoded = self._extract_prefixed_bencoded_data()
996
if type(decoded) is not dict:
997
raise errors.SmartProtocolError(
998
'Header object %r is not a dict' % (decoded,))
999
self.state_accept = self._state_accept_expecting_message_part
1001
self.message_handler.headers_received(decoded)
1003
raise errors.SmartMessageHandlerError(sys.exc_info())
1005
def _state_accept_expecting_message_part(self):
1006
message_part_kind = self._extract_single_byte()
1007
if message_part_kind == 'o':
1008
self.state_accept = self._state_accept_expecting_one_byte
1009
elif message_part_kind == 's':
1010
self.state_accept = self._state_accept_expecting_structure
1011
elif message_part_kind == 'b':
1012
self.state_accept = self._state_accept_expecting_bytes
1013
elif message_part_kind == 'e':
1016
raise errors.SmartProtocolError(
1017
'Bad message kind byte: %r' % (message_part_kind,))
1019
def _state_accept_expecting_one_byte(self):
1020
byte = self._extract_single_byte()
1021
self.state_accept = self._state_accept_expecting_message_part
1023
self.message_handler.byte_part_received(byte)
1025
raise errors.SmartMessageHandlerError(sys.exc_info())
1027
def _state_accept_expecting_bytes(self):
1028
# XXX: this should not buffer whole message part, but instead deliver
1029
# the bytes as they arrive.
1030
prefixed_bytes = self._extract_length_prefixed_bytes()
1031
self.state_accept = self._state_accept_expecting_message_part
1033
self.message_handler.bytes_part_received(prefixed_bytes)
1035
raise errors.SmartMessageHandlerError(sys.exc_info())
1037
def _state_accept_expecting_structure(self):
1038
structure = self._extract_prefixed_bencoded_data()
1039
self.state_accept = self._state_accept_expecting_message_part
1041
self.message_handler.structure_part_received(structure)
1043
raise errors.SmartMessageHandlerError(sys.exc_info())
1046
self.unused_data = self._get_in_buffer()
1047
self._set_in_buffer(None)
1048
self.state_accept = self._state_accept_reading_unused
1050
self.message_handler.end_received()
1052
raise errors.SmartMessageHandlerError(sys.exc_info())
1054
def _state_accept_reading_unused(self):
1055
self.unused_data += self._get_in_buffer()
1056
self._set_in_buffer(None)
1058
def next_read_size(self):
1059
if self.state_accept == self._state_accept_reading_unused:
1061
elif self.decoding_failed:
1062
# An exception occured while processing this message, probably from
1063
# self.message_handler. We're not sure that this state machine is
1064
# in a consistent state, so just signal that we're done (i.e. give
1068
if self._number_needed_bytes is not None:
1069
return self._number_needed_bytes - self._in_buffer_len
1071
raise AssertionError("don't know how many bytes are expected!")
1074
class _ProtocolThreeEncoder(object):
1076
response_marker = request_marker = MESSAGE_VERSION_THREE
1077
BUFFER_SIZE = 1024*1024 # 1 MiB buffer before flushing
1079
def __init__(self, write_func):
1082
self._real_write_func = write_func
1084
def _write_func(self, bytes):
1085
# TODO: Another possibility would be to turn this into an async model.
1086
# Where we let another thread know that we have some bytes if
1087
# they want it, but we don't actually block for it
1088
# Note that osutils.send_all always sends 64kB chunks anyway, so
1089
# we might just push out smaller bits at a time?
1090
self._buf.append(bytes)
1091
self._buf_len += len(bytes)
1092
if self._buf_len > self.BUFFER_SIZE:
1097
self._real_write_func(''.join(self._buf))
1101
def _serialise_offsets(self, offsets):
1102
"""Serialise a readv offset list."""
1104
for start, length in offsets:
1105
txt.append('%d,%d' % (start, length))
1106
return '\n'.join(txt)
1108
def _write_protocol_version(self):
1109
self._write_func(MESSAGE_VERSION_THREE)
1111
def _write_prefixed_bencode(self, structure):
1112
bytes = bencode(structure)
1113
self._write_func(struct.pack('!L', len(bytes)))
1114
self._write_func(bytes)
1116
def _write_headers(self, headers):
1117
self._write_prefixed_bencode(headers)
1119
def _write_structure(self, args):
1120
self._write_func('s')
1123
if type(arg) is unicode:
1124
utf8_args.append(arg.encode('utf8'))
1126
utf8_args.append(arg)
1127
self._write_prefixed_bencode(utf8_args)
1129
def _write_end(self):
1130
self._write_func('e')
1133
def _write_prefixed_body(self, bytes):
1134
self._write_func('b')
1135
self._write_func(struct.pack('!L', len(bytes)))
1136
self._write_func(bytes)
1138
def _write_chunked_body_start(self):
1139
self._write_func('oC')
1141
def _write_error_status(self):
1142
self._write_func('oE')
1144
def _write_success_status(self):
1145
self._write_func('oS')
1148
class ProtocolThreeResponder(_ProtocolThreeEncoder):
1150
def __init__(self, write_func):
1151
_ProtocolThreeEncoder.__init__(self, write_func)
1152
self.response_sent = False
1153
self._headers = {'Software version': bzrlib.__version__}
1154
if 'hpss' in debug.debug_flags:
1155
self._thread_id = thread.get_ident()
1156
self._response_start_time = None
1158
def _trace(self, action, message, extra_bytes=None, include_time=False):
1159
if self._response_start_time is None:
1160
self._response_start_time = osutils.timer_func()
1162
t = '%5.3fs ' % (time.clock() - self._response_start_time)
1165
if extra_bytes is None:
1168
extra = ' ' + repr(extra_bytes[:40])
1170
extra = extra[:29] + extra[-1] + '...'
1171
mutter('%12s: [%s] %s%s%s'
1172
% (action, self._thread_id, t, message, extra))
1174
def send_error(self, exception):
1175
if self.response_sent:
1176
raise AssertionError(
1177
"send_error(%s) called, but response already sent."
1179
if isinstance(exception, errors.UnknownSmartMethod):
1180
failure = request.FailedSmartServerResponse(
1181
('UnknownMethod', exception.verb))
1182
self.send_response(failure)
1184
if 'hpss' in debug.debug_flags:
1185
self._trace('error', str(exception))
1186
self.response_sent = True
1187
self._write_protocol_version()
1188
self._write_headers(self._headers)
1189
self._write_error_status()
1190
self._write_structure(('error', str(exception)))
1193
def send_response(self, response):
1194
if self.response_sent:
1195
raise AssertionError(
1196
"send_response(%r) called, but response already sent."
1198
self.response_sent = True
1199
self._write_protocol_version()
1200
self._write_headers(self._headers)
1201
if response.is_successful():
1202
self._write_success_status()
1204
self._write_error_status()
1205
if 'hpss' in debug.debug_flags:
1206
self._trace('response', repr(response.args))
1207
self._write_structure(response.args)
1208
if response.body is not None:
1209
self._write_prefixed_body(response.body)
1210
if 'hpss' in debug.debug_flags:
1211
self._trace('body', '%d bytes' % (len(response.body),),
1212
response.body, include_time=True)
1213
elif response.body_stream is not None:
1214
count = num_bytes = 0
1216
for exc_info, chunk in _iter_with_errors(response.body_stream):
1218
if exc_info is not None:
1219
self._write_error_status()
1220
error_struct = request._translate_error(exc_info[1])
1221
self._write_structure(error_struct)
1224
if isinstance(chunk, request.FailedSmartServerResponse):
1225
self._write_error_status()
1226
self._write_structure(chunk.args)
1228
num_bytes += len(chunk)
1229
if first_chunk is None:
1231
self._write_prefixed_body(chunk)
1233
if 'hpssdetail' in debug.debug_flags:
1234
# Not worth timing separately, as _write_func is
1236
self._trace('body chunk',
1237
'%d bytes' % (len(chunk),),
1238
chunk, suppress_time=True)
1239
if 'hpss' in debug.debug_flags:
1240
self._trace('body stream',
1241
'%d bytes %d chunks' % (num_bytes, count),
1244
if 'hpss' in debug.debug_flags:
1245
self._trace('response end', '', include_time=True)
1248
def _iter_with_errors(iterable):
1249
"""Handle errors from iterable.next().
1253
for exc_info, value in _iter_with_errors(iterable):
1256
This is a safer alternative to::
1259
for value in iterable:
1264
Because the latter will catch errors from the for-loop body, not just
1267
If an error occurs, exc_info will be a exc_info tuple, and the generator
1268
will terminate. Otherwise exc_info will be None, and value will be the
1269
value from iterable.next(). Note that KeyboardInterrupt and SystemExit
1270
will not be itercepted.
1272
iterator = iter(iterable)
1275
yield None, iterator.next()
1276
except StopIteration:
1278
except (KeyboardInterrupt, SystemExit):
1281
mutter('_iter_with_errors caught error')
1282
log_exception_quietly()
1283
yield sys.exc_info(), None
1287
class ProtocolThreeRequester(_ProtocolThreeEncoder, Requester):
1289
def __init__(self, medium_request):
1290
_ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
1291
self._medium_request = medium_request
1293
self.body_stream_started = None
1295
def set_headers(self, headers):
1296
self._headers = headers.copy()
1298
def call(self, *args):
1299
if 'hpss' in debug.debug_flags:
1300
mutter('hpss call: %s', repr(args)[1:-1])
1301
base = getattr(self._medium_request._medium, 'base', None)
1302
if base is not None:
1303
mutter(' (to %s)', base)
1304
self._request_start_time = osutils.timer_func()
1305
self._write_protocol_version()
1306
self._write_headers(self._headers)
1307
self._write_structure(args)
1309
self._medium_request.finished_writing()
1311
def call_with_body_bytes(self, args, body):
1312
"""Make a remote call of args with body bytes 'body'.
1314
After calling this, call read_response_tuple to find the result out.
1316
if 'hpss' in debug.debug_flags:
1317
mutter('hpss call w/body: %s (%r...)', repr(args)[1:-1], body[:20])
1318
path = getattr(self._medium_request._medium, '_path', None)
1319
if path is not None:
1320
mutter(' (to %s)', path)
1321
mutter(' %d bytes', len(body))
1322
self._request_start_time = osutils.timer_func()
1323
self._write_protocol_version()
1324
self._write_headers(self._headers)
1325
self._write_structure(args)
1326
self._write_prefixed_body(body)
1328
self._medium_request.finished_writing()
1330
def call_with_body_readv_array(self, args, body):
1331
"""Make a remote call with a readv array.
1333
The body is encoded with one line per readv offset pair. The numbers in
1334
each pair are separated by a comma, and no trailing \\n is emitted.
1336
if 'hpss' in debug.debug_flags:
1337
mutter('hpss call w/readv: %s', repr(args)[1:-1])
1338
path = getattr(self._medium_request._medium, '_path', None)
1339
if path is not None:
1340
mutter(' (to %s)', path)
1341
self._request_start_time = osutils.timer_func()
1342
self._write_protocol_version()
1343
self._write_headers(self._headers)
1344
self._write_structure(args)
1345
readv_bytes = self._serialise_offsets(body)
1346
if 'hpss' in debug.debug_flags:
1347
mutter(' %d bytes in readv request', len(readv_bytes))
1348
self._write_prefixed_body(readv_bytes)
1350
self._medium_request.finished_writing()
1352
def call_with_body_stream(self, args, stream):
1353
if 'hpss' in debug.debug_flags:
1354
mutter('hpss call w/body stream: %r', args)
1355
path = getattr(self._medium_request._medium, '_path', None)
1356
if path is not None:
1357
mutter(' (to %s)', path)
1358
self._request_start_time = osutils.timer_func()
1359
self.body_stream_started = False
1360
self._write_protocol_version()
1361
self._write_headers(self._headers)
1362
self._write_structure(args)
1363
# TODO: notice if the server has sent an early error reply before we
1364
# have finished sending the stream. We would notice at the end
1365
# anyway, but if the medium can deliver it early then it's good
1366
# to short-circuit the whole request...
1367
# Provoke any ConnectionReset failures before we start the body stream.
1369
self.body_stream_started = True
1370
for exc_info, part in _iter_with_errors(stream):
1371
if exc_info is not None:
1372
# Iterating the stream failed. Cleanly abort the request.
1373
self._write_error_status()
1374
# Currently the client unconditionally sends ('error',) as the
1376
self._write_structure(('error',))
1378
self._medium_request.finished_writing()
1379
raise exc_info[0], exc_info[1], exc_info[2]
1381
self._write_prefixed_body(part)
1384
self._medium_request.finished_writing()