32
from bzrlib.lazy_import import lazy_import
33
lazy_import(globals(), """
45
from bzrlib.smart import client, protocol, request, vfs
46
from bzrlib.transport import ssh
48
from bzrlib import osutils
50
# Throughout this module buffer size parameters are either limited to be at
51
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
52
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
53
# from non-sockets as well.
54
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
56
def _get_protocol_factory_for_bytes(bytes):
57
"""Determine the right protocol factory for 'bytes'.
59
This will return an appropriate protocol factory depending on the version
60
of the protocol being used, as determined by inspecting the given bytes.
61
The bytes should have at least one newline byte (i.e. be a whole line),
62
otherwise it's possible that a request will be incorrectly identified as
65
Typical use would be::
67
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
68
server_protocol = factory(transport, write_func, root_client_path)
69
server_protocol.accept_bytes(unused_bytes)
71
:param bytes: a str of bytes of the start of the request.
72
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
73
a callable that takes three args: transport, write_func,
74
root_client_path. unused_bytes are any bytes that were not part of a
75
protocol version marker.
77
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
78
protocol_factory = protocol.build_server_protocol_three
79
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
80
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
81
protocol_factory = protocol.SmartServerRequestProtocolTwo
82
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
84
protocol_factory = protocol.SmartServerRequestProtocolOne
85
return protocol_factory, bytes
88
def _get_line(read_bytes_func):
89
"""Read bytes using read_bytes_func until a newline byte.
91
This isn't particularly efficient, so should only be used when the
92
expected size of the line is quite short.
94
:returns: a tuple of two strs: (line, excess)
98
while newline_pos == -1:
99
new_bytes = read_bytes_func(1)
102
# Ran out of bytes before receiving a complete line.
104
newline_pos = bytes.find('\n')
105
line = bytes[:newline_pos+1]
106
excess = bytes[newline_pos+1:]
110
class SmartMedium(object):
111
"""Base class for smart protocol media, both client- and server-side."""
114
self._push_back_buffer = None
116
def _push_back(self, bytes):
117
"""Return unused bytes to the medium, because they belong to the next
120
This sets the _push_back_buffer to the given bytes.
122
if self._push_back_buffer is not None:
123
raise AssertionError(
124
"_push_back called when self._push_back_buffer is %r"
125
% (self._push_back_buffer,))
128
self._push_back_buffer = bytes
130
def _get_push_back_buffer(self):
131
if self._push_back_buffer == '':
132
raise AssertionError(
133
'%s._push_back_buffer should never be the empty string, '
134
'which can be confused with EOF' % (self,))
135
bytes = self._push_back_buffer
136
self._push_back_buffer = None
139
def read_bytes(self, desired_count):
140
"""Read some bytes from this medium.
142
:returns: some bytes, possibly more or less than the number requested
143
in 'desired_count' depending on the medium.
145
if self._push_back_buffer is not None:
146
return self._get_push_back_buffer()
147
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
148
return self._read_bytes(bytes_to_read)
150
def _read_bytes(self, count):
151
raise NotImplementedError(self._read_bytes)
154
"""Read bytes from this request's response until a newline byte.
156
This isn't particularly efficient, so should only be used when the
157
expected size of the line is quite short.
159
:returns: a string of bytes ending in a newline (byte 0x0A).
161
line, excess = _get_line(self.read_bytes)
162
self._push_back(excess)
165
def _report_activity(self, bytes, direction):
166
"""Notify that this medium has activity.
168
Implementations should call this from all methods that actually do IO.
169
Be careful that it's not called twice, if one method is implemented on
172
:param bytes: Number of bytes read or written.
173
:param direction: 'read' or 'write' or None.
175
ui.ui_factory.report_transport_activity(self, bytes, direction)
178
class SmartServerStreamMedium(SmartMedium):
30
from bzrlib import errors
31
from bzrlib.smart.protocol import SmartServerRequestProtocolOne
34
from bzrlib.transport import ssh
35
except errors.ParamikoNotPresent:
36
# no paramiko. SmartSSHClientMedium will break.
40
class SmartServerStreamMedium(object):
179
41
"""Handles smart commands coming over a stream.
181
43
The stream may be a pipe connected to sshd, or a tcp socket, or an
249
89
"""Called when an unhandled exception from the protocol occurs."""
250
90
raise NotImplementedError(self.terminate_due_to_error)
252
def _read_bytes(self, desired_count):
253
"""Get some bytes from the medium.
255
:param desired_count: number of bytes we want to read.
257
raise NotImplementedError(self._read_bytes)
260
93
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
def __init__(self, sock, backing_transport, root_client_path='/'):
95
def __init__(self, sock, backing_transport):
265
98
:param sock: the socket the server will read from. It will be put
266
99
into blocking mode.
268
SmartServerStreamMedium.__init__(
269
self, backing_transport, root_client_path=root_client_path)
101
SmartServerStreamMedium.__init__(self, backing_transport)
270
103
sock.setblocking(True)
271
104
self.socket = sock
273
106
def _serve_one_request_unguarded(self, protocol):
274
107
while protocol.next_read_size():
275
# We can safely try to read large chunks. If there is less data
276
# than MAX_SOCKET_CHUNK ready, the socket will just return a
277
# short read immediately rather than block.
278
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
282
protocol.accept_bytes(bytes)
284
self._push_back(protocol.unused_data)
286
def _read_bytes(self, desired_count):
287
return osutils.read_bytes_from_socket(
288
self.socket, self._report_activity)
109
protocol.accept_bytes(self.push_back)
112
bytes = self.socket.recv(4096)
116
protocol.accept_bytes(bytes)
118
self.push_back = protocol.excess_buffer
290
120
def terminate_due_to_error(self):
121
"""Called when an unhandled exception from the protocol occurs."""
291
122
# TODO: This should log to a server log file, but no such thing
292
123
# exists yet. Andrew Bennetts 2006-09-29.
293
124
self.socket.close()
294
125
self.finished = True
296
127
def _write_out(self, bytes):
297
tstart = osutils.timer_func()
298
osutils.send_all(self.socket, bytes, self._report_activity)
299
if 'hpss' in debug.debug_flags:
300
thread_id = thread.get_ident()
301
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
302
% ('wrote', thread_id, len(bytes),
303
osutils.timer_func() - tstart))
128
self.socket.sendall(bytes)
306
131
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
461
280
return self._read_bytes(count)
463
282
def _read_bytes(self, count):
464
"""Helper for SmartClientMediumRequest.read_bytes.
283
"""Helper for read_bytes.
466
285
read_bytes checks the state of the request to determing if bytes
467
286
should be read. After that it hands off to _read_bytes to do the
470
By default this forwards to self._medium.read_bytes because we are
471
operating on the medium's stream.
473
return self._medium.read_bytes(count)
476
line = self._read_line()
477
if not line.endswith('\n'):
478
# end of file encountered reading from server
479
raise errors.ConnectionReset(
480
"Unexpected end of message. Please check connectivity "
481
"and permissions, and report a bug if problems persist.")
484
def _read_line(self):
485
"""Helper for SmartClientMediumRequest.read_line.
487
By default this forwards to self._medium._get_line because we are
488
operating on the medium's stream.
490
return self._medium._get_line()
493
class _VfsRefuser(object):
494
"""An object that refuses all VFS requests.
499
client._SmartClient.hooks.install_named_hook(
500
'call', self.check_vfs, 'vfs refuser')
502
def check_vfs(self, params):
504
request_method = request.request_handlers.get(params.method)
506
# A method we don't know about doesn't count as a VFS method.
508
if issubclass(request_method, vfs.VfsRequest):
509
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
512
class _DebugCounter(object):
513
"""An object that counts the HPSS calls made to each client medium.
515
When a medium is garbage-collected, or failing that when
516
bzrlib.global_state exits, the total number of calls made on that medium
517
are reported via trace.note.
521
self.counts = weakref.WeakKeyDictionary()
522
client._SmartClient.hooks.install_named_hook(
523
'call', self.increment_call_count, 'hpss call counter')
524
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
526
def track(self, medium):
527
"""Start tracking calls made to a medium.
529
This only keeps a weakref to the medium, so shouldn't affect the
532
medium_repr = repr(medium)
533
# Add this medium to the WeakKeyDictionary
534
self.counts[medium] = dict(count=0, vfs_count=0,
535
medium_repr=medium_repr)
536
# Weakref callbacks are fired in reverse order of their association
537
# with the referenced object. So we add a weakref *after* adding to
538
# the WeakKeyDict so that we can report the value from it before the
539
# entry is removed by the WeakKeyDict's own callback.
540
ref = weakref.ref(medium, self.done)
542
def increment_call_count(self, params):
543
# Increment the count in the WeakKeyDictionary
544
value = self.counts[params.medium]
547
request_method = request.request_handlers.get(params.method)
549
# A method we don't know about doesn't count as a VFS method.
551
if issubclass(request_method, vfs.VfsRequest):
552
value['vfs_count'] += 1
555
value = self.counts[ref]
556
count, vfs_count, medium_repr = (
557
value['count'], value['vfs_count'], value['medium_repr'])
558
# In case this callback is invoked for the same ref twice (by the
559
# weakref callback and by the atexit function), set the call count back
560
# to 0 so this item won't be reported twice.
562
value['vfs_count'] = 0
564
trace.note('HPSS calls: %d (%d vfs) %s',
565
count, vfs_count, medium_repr)
568
for ref in list(self.counts.keys()):
571
_debug_counter = None
575
class SmartClientMedium(SmartMedium):
289
raise NotImplementedError(self._read_bytes)
292
class SmartClientMedium(object):
576
293
"""Smart client is a medium for sending smart protocol requests over."""
578
def __init__(self, base):
579
super(SmartClientMedium, self).__init__()
581
self._protocol_version_error = None
582
self._protocol_version = None
583
self._done_hello = False
584
# Be optimistic: we assume the remote end can accept new remote
585
# requests until we get an error saying otherwise.
586
# _remote_version_is_before tracks the bzr version the remote side
587
# can be based on what we've seen so far.
588
self._remote_version_is_before = None
589
# Install debug hook function if debug flag is set.
590
if 'hpss' in debug.debug_flags:
591
global _debug_counter
592
if _debug_counter is None:
593
_debug_counter = _DebugCounter()
594
_debug_counter.track(self)
595
if 'hpss_client_no_vfs' in debug.debug_flags:
597
if _vfs_refuser is None:
598
_vfs_refuser = _VfsRefuser()
600
def _is_remote_before(self, version_tuple):
601
"""Is it possible the remote side supports RPCs for a given version?
605
needed_version = (1, 2)
606
if medium._is_remote_before(needed_version):
607
fallback_to_pre_1_2_rpc()
611
except UnknownSmartMethod:
612
medium._remember_remote_is_before(needed_version)
613
fallback_to_pre_1_2_rpc()
615
:seealso: _remember_remote_is_before
617
if self._remote_version_is_before is None:
618
# So far, the remote side seems to support everything
620
return version_tuple >= self._remote_version_is_before
622
def _remember_remote_is_before(self, version_tuple):
623
"""Tell this medium that the remote side is older the given version.
625
:seealso: _is_remote_before
627
if (self._remote_version_is_before is not None and
628
version_tuple > self._remote_version_is_before):
629
# We have been told that the remote side is older than some version
630
# which is newer than a previously supplied older-than version.
631
# This indicates that some smart verb call is not guarded
632
# appropriately (it should simply not have been tried).
634
"_remember_remote_is_before(%r) called, but "
635
"_remember_remote_is_before(%r) was called previously."
636
, version_tuple, self._remote_version_is_before)
637
if 'hpss' in debug.debug_flags:
638
ui.ui_factory.show_warning(
639
"_remember_remote_is_before(%r) called, but "
640
"_remember_remote_is_before(%r) was called previously."
641
% (version_tuple, self._remote_version_is_before))
643
self._remote_version_is_before = version_tuple
645
def protocol_version(self):
646
"""Find out if 'hello' smart request works."""
647
if self._protocol_version_error is not None:
648
raise self._protocol_version_error
649
if not self._done_hello:
651
medium_request = self.get_request()
652
# Send a 'hello' request in protocol version one, for maximum
653
# backwards compatibility.
654
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
655
client_protocol.query_version()
656
self._done_hello = True
657
except errors.SmartProtocolError, e:
658
# Cache the error, just like we would cache a successful
660
self._protocol_version_error = e
664
def should_probe(self):
665
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
668
Some transports are unambiguously smart-only; there's no need to check
669
if the transport is able to carry smart requests, because that's all
670
it is for. In those cases, this method should return False.
672
But some HTTP transports can sometimes fail to carry smart requests,
673
but still be usuable for accessing remote bzrdirs via plain file
674
accesses. So for those transports, their media should return True here
675
so that RemoteBzrDirFormat can determine if it is appropriate for that
680
295
def disconnect(self):
681
296
"""If this medium maintains a persistent connection, close it.
683
298
The default implementation does nothing.
686
def remote_path_from_transport(self, transport):
687
"""Convert transport into a path suitable for using in a request.
689
Note that the resulting remote path doesn't encode the host name or
690
anything but path, so it is only safe to use it in requests sent over
691
the medium from the matching transport.
693
medium_base = urlutils.join(self.base, '/')
694
rel_url = urlutils.relative_url(medium_base, transport.base)
695
return urllib.unquote(rel_url)
698
302
class SmartClientStreamMedium(SmartClientMedium):
699
303
"""Stream based medium common class.
757
362
def _read_bytes(self, count):
758
363
"""See SmartClientStreamMedium._read_bytes."""
759
bytes_to_read = min(count, _MAX_READ_SIZE)
760
bytes = self._readable_pipe.read(bytes_to_read)
761
self._report_activity(len(bytes), 'read')
765
class SSHParams(object):
766
"""A set of parameters for starting a remote bzr via SSH."""
364
return self._readable_pipe.read(count)
367
class SmartSSHClientMedium(SmartClientStreamMedium):
368
"""A client medium using SSH."""
768
370
def __init__(self, host, port=None, username=None, password=None,
769
bzr_remote_path='bzr'):
772
self.username = username
773
self.password = password
774
self.bzr_remote_path = bzr_remote_path
777
class SmartSSHClientMedium(SmartClientStreamMedium):
778
"""A client medium using SSH.
780
It delegates IO to a SmartClientSocketMedium or
781
SmartClientAlreadyConnectedSocketMedium (depending on platform).
784
def __init__(self, base, ssh_params, vendor=None):
785
372
"""Creates a client that will connect on the first use.
787
:param ssh_params: A SSHParams instance.
788
374
:param vendor: An optional override for the ssh vendor to use. See
789
375
bzrlib.transport.ssh for details on ssh vendors.
791
self._real_medium = None
792
self._ssh_params = ssh_params
793
# for the benefit of progress making a short description of this
795
self._scheme = 'bzr+ssh'
796
# SmartClientStreamMedium stores the repr of this object in its
797
# _DebugCounter so we have to store all the values used in our repr
798
# method before calling the super init.
799
SmartClientStreamMedium.__init__(self, base)
377
SmartClientStreamMedium.__init__(self)
378
self._connected = False
380
self._password = password
382
self._username = username
383
self._read_from = None
384
self._ssh_connection = None
800
385
self._vendor = vendor
801
self._ssh_connection = None
804
if self._ssh_params.port is None:
807
maybe_port = ':%s' % self._ssh_params.port
808
return "%s(%s://%s@%s%s/)" % (
809
self.__class__.__name__,
811
self._ssh_params.username,
812
self._ssh_params.host,
386
self._write_to = None
815
388
def _accept_bytes(self, bytes):
816
389
"""See SmartClientStreamMedium.accept_bytes."""
817
390
self._ensure_connection()
818
self._real_medium.accept_bytes(bytes)
391
self._write_to.write(bytes)
820
393
def disconnect(self):
821
394
"""See SmartClientMedium.disconnect()."""
822
if self._real_medium is not None:
823
self._real_medium.disconnect()
824
self._real_medium = None
825
if self._ssh_connection is not None:
826
self._ssh_connection.close()
827
self._ssh_connection = None
395
if not self._connected:
397
self._read_from.close()
398
self._write_to.close()
399
self._ssh_connection.close()
400
self._connected = False
829
402
def _ensure_connection(self):
830
403
"""Connect this medium if not already connected."""
831
if self._real_medium is not None:
406
executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
833
407
if self._vendor is None:
834
408
vendor = ssh._get_ssh_vendor()
836
410
vendor = self._vendor
837
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
838
self._ssh_params.password, self._ssh_params.host,
839
self._ssh_params.port,
840
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
841
'--directory=/', '--allow-writes'])
842
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
843
if io_kind == 'socket':
844
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
845
self.base, io_object)
846
elif io_kind == 'pipes':
847
read_from, write_to = io_object
848
self._real_medium = SmartSimplePipesClientMedium(
849
read_from, write_to, self.base)
851
raise AssertionError(
852
"Unexpected io_kind %r from %r"
853
% (io_kind, self._ssh_connection))
411
self._ssh_connection = vendor.connect_ssh(self._username,
412
self._password, self._host, self._port,
413
command=[executable, 'serve', '--inet', '--directory=/',
415
self._read_from, self._write_to = \
416
self._ssh_connection.get_filelike_channels()
417
self._connected = True
855
419
def _flush(self):
856
420
"""See SmartClientStreamMedium._flush()."""
857
self._real_medium._flush()
421
self._write_to.flush()
859
423
def _read_bytes(self, count):
860
424
"""See SmartClientStreamMedium.read_bytes."""
861
if self._real_medium is None:
425
if not self._connected:
862
426
raise errors.MediumNotConnected(self)
863
return self._real_medium.read_bytes(count)
866
# Port 4155 is the default port for bzr://, registered with IANA.
867
BZR_DEFAULT_INTERFACE = None
868
BZR_DEFAULT_PORT = 4155
871
class SmartClientSocketMedium(SmartClientStreamMedium):
872
"""A client medium using a socket.
427
return self._read_from.read(count)
430
class SmartTCPClientMedium(SmartClientStreamMedium):
431
"""A client medium using TCP."""
874
This class isn't usable directly. Use one of its subclasses instead.
877
def __init__(self, base):
878
SmartClientStreamMedium.__init__(self, base)
433
def __init__(self, host, port):
434
"""Creates a client that will connect on the first use."""
435
SmartClientStreamMedium.__init__(self)
436
self._connected = False
879
439
self._socket = None
880
self._connected = False
882
441
def _accept_bytes(self, bytes):
883
442
"""See SmartClientMedium.accept_bytes."""
884
443
self._ensure_connection()
885
osutils.send_all(self._socket, bytes, self._report_activity)
444
self._socket.sendall(bytes)
446
def disconnect(self):
447
"""See SmartClientMedium.disconnect()."""
448
if not self._connected:
452
self._connected = False
887
454
def _ensure_connection(self):
888
455
"""Connect this medium if not already connected."""
889
raise NotImplementedError(self._ensure_connection)
458
self._socket = socket.socket()
459
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
460
result = self._socket.connect_ex((self._host, int(self._port)))
462
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
463
(self._host, self._port, os.strerror(result)))
464
self._connected = True
891
466
def _flush(self):
892
467
"""See SmartClientStreamMedium._flush().
894
For sockets we do no flushing. For TCP sockets we may want to turn off
895
TCP_NODELAY and add a means to do a flush, but that can be done in the
469
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
470
add a means to do a flush, but that can be done in the future.
899
473
def _read_bytes(self, count):
900
474
"""See SmartClientMedium.read_bytes."""
901
475
if not self._connected:
902
476
raise errors.MediumNotConnected(self)
903
return osutils.read_bytes_from_socket(
904
self._socket, self._report_activity)
906
def disconnect(self):
907
"""See SmartClientMedium.disconnect()."""
908
if not self._connected:
912
self._connected = False
915
class SmartTCPClientMedium(SmartClientSocketMedium):
916
"""A client medium that creates a TCP connection."""
918
def __init__(self, host, port, base):
919
"""Creates a client that will connect on the first use."""
920
SmartClientSocketMedium.__init__(self, base)
924
def _ensure_connection(self):
925
"""Connect this medium if not already connected."""
928
if self._port is None:
929
port = BZR_DEFAULT_PORT
931
port = int(self._port)
933
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
934
socket.SOCK_STREAM, 0, 0)
935
except socket.gaierror, (err_num, err_msg):
936
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
937
(self._host, port, err_msg))
938
# Initialize err in case there are no addresses returned:
939
err = socket.error("no address found for %s" % self._host)
940
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
942
self._socket = socket.socket(family, socktype, proto)
943
self._socket.setsockopt(socket.IPPROTO_TCP,
944
socket.TCP_NODELAY, 1)
945
self._socket.connect(sockaddr)
946
except socket.error, err:
947
if self._socket is not None:
952
if self._socket is None:
953
# socket errors either have a (string) or (errno, string) as their
955
if type(err.args) is str:
958
err_msg = err.args[1]
959
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
960
(self._host, port, err_msg))
961
self._connected = True
964
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
965
"""A client medium for an already connected socket.
967
Note that this class will assume it "owns" the socket, so it will close it
968
when its disconnect method is called.
971
def __init__(self, base, sock):
972
SmartClientSocketMedium.__init__(self, base)
974
self._connected = True
976
def _ensure_connection(self):
977
# Already connected, by definition! So nothing to do.
477
return self._socket.recv(count)
981
480
class SmartClientStreamMediumRequest(SmartClientMediumRequest):