24
24
bzrlib/transport/smart/__init__.py.
30
from bzrlib import errors
31
from bzrlib.smart.protocol import SmartServerRequestProtocolOne
34
from bzrlib.transport import ssh
35
except errors.ParamikoNotPresent:
36
# no paramiko. SmartSSHClientMedium will break.
40
class SmartServerStreamMedium(object):
33
from bzrlib.lazy_import import lazy_import
34
lazy_import(globals(), """
45
from bzrlib.smart import client, protocol, request, vfs
46
from bzrlib.transport import ssh
48
#usually already imported, and getting IllegalScoperReplacer on it here.
49
from bzrlib import osutils
51
# We must not read any more than 64k at a time so we don't risk "no buffer
52
# space available" errors on some platforms. Windows in particular is likely
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
54
_MAX_READ_SIZE = 64 * 1024
57
def _get_protocol_factory_for_bytes(bytes):
58
"""Determine the right protocol factory for 'bytes'.
60
This will return an appropriate protocol factory depending on the version
61
of the protocol being used, as determined by inspecting the given bytes.
62
The bytes should have at least one newline byte (i.e. be a whole line),
63
otherwise it's possible that a request will be incorrectly identified as
66
Typical use would be::
68
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
69
server_protocol = factory(transport, write_func, root_client_path)
70
server_protocol.accept_bytes(unused_bytes)
72
:param bytes: a str of bytes of the start of the request.
73
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
74
a callable that takes three args: transport, write_func,
75
root_client_path. unused_bytes are any bytes that were not part of a
76
protocol version marker.
78
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
79
protocol_factory = protocol.build_server_protocol_three
80
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
81
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
82
protocol_factory = protocol.SmartServerRequestProtocolTwo
83
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
85
protocol_factory = protocol.SmartServerRequestProtocolOne
86
return protocol_factory, bytes
89
def _get_line(read_bytes_func):
90
"""Read bytes using read_bytes_func until a newline byte.
92
This isn't particularly efficient, so should only be used when the
93
expected size of the line is quite short.
95
:returns: a tuple of two strs: (line, excess)
99
while newline_pos == -1:
100
new_bytes = read_bytes_func(1)
103
# Ran out of bytes before receiving a complete line.
105
newline_pos = bytes.find('\n')
106
line = bytes[:newline_pos+1]
107
excess = bytes[newline_pos+1:]
111
class SmartMedium(object):
112
"""Base class for smart protocol media, both client- and server-side."""
115
self._push_back_buffer = None
117
def _push_back(self, bytes):
118
"""Return unused bytes to the medium, because they belong to the next
121
This sets the _push_back_buffer to the given bytes.
123
if self._push_back_buffer is not None:
124
raise AssertionError(
125
"_push_back called when self._push_back_buffer is %r"
126
% (self._push_back_buffer,))
129
self._push_back_buffer = bytes
131
def _get_push_back_buffer(self):
132
if self._push_back_buffer == '':
133
raise AssertionError(
134
'%s._push_back_buffer should never be the empty string, '
135
'which can be confused with EOF' % (self,))
136
bytes = self._push_back_buffer
137
self._push_back_buffer = None
140
def read_bytes(self, desired_count):
141
"""Read some bytes from this medium.
143
:returns: some bytes, possibly more or less than the number requested
144
in 'desired_count' depending on the medium.
146
if self._push_back_buffer is not None:
147
return self._get_push_back_buffer()
148
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
149
return self._read_bytes(bytes_to_read)
151
def _read_bytes(self, count):
152
raise NotImplementedError(self._read_bytes)
155
"""Read bytes from this request's response until a newline byte.
157
This isn't particularly efficient, so should only be used when the
158
expected size of the line is quite short.
160
:returns: a string of bytes ending in a newline (byte 0x0A).
162
line, excess = _get_line(self.read_bytes)
163
self._push_back(excess)
166
def _report_activity(self, bytes, direction):
167
"""Notify that this medium has activity.
169
Implementations should call this from all methods that actually do IO.
170
Be careful that it's not called twice, if one method is implemented on
173
:param bytes: Number of bytes read or written.
174
:param direction: 'read' or 'write' or None.
176
ui.ui_factory.report_transport_activity(self, bytes, direction)
179
class SmartServerStreamMedium(SmartMedium):
41
180
"""Handles smart commands coming over a stream.
43
182
The stream may be a pipe connected to sshd, or a tcp socket, or an
89
250
"""Called when an unhandled exception from the protocol occurs."""
90
251
raise NotImplementedError(self.terminate_due_to_error)
253
def _read_bytes(self, desired_count):
254
"""Get some bytes from the medium.
256
:param desired_count: number of bytes we want to read.
258
raise NotImplementedError(self._read_bytes)
93
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
95
def __init__(self, sock, backing_transport):
263
def __init__(self, sock, backing_transport, root_client_path='/'):
98
266
:param sock: the socket the server will read from. It will be put
99
267
into blocking mode.
101
SmartServerStreamMedium.__init__(self, backing_transport)
269
SmartServerStreamMedium.__init__(
270
self, backing_transport, root_client_path=root_client_path)
103
271
sock.setblocking(True)
104
272
self.socket = sock
106
274
def _serve_one_request_unguarded(self, protocol):
107
275
while protocol.next_read_size():
109
protocol.accept_bytes(self.push_back)
112
bytes = self.socket.recv(4096)
116
protocol.accept_bytes(bytes)
118
self.push_back = protocol.excess_buffer
276
# We can safely try to read large chunks. If there is less data
277
# than _MAX_READ_SIZE ready, the socket wil just return a short
278
# read immediately rather than block.
279
bytes = self.read_bytes(_MAX_READ_SIZE)
283
protocol.accept_bytes(bytes)
285
self._push_back(protocol.unused_data)
287
def _read_bytes(self, desired_count):
288
return _read_bytes_from_socket(
289
self.socket.recv, desired_count, self._report_activity)
120
291
def terminate_due_to_error(self):
121
"""Called when an unhandled exception from the protocol occurs."""
122
292
# TODO: This should log to a server log file, but no such thing
123
293
# exists yet. Andrew Bennetts 2006-09-29.
294
osutils.until_no_eintr(self.socket.close)
125
295
self.finished = True
127
297
def _write_out(self, bytes):
128
self.socket.sendall(bytes)
298
osutils.send_all(self.socket, bytes, self._report_activity)
131
301
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
280
456
return self._read_bytes(count)
282
458
def _read_bytes(self, count):
283
"""Helper for read_bytes.
459
"""Helper for SmartClientMediumRequest.read_bytes.
285
461
read_bytes checks the state of the request to determing if bytes
286
462
should be read. After that it hands off to _read_bytes to do the
289
raise NotImplementedError(self._read_bytes)
292
class SmartClientMedium(object):
465
By default this forwards to self._medium.read_bytes because we are
466
operating on the medium's stream.
468
return self._medium.read_bytes(count)
471
line = self._read_line()
472
if not line.endswith('\n'):
473
# end of file encountered reading from server
474
raise errors.ConnectionReset(
475
"Unexpected end of message. Please check connectivity "
476
"and permissions, and report a bug if problems persist.")
479
def _read_line(self):
480
"""Helper for SmartClientMediumRequest.read_line.
482
By default this forwards to self._medium._get_line because we are
483
operating on the medium's stream.
485
return self._medium._get_line()
488
class _DebugCounter(object):
489
"""An object that counts the HPSS calls made to each client medium.
491
When a medium is garbage-collected, or failing that when atexit functions
492
are run, the total number of calls made on that medium are reported via
497
self.counts = weakref.WeakKeyDictionary()
498
client._SmartClient.hooks.install_named_hook(
499
'call', self.increment_call_count, 'hpss call counter')
500
atexit.register(self.flush_all)
502
def track(self, medium):
503
"""Start tracking calls made to a medium.
505
This only keeps a weakref to the medium, so shouldn't affect the
508
medium_repr = repr(medium)
509
# Add this medium to the WeakKeyDictionary
510
self.counts[medium] = dict(count=0, vfs_count=0,
511
medium_repr=medium_repr)
512
# Weakref callbacks are fired in reverse order of their association
513
# with the referenced object. So we add a weakref *after* adding to
514
# the WeakKeyDict so that we can report the value from it before the
515
# entry is removed by the WeakKeyDict's own callback.
516
ref = weakref.ref(medium, self.done)
518
def increment_call_count(self, params):
519
# Increment the count in the WeakKeyDictionary
520
value = self.counts[params.medium]
523
request_method = request.request_handlers.get(params.method)
525
# A method we don't know about doesn't count as a VFS method.
527
if issubclass(request_method, vfs.VfsRequest):
528
value['vfs_count'] += 1
531
value = self.counts[ref]
532
count, vfs_count, medium_repr = (
533
value['count'], value['vfs_count'], value['medium_repr'])
534
# In case this callback is invoked for the same ref twice (by the
535
# weakref callback and by the atexit function), set the call count back
536
# to 0 so this item won't be reported twice.
538
value['vfs_count'] = 0
540
trace.note('HPSS calls: %d (%d vfs) %s',
541
count, vfs_count, medium_repr)
544
for ref in list(self.counts.keys()):
547
_debug_counter = None
550
class SmartClientMedium(SmartMedium):
293
551
"""Smart client is a medium for sending smart protocol requests over."""
553
def __init__(self, base):
554
super(SmartClientMedium, self).__init__()
556
self._protocol_version_error = None
557
self._protocol_version = None
558
self._done_hello = False
559
# Be optimistic: we assume the remote end can accept new remote
560
# requests until we get an error saying otherwise.
561
# _remote_version_is_before tracks the bzr version the remote side
562
# can be based on what we've seen so far.
563
self._remote_version_is_before = None
564
# Install debug hook function if debug flag is set.
565
if 'hpss' in debug.debug_flags:
566
global _debug_counter
567
if _debug_counter is None:
568
_debug_counter = _DebugCounter()
569
_debug_counter.track(self)
571
def _is_remote_before(self, version_tuple):
572
"""Is it possible the remote side supports RPCs for a given version?
576
needed_version = (1, 2)
577
if medium._is_remote_before(needed_version):
578
fallback_to_pre_1_2_rpc()
582
except UnknownSmartMethod:
583
medium._remember_remote_is_before(needed_version)
584
fallback_to_pre_1_2_rpc()
586
:seealso: _remember_remote_is_before
588
if self._remote_version_is_before is None:
589
# So far, the remote side seems to support everything
591
return version_tuple >= self._remote_version_is_before
593
def _remember_remote_is_before(self, version_tuple):
594
"""Tell this medium that the remote side is older the given version.
596
:seealso: _is_remote_before
598
if (self._remote_version_is_before is not None and
599
version_tuple > self._remote_version_is_before):
600
# We have been told that the remote side is older than some version
601
# which is newer than a previously supplied older-than version.
602
# This indicates that some smart verb call is not guarded
603
# appropriately (it should simply not have been tried).
604
raise AssertionError(
605
"_remember_remote_is_before(%r) called, but "
606
"_remember_remote_is_before(%r) was called previously."
607
% (version_tuple, self._remote_version_is_before))
608
self._remote_version_is_before = version_tuple
610
def protocol_version(self):
611
"""Find out if 'hello' smart request works."""
612
if self._protocol_version_error is not None:
613
raise self._protocol_version_error
614
if not self._done_hello:
616
medium_request = self.get_request()
617
# Send a 'hello' request in protocol version one, for maximum
618
# backwards compatibility.
619
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
620
client_protocol.query_version()
621
self._done_hello = True
622
except errors.SmartProtocolError, e:
623
# Cache the error, just like we would cache a successful
625
self._protocol_version_error = e
629
def should_probe(self):
630
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
633
Some transports are unambiguously smart-only; there's no need to check
634
if the transport is able to carry smart requests, because that's all
635
it is for. In those cases, this method should return False.
637
But some HTTP transports can sometimes fail to carry smart requests,
638
but still be usuable for accessing remote bzrdirs via plain file
639
accesses. So for those transports, their media should return True here
640
so that RemoteBzrDirFormat can determine if it is appropriate for that
295
645
def disconnect(self):
296
646
"""If this medium maintains a persistent connection, close it.
298
648
The default implementation does nothing.
651
def remote_path_from_transport(self, transport):
652
"""Convert transport into a path suitable for using in a request.
654
Note that the resulting remote path doesn't encode the host name or
655
anything but path, so it is only safe to use it in requests sent over
656
the medium from the matching transport.
658
medium_base = urlutils.join(self.base, '/')
659
rel_url = urlutils.relative_url(medium_base, transport.base)
660
return urllib.unquote(rel_url)
302
663
class SmartClientStreamMedium(SmartClientMedium):
303
664
"""Stream based medium common class.
337
699
return SmartClientStreamMediumRequest(self)
339
def read_bytes(self, count):
340
return self._read_bytes(count)
343
702
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
344
703
"""A client medium using simple pipes.
346
705
This client does not manage the pipes: it assumes they will always be open.
349
def __init__(self, readable_pipe, writeable_pipe):
350
SmartClientStreamMedium.__init__(self)
708
def __init__(self, readable_pipe, writeable_pipe, base):
709
SmartClientStreamMedium.__init__(self, base)
351
710
self._readable_pipe = readable_pipe
352
711
self._writeable_pipe = writeable_pipe
354
713
def _accept_bytes(self, bytes):
355
714
"""See SmartClientStreamMedium.accept_bytes."""
356
self._writeable_pipe.write(bytes)
715
osutils.until_no_eintr(self._writeable_pipe.write, bytes)
716
self._report_activity(len(bytes), 'write')
358
718
def _flush(self):
359
719
"""See SmartClientStreamMedium._flush()."""
360
self._writeable_pipe.flush()
720
osutils.until_no_eintr(self._writeable_pipe.flush)
362
722
def _read_bytes(self, count):
363
723
"""See SmartClientStreamMedium._read_bytes."""
364
return self._readable_pipe.read(count)
724
bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
725
self._report_activity(len(bytes), 'read')
367
729
class SmartSSHClientMedium(SmartClientStreamMedium):
368
730
"""A client medium using SSH."""
370
732
def __init__(self, host, port=None, username=None, password=None,
733
base=None, vendor=None, bzr_remote_path=None):
372
734
"""Creates a client that will connect on the first use.
374
736
:param vendor: An optional override for the ssh vendor to use. See
375
737
bzrlib.transport.ssh for details on ssh vendors.
377
SmartClientStreamMedium.__init__(self)
378
739
self._connected = False
379
740
self._host = host
380
741
self._password = password
381
742
self._port = port
382
743
self._username = username
744
# SmartClientStreamMedium stores the repr of this object in its
745
# _DebugCounter so we have to store all the values used in our repr
746
# method before calling the super init.
747
SmartClientStreamMedium.__init__(self, base)
383
748
self._read_from = None
384
749
self._ssh_connection = None
385
750
self._vendor = vendor
386
751
self._write_to = None
752
self._bzr_remote_path = bzr_remote_path
753
# for the benefit of progress making a short description of this
755
self._scheme = 'bzr+ssh'
758
return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
759
self.__class__.__name__,
388
765
def _accept_bytes(self, bytes):
389
766
"""See SmartClientStreamMedium.accept_bytes."""
390
767
self._ensure_connection()
391
self._write_to.write(bytes)
768
osutils.until_no_eintr(self._write_to.write, bytes)
769
self._report_activity(len(bytes), 'write')
393
771
def disconnect(self):
394
772
"""See SmartClientMedium.disconnect()."""
395
773
if not self._connected:
397
self._read_from.close()
398
self._write_to.close()
775
osutils.until_no_eintr(self._read_from.close)
776
osutils.until_no_eintr(self._write_to.close)
399
777
self._ssh_connection.close()
400
778
self._connected = False
455
840
"""Connect this medium if not already connected."""
456
841
if self._connected:
458
self._socket = socket.socket()
459
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
460
result = self._socket.connect_ex((self._host, int(self._port)))
843
if self._port is None:
844
port = BZR_DEFAULT_PORT
846
port = int(self._port)
848
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
849
socket.SOCK_STREAM, 0, 0)
850
except socket.gaierror, (err_num, err_msg):
851
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
852
(self._host, port, err_msg))
853
# Initialize err in case there are no addresses returned:
854
err = socket.error("no address found for %s" % self._host)
855
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
857
self._socket = socket.socket(family, socktype, proto)
858
self._socket.setsockopt(socket.IPPROTO_TCP,
859
socket.TCP_NODELAY, 1)
860
self._socket.connect(sockaddr)
861
except socket.error, err:
862
if self._socket is not None:
867
if self._socket is None:
868
# socket errors either have a (string) or (errno, string) as their
870
if type(err.args) is str:
873
err_msg = err.args[1]
462
874
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
463
(self._host, self._port, os.strerror(result)))
875
(self._host, port, err_msg))
464
876
self._connected = True
466
878
def _flush(self):
467
879
"""See SmartClientStreamMedium._flush().
469
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
881
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
470
882
add a means to do a flush, but that can be done in the future.