24
24
bzrlib/transport/smart/__init__.py.
33
from bzrlib.lazy_import import lazy_import
34
lazy_import(globals(), """
32
37
from bzrlib import (
38
from bzrlib.smart.protocol import (
39
MESSAGE_VERSION_THREE,
41
SmartClientRequestProtocolOne,
42
SmartServerRequestProtocolOne,
43
SmartServerRequestProtocolTwo,
44
build_server_protocol_three
45
from bzrlib.smart import client, protocol
46
46
from bzrlib.transport import ssh
50
# We must not read any more than 64k at a time so we don't risk "no buffer
51
# space available" errors on some platforms. Windows in particular is likely
52
# to give error 10053 or 10055 if we read more than 64k from a socket.
53
_MAX_READ_SIZE = 64 * 1024
49
56
def _get_protocol_factory_for_bytes(bytes):
67
74
root_client_path. unused_bytes are any bytes that were not part of a
68
75
protocol version marker.
70
if bytes.startswith(MESSAGE_VERSION_THREE):
71
protocol_factory = build_server_protocol_three
72
bytes = bytes[len(MESSAGE_VERSION_THREE):]
73
elif bytes.startswith(REQUEST_VERSION_TWO):
74
protocol_factory = SmartServerRequestProtocolTwo
75
bytes = bytes[len(REQUEST_VERSION_TWO):]
77
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
78
protocol_factory = protocol.build_server_protocol_three
79
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
80
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
81
protocol_factory = protocol.SmartServerRequestProtocolTwo
82
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
77
protocol_factory = SmartServerRequestProtocolOne
84
protocol_factory = protocol.SmartServerRequestProtocolOne
78
85
return protocol_factory, bytes
81
class SmartServerStreamMedium(object):
88
def _get_line(read_bytes_func):
89
"""Read bytes using read_bytes_func until a newline byte.
91
This isn't particularly efficient, so should only be used when the
92
expected size of the line is quite short.
94
:returns: a tuple of two strs: (line, excess)
98
while newline_pos == -1:
99
new_bytes = read_bytes_func(1)
102
# Ran out of bytes before receiving a complete line.
104
newline_pos = bytes.find('\n')
105
line = bytes[:newline_pos+1]
106
excess = bytes[newline_pos+1:]
110
class SmartMedium(object):
111
"""Base class for smart protocol media, both client- and server-side."""
114
self._push_back_buffer = None
116
def _push_back(self, bytes):
117
"""Return unused bytes to the medium, because they belong to the next
120
This sets the _push_back_buffer to the given bytes.
122
if self._push_back_buffer is not None:
123
raise AssertionError(
124
"_push_back called when self._push_back_buffer is %r"
125
% (self._push_back_buffer,))
128
self._push_back_buffer = bytes
130
def _get_push_back_buffer(self):
131
if self._push_back_buffer == '':
132
raise AssertionError(
133
'%s._push_back_buffer should never be the empty string, '
134
'which can be confused with EOF' % (self,))
135
bytes = self._push_back_buffer
136
self._push_back_buffer = None
139
def read_bytes(self, desired_count):
140
"""Read some bytes from this medium.
142
:returns: some bytes, possibly more or less than the number requested
143
in 'desired_count' depending on the medium.
145
if self._push_back_buffer is not None:
146
return self._get_push_back_buffer()
147
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
148
return self._read_bytes(bytes_to_read)
150
def _read_bytes(self, count):
151
raise NotImplementedError(self._read_bytes)
154
"""Read bytes from this request's response until a newline byte.
156
This isn't particularly efficient, so should only be used when the
157
expected size of the line is quite short.
159
:returns: a string of bytes ending in a newline (byte 0x0A).
161
line, excess = _get_line(self.read_bytes)
162
self._push_back(excess)
166
class SmartServerStreamMedium(SmartMedium):
82
167
"""Handles smart commands coming over a stream.
84
169
The stream may be a pipe connected to sshd, or a tcp socket, or an
105
190
self.backing_transport = backing_transport
106
191
self.root_client_path = root_client_path
107
192
self.finished = False
108
self._push_back_buffer = None
110
def _push_back(self, bytes):
111
"""Return unused bytes to the medium, because they belong to the next
114
This sets the _push_back_buffer to the given bytes.
116
if self._push_back_buffer is not None:
117
raise AssertionError(
118
"_push_back called when self._push_back_buffer is %r"
119
% (self._push_back_buffer,))
122
self._push_back_buffer = bytes
124
def _get_push_back_buffer(self):
125
if self._push_back_buffer == '':
126
raise AssertionError(
127
'%s._push_back_buffer should never be the empty string, '
128
'which can be confused with EOF' % (self,))
129
bytes = self._push_back_buffer
130
self._push_back_buffer = None
193
SmartMedium.__init__(self)
134
196
"""Serve requests until the client disconnects."""
175
237
"""Called when an unhandled exception from the protocol occurs."""
176
238
raise NotImplementedError(self.terminate_due_to_error)
178
def _get_bytes(self, desired_count):
240
def _read_bytes(self, desired_count):
179
241
"""Get some bytes from the medium.
181
243
:param desired_count: number of bytes we want to read.
183
raise NotImplementedError(self._get_bytes)
186
"""Read bytes from this request's response until a newline byte.
188
This isn't particularly efficient, so should only be used when the
189
expected size of the line is quite short.
191
:returns: a string of bytes ending in a newline (byte 0x0A).
195
while newline_pos == -1:
196
new_bytes = self._get_bytes(1)
199
# Ran out of bytes before receiving a complete line.
201
newline_pos = bytes.find('\n')
202
line = bytes[:newline_pos+1]
203
self._push_back(bytes[newline_pos+1:])
245
raise NotImplementedError(self._read_bytes)
207
248
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
228
272
self._push_back(protocol.unused_data)
230
def _get_bytes(self, desired_count):
231
if self._push_back_buffer is not None:
232
return self._get_push_back_buffer()
274
def _read_bytes(self, desired_count):
233
275
# We ignore the desired_count because on sockets it's more efficient to
235
return self.socket.recv(4096)
276
# read large chunks (of _MAX_READ_SIZE bytes) at a time.
277
return self.socket.recv(_MAX_READ_SIZE)
237
279
def terminate_due_to_error(self):
238
280
# TODO: This should log to a server log file, but no such thing
239
281
# exists yet. Andrew Bennetts 2006-09-29.
401
444
return self._read_bytes(count)
403
446
def _read_bytes(self, count):
404
"""Helper for read_bytes.
447
"""Helper for SmartClientMediumRequest.read_bytes.
406
449
read_bytes checks the state of the request to determing if bytes
407
450
should be read. After that it hands off to _read_bytes to do the
453
By default this forwards to self._medium.read_bytes because we are
454
operating on the medium's stream.
410
raise NotImplementedError(self._read_bytes)
456
return self._medium.read_bytes(count)
412
458
def read_line(self):
413
"""Read bytes from this request's response until a newline byte.
415
This isn't particularly efficient, so should only be used when the
416
expected size of the line is quite short.
418
:returns: a string of bytes ending in a newline (byte 0x0A).
420
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
422
while not line or line[-1] != '\n':
423
new_char = self.read_bytes(1)
426
# end of file encountered reading from server
427
raise errors.ConnectionReset(
428
"please check connectivity and permissions",
429
"(and try -Dhpss if further diagnosis is required)")
459
line = self._read_line()
460
if not line.endswith('\n'):
461
# end of file encountered reading from server
462
raise errors.ConnectionReset(
463
"please check connectivity and permissions",
464
"(and try -Dhpss if further diagnosis is required)")
433
class SmartClientMedium(object):
467
def _read_line(self):
468
"""Helper for SmartClientMediumRequest.read_line.
470
By default this forwards to self._medium._get_line because we are
471
operating on the medium's stream.
473
return self._medium._get_line()
476
class _DebugCounter(object):
477
"""An object that counts the HPSS calls made to each client medium.
479
When a medium is garbage-collected, or failing that when atexit functions
480
are run, the total number of calls made on that medium are reported via
485
self.counts = weakref.WeakKeyDictionary()
486
client._SmartClient.hooks.install_named_hook(
487
'call', self.increment_call_count, 'hpss call counter')
488
atexit.register(self.flush_all)
490
def track(self, medium):
491
"""Start tracking calls made to a medium.
493
This only keeps a weakref to the medium, so shouldn't affect the
496
medium_repr = repr(medium)
497
# Add this medium to the WeakKeyDictionary
498
self.counts[medium] = [0, medium_repr]
499
# Weakref callbacks are fired in reverse order of their association
500
# with the referenced object. So we add a weakref *after* adding to
501
# the WeakKeyDict so that we can report the value from it before the
502
# entry is removed by the WeakKeyDict's own callback.
503
ref = weakref.ref(medium, self.done)
505
def increment_call_count(self, params):
506
# Increment the count in the WeakKeyDictionary
507
value = self.counts[params.medium]
511
value = self.counts[ref]
512
count, medium_repr = value
513
# In case this callback is invoked for the same ref twice (by the
514
# weakref callback and by the atexit function), set the call count back
515
# to 0 so this item won't be reported twice.
518
trace.note('HPSS calls: %d %s', count, medium_repr)
521
for ref in list(self.counts.keys()):
524
_debug_counter = None
527
class SmartClientMedium(SmartMedium):
434
528
"""Smart client is a medium for sending smart protocol requests over."""
436
530
def __init__(self, base):
440
534
self._protocol_version = None
441
535
self._done_hello = False
442
536
# Be optimistic: we assume the remote end can accept new remote
443
# requests until we get an error saying otherwise. (1.2 adds some
444
# requests that send bodies, which confuses older servers.)
445
self._remote_is_at_least_1_2 = True
537
# requests until we get an error saying otherwise.
538
# _remote_version_is_before tracks the bzr version the remote side
539
# can be based on what we've seen so far.
540
self._remote_version_is_before = None
541
# Install debug hook function if debug flag is set.
542
if 'hpss' in debug.debug_flags:
543
global _debug_counter
544
if _debug_counter is None:
545
_debug_counter = _DebugCounter()
546
_debug_counter.track(self)
548
def _is_remote_before(self, version_tuple):
549
"""Is it possible the remote side supports RPCs for a given version?
553
needed_version = (1, 2)
554
if medium._is_remote_before(needed_version):
555
fallback_to_pre_1_2_rpc()
559
except UnknownSmartMethod:
560
medium._remember_remote_is_before(needed_version)
561
fallback_to_pre_1_2_rpc()
563
:seealso: _remember_remote_is_before
565
if self._remote_version_is_before is None:
566
# So far, the remote side seems to support everything
568
return version_tuple >= self._remote_version_is_before
570
def _remember_remote_is_before(self, version_tuple):
571
"""Tell this medium that the remote side is older the given version.
573
:seealso: _is_remote_before
575
if (self._remote_version_is_before is not None and
576
version_tuple > self._remote_version_is_before):
577
raise AssertionError(
578
"_remember_remote_is_before(%r) called, but "
579
"_remember_remote_is_before(%r) was called previously."
580
% (version_tuple, self._remote_version_is_before))
581
self._remote_version_is_before = version_tuple
447
583
def protocol_version(self):
448
584
"""Find out if 'hello' smart request works."""
664
798
"""Connect this medium if not already connected."""
665
799
if self._connected:
667
self._socket = socket.socket()
668
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
669
801
if self._port is None:
670
802
port = BZR_DEFAULT_PORT
672
804
port = int(self._port)
674
self._socket.connect((self._host, port))
675
except socket.error, err:
806
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
807
socket.SOCK_STREAM, 0, 0)
808
except socket.gaierror, (err_num, err_msg):
809
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
810
(self._host, port, err_msg))
811
# Initialize err in case there are no addresses returned:
812
err = socket.error("no address found for %s" % self._host)
813
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
815
self._socket = socket.socket(family, socktype, proto)
816
self._socket.setsockopt(socket.IPPROTO_TCP,
817
socket.TCP_NODELAY, 1)
818
self._socket.connect(sockaddr)
819
except socket.error, err:
820
if self._socket is not None:
825
if self._socket is None:
676
826
# socket errors either have a (string) or (errno, string) as their
678
828
if type(err.args) is str: