24
24
bzrlib/transport/smart/__init__.py.
31
from bzrlib import errors
32
from bzrlib.smart.protocol import (
34
SmartServerRequestProtocolOne,
35
SmartServerRequestProtocolTwo,
33
from bzrlib.lazy_import import lazy_import
34
lazy_import(globals(), """
39
from bzrlib.transport import ssh
40
except errors.ParamikoNotPresent:
41
# no paramiko. SmartSSHClientMedium will break.
45
class SmartServerStreamMedium(object):
45
from bzrlib.smart import client, protocol, request, vfs
46
from bzrlib.transport import ssh
48
#usually already imported, and getting IllegalScoperReplacer on it here.
49
from bzrlib import osutils
51
# We must not read any more than 64k at a time so we don't risk "no buffer
52
# space available" errors on some platforms. Windows in particular is likely
53
# to give error 10053 or 10055 if we read more than 64k from a socket.
54
_MAX_READ_SIZE = 64 * 1024
57
def _get_protocol_factory_for_bytes(bytes):
58
"""Determine the right protocol factory for 'bytes'.
60
This will return an appropriate protocol factory depending on the version
61
of the protocol being used, as determined by inspecting the given bytes.
62
The bytes should have at least one newline byte (i.e. be a whole line),
63
otherwise it's possible that a request will be incorrectly identified as
66
Typical use would be::
68
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
69
server_protocol = factory(transport, write_func, root_client_path)
70
server_protocol.accept_bytes(unused_bytes)
72
:param bytes: a str of bytes of the start of the request.
73
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
74
a callable that takes three args: transport, write_func,
75
root_client_path. unused_bytes are any bytes that were not part of a
76
protocol version marker.
78
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
79
protocol_factory = protocol.build_server_protocol_three
80
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
81
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
82
protocol_factory = protocol.SmartServerRequestProtocolTwo
83
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
85
protocol_factory = protocol.SmartServerRequestProtocolOne
86
return protocol_factory, bytes
89
def _get_line(read_bytes_func):
90
"""Read bytes using read_bytes_func until a newline byte.
92
This isn't particularly efficient, so should only be used when the
93
expected size of the line is quite short.
95
:returns: a tuple of two strs: (line, excess)
99
while newline_pos == -1:
100
new_bytes = read_bytes_func(1)
103
# Ran out of bytes before receiving a complete line.
105
newline_pos = bytes.find('\n')
106
line = bytes[:newline_pos+1]
107
excess = bytes[newline_pos+1:]
111
class SmartMedium(object):
112
"""Base class for smart protocol media, both client- and server-side."""
115
self._push_back_buffer = None
117
def _push_back(self, bytes):
118
"""Return unused bytes to the medium, because they belong to the next
121
This sets the _push_back_buffer to the given bytes.
123
if self._push_back_buffer is not None:
124
raise AssertionError(
125
"_push_back called when self._push_back_buffer is %r"
126
% (self._push_back_buffer,))
129
self._push_back_buffer = bytes
131
def _get_push_back_buffer(self):
132
if self._push_back_buffer == '':
133
raise AssertionError(
134
'%s._push_back_buffer should never be the empty string, '
135
'which can be confused with EOF' % (self,))
136
bytes = self._push_back_buffer
137
self._push_back_buffer = None
140
def read_bytes(self, desired_count):
141
"""Read some bytes from this medium.
143
:returns: some bytes, possibly more or less than the number requested
144
in 'desired_count' depending on the medium.
146
if self._push_back_buffer is not None:
147
return self._get_push_back_buffer()
148
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
149
return self._read_bytes(bytes_to_read)
151
def _read_bytes(self, count):
152
raise NotImplementedError(self._read_bytes)
155
"""Read bytes from this request's response until a newline byte.
157
This isn't particularly efficient, so should only be used when the
158
expected size of the line is quite short.
160
:returns: a string of bytes ending in a newline (byte 0x0A).
162
line, excess = _get_line(self.read_bytes)
163
self._push_back(excess)
166
def _report_activity(self, bytes, direction):
167
"""Notify that this medium has activity.
169
Implementations should call this from all methods that actually do IO.
170
Be careful that it's not called twice, if one method is implemented on
173
:param bytes: Number of bytes read or written.
174
:param direction: 'read' or 'write' or None.
176
ui.ui_factory.report_transport_activity(self, bytes, direction)
179
class SmartServerStreamMedium(SmartMedium):
46
180
"""Handles smart commands coming over a stream.
48
182
The stream may be a pipe connected to sshd, or a tcp socket, or an
113
250
"""Called when an unhandled exception from the protocol occurs."""
114
251
raise NotImplementedError(self.terminate_due_to_error)
116
def _get_bytes(self, desired_count):
253
def _read_bytes(self, desired_count):
117
254
"""Get some bytes from the medium.
119
256
:param desired_count: number of bytes we want to read.
121
raise NotImplementedError(self._get_bytes)
124
"""Read bytes from this request's response until a newline byte.
126
This isn't particularly efficient, so should only be used when the
127
expected size of the line is quite short.
129
:returns: a string of bytes ending in a newline (byte 0x0A).
131
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
133
while not line or line[-1] != '\n':
134
new_char = self._get_bytes(1)
137
# Ran out of bytes before receiving a complete line.
258
raise NotImplementedError(self._read_bytes)
142
261
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
144
def __init__(self, sock, backing_transport):
263
def __init__(self, sock, backing_transport, root_client_path='/'):
147
266
:param sock: the socket the server will read from. It will be put
148
267
into blocking mode.
150
SmartServerStreamMedium.__init__(self, backing_transport)
269
SmartServerStreamMedium.__init__(
270
self, backing_transport, root_client_path=root_client_path)
152
271
sock.setblocking(True)
153
272
self.socket = sock
155
274
def _serve_one_request_unguarded(self, protocol):
156
275
while protocol.next_read_size():
158
protocol.accept_bytes(self.push_back)
161
bytes = self._get_bytes(4096)
165
protocol.accept_bytes(bytes)
167
self.push_back = protocol.excess_buffer
169
def _get_bytes(self, desired_count):
170
# We ignore the desired_count because on sockets it's more efficient to
172
return self.socket.recv(4096)
276
# We can safely try to read large chunks. If there is less data
277
# than _MAX_READ_SIZE ready, the socket wil just return a short
278
# read immediately rather than block.
279
bytes = self.read_bytes(_MAX_READ_SIZE)
283
protocol.accept_bytes(bytes)
285
self._push_back(protocol.unused_data)
287
def _read_bytes(self, desired_count):
288
return _read_bytes_from_socket(
289
self.socket.recv, desired_count, self._report_activity)
174
291
def terminate_due_to_error(self):
175
"""Called when an unhandled exception from the protocol occurs."""
176
292
# TODO: This should log to a server log file, but no such thing
177
293
# exists yet. Andrew Bennetts 2006-09-29.
294
osutils.until_no_eintr(self.socket.close)
179
295
self.finished = True
181
297
def _write_out(self, bytes):
182
self.socket.sendall(bytes)
298
osutils.send_all(self.socket, bytes, self._report_activity)
185
301
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
337
456
return self._read_bytes(count)
339
458
def _read_bytes(self, count):
340
"""Helper for read_bytes.
459
"""Helper for SmartClientMediumRequest.read_bytes.
342
461
read_bytes checks the state of the request to determing if bytes
343
462
should be read. After that it hands off to _read_bytes to do the
465
By default this forwards to self._medium.read_bytes because we are
466
operating on the medium's stream.
346
raise NotImplementedError(self._read_bytes)
468
return self._medium.read_bytes(count)
348
470
def read_line(self):
349
"""Read bytes from this request's response until a newline byte.
351
This isn't particularly efficient, so should only be used when the
352
expected size of the line is quite short.
354
:returns: a string of bytes ending in a newline (byte 0x0A).
356
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
358
while not line or line[-1] != '\n':
359
new_char = self.read_bytes(1)
362
raise errors.SmartProtocolError(
363
'unexpected end of file reading from server')
471
line = self._read_line()
472
if not line.endswith('\n'):
473
# end of file encountered reading from server
474
raise errors.ConnectionReset(
475
"Unexpected end of message. Please check connectivity "
476
"and permissions, and report a bug if problems persist.")
367
class SmartClientMedium(object):
479
def _read_line(self):
480
"""Helper for SmartClientMediumRequest.read_line.
482
By default this forwards to self._medium._get_line because we are
483
operating on the medium's stream.
485
return self._medium._get_line()
488
class _DebugCounter(object):
489
"""An object that counts the HPSS calls made to each client medium.
491
When a medium is garbage-collected, or failing that when atexit functions
492
are run, the total number of calls made on that medium are reported via
497
self.counts = weakref.WeakKeyDictionary()
498
client._SmartClient.hooks.install_named_hook(
499
'call', self.increment_call_count, 'hpss call counter')
500
atexit.register(self.flush_all)
502
def track(self, medium):
503
"""Start tracking calls made to a medium.
505
This only keeps a weakref to the medium, so shouldn't affect the
508
medium_repr = repr(medium)
509
# Add this medium to the WeakKeyDictionary
510
self.counts[medium] = dict(count=0, vfs_count=0,
511
medium_repr=medium_repr)
512
# Weakref callbacks are fired in reverse order of their association
513
# with the referenced object. So we add a weakref *after* adding to
514
# the WeakKeyDict so that we can report the value from it before the
515
# entry is removed by the WeakKeyDict's own callback.
516
ref = weakref.ref(medium, self.done)
518
def increment_call_count(self, params):
519
# Increment the count in the WeakKeyDictionary
520
value = self.counts[params.medium]
523
request_method = request.request_handlers.get(params.method)
525
# A method we don't know about doesn't count as a VFS method.
527
if issubclass(request_method, vfs.VfsRequest):
528
value['vfs_count'] += 1
531
value = self.counts[ref]
532
count, vfs_count, medium_repr = (
533
value['count'], value['vfs_count'], value['medium_repr'])
534
# In case this callback is invoked for the same ref twice (by the
535
# weakref callback and by the atexit function), set the call count back
536
# to 0 so this item won't be reported twice.
538
value['vfs_count'] = 0
540
trace.note('HPSS calls: %d (%d vfs) %s',
541
count, vfs_count, medium_repr)
544
for ref in list(self.counts.keys()):
547
_debug_counter = None
550
class SmartClientMedium(SmartMedium):
368
551
"""Smart client is a medium for sending smart protocol requests over."""
553
def __init__(self, base):
554
super(SmartClientMedium, self).__init__()
556
self._protocol_version_error = None
557
self._protocol_version = None
558
self._done_hello = False
559
# Be optimistic: we assume the remote end can accept new remote
560
# requests until we get an error saying otherwise.
561
# _remote_version_is_before tracks the bzr version the remote side
562
# can be based on what we've seen so far.
563
self._remote_version_is_before = None
564
# Install debug hook function if debug flag is set.
565
if 'hpss' in debug.debug_flags:
566
global _debug_counter
567
if _debug_counter is None:
568
_debug_counter = _DebugCounter()
569
_debug_counter.track(self)
571
def _is_remote_before(self, version_tuple):
572
"""Is it possible the remote side supports RPCs for a given version?
576
needed_version = (1, 2)
577
if medium._is_remote_before(needed_version):
578
fallback_to_pre_1_2_rpc()
582
except UnknownSmartMethod:
583
medium._remember_remote_is_before(needed_version)
584
fallback_to_pre_1_2_rpc()
586
:seealso: _remember_remote_is_before
588
if self._remote_version_is_before is None:
589
# So far, the remote side seems to support everything
591
return version_tuple >= self._remote_version_is_before
593
def _remember_remote_is_before(self, version_tuple):
594
"""Tell this medium that the remote side is older the given version.
596
:seealso: _is_remote_before
598
if (self._remote_version_is_before is not None and
599
version_tuple > self._remote_version_is_before):
600
# We have been told that the remote side is older than some version
601
# which is newer than a previously supplied older-than version.
602
# This indicates that some smart verb call is not guarded
603
# appropriately (it should simply not have been tried).
604
raise AssertionError(
605
"_remember_remote_is_before(%r) called, but "
606
"_remember_remote_is_before(%r) was called previously."
607
% (version_tuple, self._remote_version_is_before))
608
self._remote_version_is_before = version_tuple
610
def protocol_version(self):
611
"""Find out if 'hello' smart request works."""
612
if self._protocol_version_error is not None:
613
raise self._protocol_version_error
614
if not self._done_hello:
616
medium_request = self.get_request()
617
# Send a 'hello' request in protocol version one, for maximum
618
# backwards compatibility.
619
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
620
client_protocol.query_version()
621
self._done_hello = True
622
except errors.SmartProtocolError, e:
623
# Cache the error, just like we would cache a successful
625
self._protocol_version_error = e
629
def should_probe(self):
630
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
633
Some transports are unambiguously smart-only; there's no need to check
634
if the transport is able to carry smart requests, because that's all
635
it is for. In those cases, this method should return False.
637
But some HTTP transports can sometimes fail to carry smart requests,
638
but still be usuable for accessing remote bzrdirs via plain file
639
accesses. So for those transports, their media should return True here
640
so that RemoteBzrDirFormat can determine if it is appropriate for that
370
645
def disconnect(self):
371
646
"""If this medium maintains a persistent connection, close it.
373
648
The default implementation does nothing.
651
def remote_path_from_transport(self, transport):
652
"""Convert transport into a path suitable for using in a request.
654
Note that the resulting remote path doesn't encode the host name or
655
anything but path, so it is only safe to use it in requests sent over
656
the medium from the matching transport.
658
medium_base = urlutils.join(self.base, '/')
659
rel_url = urlutils.relative_url(medium_base, transport.base)
660
return urllib.unquote(rel_url)
377
663
class SmartClientStreamMedium(SmartClientMedium):
378
664
"""Stream based medium common class.
412
699
return SmartClientStreamMediumRequest(self)
414
def read_bytes(self, count):
415
return self._read_bytes(count)
418
702
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
419
703
"""A client medium using simple pipes.
421
705
This client does not manage the pipes: it assumes they will always be open.
424
def __init__(self, readable_pipe, writeable_pipe):
425
SmartClientStreamMedium.__init__(self)
708
def __init__(self, readable_pipe, writeable_pipe, base):
709
SmartClientStreamMedium.__init__(self, base)
426
710
self._readable_pipe = readable_pipe
427
711
self._writeable_pipe = writeable_pipe
429
713
def _accept_bytes(self, bytes):
430
714
"""See SmartClientStreamMedium.accept_bytes."""
431
self._writeable_pipe.write(bytes)
715
osutils.until_no_eintr(self._writeable_pipe.write, bytes)
716
self._report_activity(len(bytes), 'write')
433
718
def _flush(self):
434
719
"""See SmartClientStreamMedium._flush()."""
435
self._writeable_pipe.flush()
720
osutils.until_no_eintr(self._writeable_pipe.flush)
437
722
def _read_bytes(self, count):
438
723
"""See SmartClientStreamMedium._read_bytes."""
439
return self._readable_pipe.read(count)
724
bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
725
self._report_activity(len(bytes), 'read')
442
729
class SmartSSHClientMedium(SmartClientStreamMedium):
443
730
"""A client medium using SSH."""
445
732
def __init__(self, host, port=None, username=None, password=None,
733
base=None, vendor=None, bzr_remote_path=None):
447
734
"""Creates a client that will connect on the first use.
449
736
:param vendor: An optional override for the ssh vendor to use. See
450
737
bzrlib.transport.ssh for details on ssh vendors.
452
SmartClientStreamMedium.__init__(self)
453
739
self._connected = False
454
740
self._host = host
455
741
self._password = password
456
742
self._port = port
457
743
self._username = username
744
# SmartClientStreamMedium stores the repr of this object in its
745
# _DebugCounter so we have to store all the values used in our repr
746
# method before calling the super init.
747
SmartClientStreamMedium.__init__(self, base)
458
748
self._read_from = None
459
749
self._ssh_connection = None
460
750
self._vendor = vendor
461
751
self._write_to = None
752
self._bzr_remote_path = bzr_remote_path
753
# for the benefit of progress making a short description of this
755
self._scheme = 'bzr+ssh'
758
return "%s(connected=%r, username=%r, host=%r, port=%r)" % (
759
self.__class__.__name__,
463
765
def _accept_bytes(self, bytes):
464
766
"""See SmartClientStreamMedium.accept_bytes."""
465
767
self._ensure_connection()
466
self._write_to.write(bytes)
768
osutils.until_no_eintr(self._write_to.write, bytes)
769
self._report_activity(len(bytes), 'write')
468
771
def disconnect(self):
469
772
"""See SmartClientMedium.disconnect()."""
470
773
if not self._connected:
472
self._read_from.close()
473
self._write_to.close()
775
osutils.until_no_eintr(self._read_from.close)
776
osutils.until_no_eintr(self._write_to.close)
474
777
self._ssh_connection.close()
475
778
self._connected = False
530
840
"""Connect this medium if not already connected."""
531
841
if self._connected:
533
self._socket = socket.socket()
534
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
535
result = self._socket.connect_ex((self._host, int(self._port)))
843
if self._port is None:
844
port = BZR_DEFAULT_PORT
846
port = int(self._port)
848
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
849
socket.SOCK_STREAM, 0, 0)
850
except socket.gaierror, (err_num, err_msg):
851
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
852
(self._host, port, err_msg))
853
# Initialize err in case there are no addresses returned:
854
err = socket.error("no address found for %s" % self._host)
855
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
857
self._socket = socket.socket(family, socktype, proto)
858
self._socket.setsockopt(socket.IPPROTO_TCP,
859
socket.TCP_NODELAY, 1)
860
self._socket.connect(sockaddr)
861
except socket.error, err:
862
if self._socket is not None:
867
if self._socket is None:
868
# socket errors either have a (string) or (errno, string) as their
870
if type(err.args) is str:
873
err_msg = err.args[1]
537
874
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
538
(self._host, self._port, os.strerror(result)))
875
(self._host, port, err_msg))
539
876
self._connected = True
541
878
def _flush(self):
542
879
"""See SmartClientStreamMedium._flush().
544
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
881
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
545
882
add a means to do a flush, but that can be done in the future.