24
24
bzrlib/transport/smart/__init__.py.
33
from bzrlib.lazy_import import lazy_import
34
lazy_import(globals(), """
31
from bzrlib import errors
32
from bzrlib.smart.protocol import (
34
SmartServerRequestProtocolOne,
35
SmartServerRequestProtocolTwo,
46
from bzrlib.smart import client, protocol
47
from bzrlib.transport import ssh
51
# We must not read any more than 64k at a time so we don't risk "no buffer
52
# space available" errors on some platforms. Windows in particular is likely
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
54
_MAX_READ_SIZE = 64 * 1024
57
def _get_protocol_factory_for_bytes(bytes):
58
"""Determine the right protocol factory for 'bytes'.
60
This will return an appropriate protocol factory depending on the version
61
of the protocol being used, as determined by inspecting the given bytes.
62
The bytes should have at least one newline byte (i.e. be a whole line),
63
otherwise it's possible that a request will be incorrectly identified as
66
Typical use would be::
68
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
69
server_protocol = factory(transport, write_func, root_client_path)
70
server_protocol.accept_bytes(unused_bytes)
72
:param bytes: a str of bytes of the start of the request.
73
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
74
a callable that takes three args: transport, write_func,
75
root_client_path. unused_bytes are any bytes that were not part of a
76
protocol version marker.
78
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
79
protocol_factory = protocol.build_server_protocol_three
80
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
81
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
82
protocol_factory = protocol.SmartServerRequestProtocolTwo
83
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
85
protocol_factory = protocol.SmartServerRequestProtocolOne
86
return protocol_factory, bytes
89
def _get_line(read_bytes_func):
90
"""Read bytes using read_bytes_func until a newline byte.
92
This isn't particularly efficient, so should only be used when the
93
expected size of the line is quite short.
95
:returns: a tuple of two strs: (line, excess)
99
while newline_pos == -1:
100
new_bytes = read_bytes_func(1)
103
# Ran out of bytes before receiving a complete line.
105
newline_pos = bytes.find('\n')
106
line = bytes[:newline_pos+1]
107
excess = bytes[newline_pos+1:]
111
class SmartMedium(object):
112
"""Base class for smart protocol media, both client- and server-side."""
115
self._push_back_buffer = None
117
def _push_back(self, bytes):
118
"""Return unused bytes to the medium, because they belong to the next
121
This sets the _push_back_buffer to the given bytes.
123
if self._push_back_buffer is not None:
124
raise AssertionError(
125
"_push_back called when self._push_back_buffer is %r"
126
% (self._push_back_buffer,))
129
self._push_back_buffer = bytes
131
def _get_push_back_buffer(self):
132
if self._push_back_buffer == '':
133
raise AssertionError(
134
'%s._push_back_buffer should never be the empty string, '
135
'which can be confused with EOF' % (self,))
136
bytes = self._push_back_buffer
137
self._push_back_buffer = None
140
def read_bytes(self, desired_count):
141
"""Read some bytes from this medium.
143
:returns: some bytes, possibly more or less than the number requested
144
in 'desired_count' depending on the medium.
146
if self._push_back_buffer is not None:
147
return self._get_push_back_buffer()
148
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
149
return self._read_bytes(bytes_to_read)
151
def _read_bytes(self, count):
152
raise NotImplementedError(self._read_bytes)
155
"""Read bytes from this request's response until a newline byte.
157
This isn't particularly efficient, so should only be used when the
158
expected size of the line is quite short.
160
:returns: a string of bytes ending in a newline (byte 0x0A).
162
line, excess = _get_line(self.read_bytes)
163
self._push_back(excess)
166
def _report_activity(self, bytes, direction):
167
"""Notify that this medium has activity.
169
Implementations should call this from all methods that actually do IO.
170
Be careful that it's not called twice, if one method is implemented on
173
:param bytes: Number of bytes read or written.
174
:param direction: 'read' or 'write' or None.
176
ui.ui_factory.report_transport_activity(self, bytes, direction)
179
class SmartServerStreamMedium(SmartMedium):
39
from bzrlib.transport import ssh
40
except errors.ParamikoNotPresent:
41
# no paramiko. SmartSSHClientMedium will break.
45
class SmartServerStreamMedium(object):
180
46
"""Handles smart commands coming over a stream.
182
48
The stream may be a pipe connected to sshd, or a tcp socket, or an
250
113
"""Called when an unhandled exception from the protocol occurs."""
251
114
raise NotImplementedError(self.terminate_due_to_error)
253
def _read_bytes(self, desired_count):
116
def _get_bytes(self, desired_count):
254
117
"""Get some bytes from the medium.
256
119
:param desired_count: number of bytes we want to read.
258
raise NotImplementedError(self._read_bytes)
121
raise NotImplementedError(self._get_bytes)
124
"""Read bytes from this request's response until a newline byte.
126
This isn't particularly efficient, so should only be used when the
127
expected size of the line is quite short.
129
:returns: a string of bytes ending in a newline (byte 0x0A).
131
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
133
while not line or line[-1] != '\n':
134
new_char = self._get_bytes(1)
137
# Ran out of bytes before receiving a complete line.
261
142
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
263
def __init__(self, sock, backing_transport, root_client_path='/'):
144
def __init__(self, sock, backing_transport):
266
147
:param sock: the socket the server will read from. It will be put
267
148
into blocking mode.
269
SmartServerStreamMedium.__init__(
270
self, backing_transport, root_client_path=root_client_path)
150
SmartServerStreamMedium.__init__(self, backing_transport)
271
152
sock.setblocking(True)
272
153
self.socket = sock
274
155
def _serve_one_request_unguarded(self, protocol):
275
156
while protocol.next_read_size():
276
# We can safely try to read large chunks. If there is less data
277
# than _MAX_READ_SIZE ready, the socket wil just return a short
278
# read immediately rather than block.
279
bytes = self.read_bytes(_MAX_READ_SIZE)
283
protocol.accept_bytes(bytes)
285
self._push_back(protocol.unused_data)
287
def _read_bytes(self, desired_count):
158
protocol.accept_bytes(self.push_back)
161
bytes = self._get_bytes(4096)
165
protocol.accept_bytes(bytes)
167
self.push_back = protocol.excess_buffer
169
def _get_bytes(self, desired_count):
288
170
# We ignore the desired_count because on sockets it's more efficient to
289
# read large chunks (of _MAX_READ_SIZE bytes) at a time.
290
bytes = osutils.until_no_eintr(self.socket.recv, _MAX_READ_SIZE)
291
self._report_activity(len(bytes), 'read')
172
return self.socket.recv(4096)
294
174
def terminate_due_to_error(self):
175
"""Called when an unhandled exception from the protocol occurs."""
295
176
# TODO: This should log to a server log file, but no such thing
296
177
# exists yet. Andrew Bennetts 2006-09-29.
297
178
self.socket.close()
298
179
self.finished = True
300
181
def _write_out(self, bytes):
301
osutils.send_all(self.socket, bytes, self._report_activity)
182
self.socket.sendall(bytes)
304
185
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
459
337
return self._read_bytes(count)
461
339
def _read_bytes(self, count):
462
"""Helper for SmartClientMediumRequest.read_bytes.
340
"""Helper for read_bytes.
464
342
read_bytes checks the state of the request to determing if bytes
465
343
should be read. After that it hands off to _read_bytes to do the
468
By default this forwards to self._medium.read_bytes because we are
469
operating on the medium's stream.
471
return self._medium.read_bytes(count)
346
raise NotImplementedError(self._read_bytes)
473
348
def read_line(self):
474
line = self._read_line()
475
if not line.endswith('\n'):
476
# end of file encountered reading from server
477
raise errors.ConnectionReset(
478
"please check connectivity and permissions")
349
"""Read bytes from this request's response until a newline byte.
351
This isn't particularly efficient, so should only be used when the
352
expected size of the line is quite short.
354
:returns: a string of bytes ending in a newline (byte 0x0A).
356
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
358
while not line or line[-1] != '\n':
359
new_char = self.read_bytes(1)
362
raise errors.SmartProtocolError(
363
'unexpected end of file reading from server')
481
def _read_line(self):
482
"""Helper for SmartClientMediumRequest.read_line.
484
By default this forwards to self._medium._get_line because we are
485
operating on the medium's stream.
487
return self._medium._get_line()
490
class _DebugCounter(object):
491
"""An object that counts the HPSS calls made to each client medium.
493
When a medium is garbage-collected, or failing that when atexit functions
494
are run, the total number of calls made on that medium are reported via
499
self.counts = weakref.WeakKeyDictionary()
500
client._SmartClient.hooks.install_named_hook(
501
'call', self.increment_call_count, 'hpss call counter')
502
atexit.register(self.flush_all)
504
def track(self, medium):
505
"""Start tracking calls made to a medium.
507
This only keeps a weakref to the medium, so shouldn't affect the
510
medium_repr = repr(medium)
511
# Add this medium to the WeakKeyDictionary
512
self.counts[medium] = [0, medium_repr]
513
# Weakref callbacks are fired in reverse order of their association
514
# with the referenced object. So we add a weakref *after* adding to
515
# the WeakKeyDict so that we can report the value from it before the
516
# entry is removed by the WeakKeyDict's own callback.
517
ref = weakref.ref(medium, self.done)
519
def increment_call_count(self, params):
520
# Increment the count in the WeakKeyDictionary
521
value = self.counts[params.medium]
525
value = self.counts[ref]
526
count, medium_repr = value
527
# In case this callback is invoked for the same ref twice (by the
528
# weakref callback and by the atexit function), set the call count back
529
# to 0 so this item won't be reported twice.
532
trace.note('HPSS calls: %d %s', count, medium_repr)
535
for ref in list(self.counts.keys()):
538
_debug_counter = None
541
class SmartClientMedium(SmartMedium):
367
class SmartClientMedium(object):
542
368
"""Smart client is a medium for sending smart protocol requests over."""
544
def __init__(self, base):
545
super(SmartClientMedium, self).__init__()
547
self._protocol_version_error = None
548
self._protocol_version = None
549
self._done_hello = False
550
# Be optimistic: we assume the remote end can accept new remote
551
# requests until we get an error saying otherwise.
552
# _remote_version_is_before tracks the bzr version the remote side
553
# can be based on what we've seen so far.
554
self._remote_version_is_before = None
555
# Install debug hook function if debug flag is set.
556
if 'hpss' in debug.debug_flags:
557
global _debug_counter
558
if _debug_counter is None:
559
_debug_counter = _DebugCounter()
560
_debug_counter.track(self)
562
def _is_remote_before(self, version_tuple):
563
"""Is it possible the remote side supports RPCs for a given version?
567
needed_version = (1, 2)
568
if medium._is_remote_before(needed_version):
569
fallback_to_pre_1_2_rpc()
573
except UnknownSmartMethod:
574
medium._remember_remote_is_before(needed_version)
575
fallback_to_pre_1_2_rpc()
577
:seealso: _remember_remote_is_before
579
if self._remote_version_is_before is None:
580
# So far, the remote side seems to support everything
582
return version_tuple >= self._remote_version_is_before
584
def _remember_remote_is_before(self, version_tuple):
585
"""Tell this medium that the remote side is older the given version.
587
:seealso: _is_remote_before
589
if (self._remote_version_is_before is not None and
590
version_tuple > self._remote_version_is_before):
591
# We have been told that the remote side is older than some version
592
# which is newer than a previously supplied older-than version.
593
# This indicates that some smart verb call is not guarded
594
# appropriately (it should simply not have been tried).
595
raise AssertionError(
596
"_remember_remote_is_before(%r) called, but "
597
"_remember_remote_is_before(%r) was called previously."
598
% (version_tuple, self._remote_version_is_before))
599
self._remote_version_is_before = version_tuple
601
def protocol_version(self):
602
"""Find out if 'hello' smart request works."""
603
if self._protocol_version_error is not None:
604
raise self._protocol_version_error
605
if not self._done_hello:
607
medium_request = self.get_request()
608
# Send a 'hello' request in protocol version one, for maximum
609
# backwards compatibility.
610
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
611
client_protocol.query_version()
612
self._done_hello = True
613
except errors.SmartProtocolError, e:
614
# Cache the error, just like we would cache a successful
616
self._protocol_version_error = e
620
def should_probe(self):
621
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
624
Some transports are unambiguously smart-only; there's no need to check
625
if the transport is able to carry smart requests, because that's all
626
it is for. In those cases, this method should return False.
628
But some HTTP transports can sometimes fail to carry smart requests,
629
but still be usuable for accessing remote bzrdirs via plain file
630
accesses. So for those transports, their media should return True here
631
so that RemoteBzrDirFormat can determine if it is appropriate for that
636
370
def disconnect(self):
637
371
"""If this medium maintains a persistent connection, close it.
639
373
The default implementation does nothing.
642
def remote_path_from_transport(self, transport):
643
"""Convert transport into a path suitable for using in a request.
645
Note that the resulting remote path doesn't encode the host name or
646
anything but path, so it is only safe to use it in requests sent over
647
the medium from the matching transport.
649
medium_base = urlutils.join(self.base, '/')
650
rel_url = urlutils.relative_url(medium_base, transport.base)
651
return urllib.unquote(rel_url)
654
377
class SmartClientStreamMedium(SmartClientMedium):
655
378
"""Stream based medium common class.
713
437
def _read_bytes(self, count):
714
438
"""See SmartClientStreamMedium._read_bytes."""
715
bytes = self._readable_pipe.read(count)
716
self._report_activity(len(bytes), 'read')
439
return self._readable_pipe.read(count)
720
442
class SmartSSHClientMedium(SmartClientStreamMedium):
721
443
"""A client medium using SSH."""
723
445
def __init__(self, host, port=None, username=None, password=None,
724
base=None, vendor=None, bzr_remote_path=None):
725
447
"""Creates a client that will connect on the first use.
727
449
:param vendor: An optional override for the ssh vendor to use. See
728
450
bzrlib.transport.ssh for details on ssh vendors.
452
SmartClientStreamMedium.__init__(self)
730
453
self._connected = False
731
454
self._host = host
732
455
self._password = password
733
456
self._port = port
734
457
self._username = username
735
# SmartClientStreamMedium stores the repr of this object in its
736
# _DebugCounter so we have to store all the values used in our repr
737
# method before calling the super init.
738
SmartClientStreamMedium.__init__(self, base)
739
458
self._read_from = None
740
459
self._ssh_connection = None
741
460
self._vendor = vendor
742
461
self._write_to = None
743
self._bzr_remote_path = bzr_remote_path
744
# for the benefit of progress making a short description of this
746
self._scheme = 'bzr+ssh'
749
return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
750
self.__class__.__name__,
756
463
def _accept_bytes(self, bytes):
757
464
"""See SmartClientStreamMedium.accept_bytes."""
758
465
self._ensure_connection()
759
466
self._write_to.write(bytes)
760
self._report_activity(len(bytes), 'write')
762
468
def disconnect(self):
763
469
"""See SmartClientMedium.disconnect()."""
831
530
"""Connect this medium if not already connected."""
832
531
if self._connected:
834
if self._port is None:
835
port = BZR_DEFAULT_PORT
837
port = int(self._port)
839
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
840
socket.SOCK_STREAM, 0, 0)
841
except socket.gaierror, (err_num, err_msg):
842
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
843
(self._host, port, err_msg))
844
# Initialize err in case there are no addresses returned:
845
err = socket.error("no address found for %s" % self._host)
846
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
848
self._socket = socket.socket(family, socktype, proto)
849
self._socket.setsockopt(socket.IPPROTO_TCP,
850
socket.TCP_NODELAY, 1)
851
self._socket.connect(sockaddr)
852
except socket.error, err:
853
if self._socket is not None:
858
if self._socket is None:
859
# socket errors either have a (string) or (errno, string) as their
861
if type(err.args) is str:
864
err_msg = err.args[1]
533
self._socket = socket.socket()
534
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
535
result = self._socket.connect_ex((self._host, int(self._port)))
865
537
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
866
(self._host, port, err_msg))
538
(self._host, self._port, os.strerror(result)))
867
539
self._connected = True
869
541
def _flush(self):
870
542
"""See SmartClientStreamMedium._flush().
872
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
544
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
873
545
add a means to do a flush, but that can be done in the future.