24
24
bzrlib/transport/smart/__init__.py.
27
from __future__ import absolute_import
35
from bzrlib.lazy_import import lazy_import
36
lazy_import(globals(), """
31
from bzrlib import errors
32
from bzrlib.smart.protocol import (
34
SmartServerRequestProtocolOne,
35
SmartServerRequestProtocolTwo,
49
from bzrlib.i18n import gettext
50
from bzrlib.smart import client, protocol, request, signals, vfs
51
from bzrlib.transport import ssh
53
from bzrlib import osutils
55
# Throughout this module buffer size parameters are either limited to be at
56
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
57
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
58
# from non-sockets as well.
59
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
61
def _get_protocol_factory_for_bytes(bytes):
62
"""Determine the right protocol factory for 'bytes'.
64
This will return an appropriate protocol factory depending on the version
65
of the protocol being used, as determined by inspecting the given bytes.
66
The bytes should have at least one newline byte (i.e. be a whole line),
67
otherwise it's possible that a request will be incorrectly identified as
70
Typical use would be::
72
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
73
server_protocol = factory(transport, write_func, root_client_path)
74
server_protocol.accept_bytes(unused_bytes)
76
:param bytes: a str of bytes of the start of the request.
77
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
78
a callable that takes three args: transport, write_func,
79
root_client_path. unused_bytes are any bytes that were not part of a
80
protocol version marker.
82
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
83
protocol_factory = protocol.build_server_protocol_three
84
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
85
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
86
protocol_factory = protocol.SmartServerRequestProtocolTwo
87
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
89
protocol_factory = protocol.SmartServerRequestProtocolOne
90
return protocol_factory, bytes
93
def _get_line(read_bytes_func):
94
"""Read bytes using read_bytes_func until a newline byte.
96
This isn't particularly efficient, so should only be used when the
97
expected size of the line is quite short.
99
:returns: a tuple of two strs: (line, excess)
103
while newline_pos == -1:
104
new_bytes = read_bytes_func(1)
107
# Ran out of bytes before receiving a complete line.
109
newline_pos = bytes.find('\n')
110
line = bytes[:newline_pos+1]
111
excess = bytes[newline_pos+1:]
115
class SmartMedium(object):
116
"""Base class for smart protocol media, both client- and server-side."""
119
self._push_back_buffer = None
121
def _push_back(self, bytes):
122
"""Return unused bytes to the medium, because they belong to the next
125
This sets the _push_back_buffer to the given bytes.
127
if self._push_back_buffer is not None:
128
raise AssertionError(
129
"_push_back called when self._push_back_buffer is %r"
130
% (self._push_back_buffer,))
133
self._push_back_buffer = bytes
135
def _get_push_back_buffer(self):
136
if self._push_back_buffer == '':
137
raise AssertionError(
138
'%s._push_back_buffer should never be the empty string, '
139
'which can be confused with EOF' % (self,))
140
bytes = self._push_back_buffer
141
self._push_back_buffer = None
144
def read_bytes(self, desired_count):
145
"""Read some bytes from this medium.
147
:returns: some bytes, possibly more or less than the number requested
148
in 'desired_count' depending on the medium.
150
if self._push_back_buffer is not None:
151
return self._get_push_back_buffer()
152
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
153
return self._read_bytes(bytes_to_read)
155
def _read_bytes(self, count):
156
raise NotImplementedError(self._read_bytes)
159
"""Read bytes from this request's response until a newline byte.
161
This isn't particularly efficient, so should only be used when the
162
expected size of the line is quite short.
164
:returns: a string of bytes ending in a newline (byte 0x0A).
166
line, excess = _get_line(self.read_bytes)
167
self._push_back(excess)
170
def _report_activity(self, bytes, direction):
171
"""Notify that this medium has activity.
173
Implementations should call this from all methods that actually do IO.
174
Be careful that it's not called twice, if one method is implemented on
177
:param bytes: Number of bytes read or written.
178
:param direction: 'read' or 'write' or None.
180
ui.ui_factory.report_transport_activity(self, bytes, direction)
183
_bad_file_descriptor = (errno.EBADF,)
184
if sys.platform == 'win32':
185
# Given on Windows if you pass a closed socket to select.select. Probably
186
# also given if you pass a file handle to select.
188
_bad_file_descriptor += (WSAENOTSOCK,)
191
class SmartServerStreamMedium(SmartMedium):
39
from bzrlib.transport import ssh
40
except errors.ParamikoNotPresent:
41
# no paramiko. SmartSSHClientMedium will break.
45
class SmartServerStreamMedium(object):
192
46
"""Handles smart commands coming over a stream.
194
48
The stream may be a pipe connected to sshd, or a tcp socket, or an
276
87
:returns: a SmartServerRequestProtocol.
278
self._wait_for_bytes_with_timeout(self._client_timeout)
280
# We're stopping, so don't try to do any more work
89
# Identify the protocol version.
282
90
bytes = self._get_line()
283
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
284
protocol = protocol_factory(
285
self.backing_transport, self._write_out, self.root_client_path)
286
protocol.accept_bytes(unused_bytes)
91
if bytes.startswith(REQUEST_VERSION_TWO):
92
protocol_class = SmartServerRequestProtocolTwo
93
bytes = bytes[len(REQUEST_VERSION_TWO):]
95
protocol_class = SmartServerRequestProtocolOne
96
protocol = protocol_class(self.backing_transport, self._write_out)
97
protocol.accept_bytes(bytes)
289
def _wait_on_descriptor(self, fd, timeout_seconds):
290
"""select() on a file descriptor, waiting for nonblocking read()
292
This will raise a ConnectionTimeout exception if we do not get a
293
readable handle before timeout_seconds.
296
t_end = self._timer() + timeout_seconds
297
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
299
while not rs and not xs and self._timer() < t_end:
303
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
304
except (select.error, socket.error) as e:
305
err = getattr(e, 'errno', None)
306
if err is None and getattr(e, 'args', None) is not None:
307
# select.error doesn't have 'errno', it just has args[0]
309
if err in _bad_file_descriptor:
310
return # Not a socket indicates read() will fail
311
elif err == errno.EINTR:
312
# Interrupted, keep looping.
317
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
318
% (timeout_seconds,))
320
100
def _serve_one_request(self, protocol):
321
101
"""Read one request from input, process, send back a response.
323
103
:param protocol: a SmartServerRequestProtocol.
328
106
self._serve_one_request_unguarded(protocol)
329
107
except KeyboardInterrupt:
335
113
"""Called when an unhandled exception from the protocol occurs."""
336
114
raise NotImplementedError(self.terminate_due_to_error)
338
def _read_bytes(self, desired_count):
116
def _get_bytes(self, desired_count):
339
117
"""Get some bytes from the medium.
341
119
:param desired_count: number of bytes we want to read.
343
raise NotImplementedError(self._read_bytes)
121
raise NotImplementedError(self._get_bytes)
124
"""Read bytes from this request's response until a newline byte.
126
This isn't particularly efficient, so should only be used when the
127
expected size of the line is quite short.
129
:returns: a string of bytes ending in a newline (byte 0x0A).
131
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
133
while not line or line[-1] != '\n':
134
new_char = self._get_bytes(1)
137
# Ran out of bytes before receiving a complete line.
346
142
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
348
def __init__(self, sock, backing_transport, root_client_path='/',
144
def __init__(self, sock, backing_transport):
352
147
:param sock: the socket the server will read from. It will be put
353
148
into blocking mode.
355
SmartServerStreamMedium.__init__(
356
self, backing_transport, root_client_path=root_client_path,
150
SmartServerStreamMedium.__init__(self, backing_transport)
358
152
sock.setblocking(True)
359
153
self.socket = sock
360
# Get the getpeername now, as we might be closed later when we care.
362
self._client_info = sock.getpeername()
364
self._client_info = '<unknown>'
367
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
370
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
373
155
def _serve_one_request_unguarded(self, protocol):
374
156
while protocol.next_read_size():
375
# We can safely try to read large chunks. If there is less data
376
# than MAX_SOCKET_CHUNK ready, the socket will just return a
377
# short read immediately rather than block.
378
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
382
protocol.accept_bytes(bytes)
384
self._push_back(protocol.unused_data)
386
def _disconnect_client(self):
387
"""Close the current connection. We stopped due to a timeout/etc."""
390
def _wait_for_bytes_with_timeout(self, timeout_seconds):
391
"""Wait for more bytes to be read, but timeout if none available.
393
This allows us to detect idle connections, and stop trying to read from
394
them, without setting the socket itself to non-blocking. This also
395
allows us to specify when we watch for idle timeouts.
397
:return: None, this will raise ConnectionTimeout if we time out before
400
return self._wait_on_descriptor(self.socket, timeout_seconds)
402
def _read_bytes(self, desired_count):
403
return osutils.read_bytes_from_socket(
404
self.socket, self._report_activity)
158
protocol.accept_bytes(self.push_back)
161
bytes = self._get_bytes(4096)
165
protocol.accept_bytes(bytes)
167
self.push_back = protocol.excess_buffer
169
def _get_bytes(self, desired_count):
170
# We ignore the desired_count because on sockets it's more efficient to
172
return self.socket.recv(4096)
406
174
def terminate_due_to_error(self):
175
"""Called when an unhandled exception from the protocol occurs."""
407
176
# TODO: This should log to a server log file, but no such thing
408
177
# exists yet. Andrew Bennetts 2006-09-29.
409
178
self.socket.close()
410
179
self.finished = True
412
181
def _write_out(self, bytes):
413
tstart = osutils.timer_func()
414
osutils.send_all(self.socket, bytes, self._report_activity)
415
if 'hpss' in debug.debug_flags:
416
thread_id = thread.get_ident()
417
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
418
% ('wrote', thread_id, len(bytes),
419
osutils.timer_func() - tstart))
182
self.socket.sendall(bytes)
422
185
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
424
def __init__(self, in_file, out_file, backing_transport, timeout=None):
187
def __init__(self, in_file, out_file, backing_transport):
425
188
"""Construct new server.
427
190
:param in_file: Python file from which requests can be read.
428
191
:param out_file: Python file to write responses.
429
192
:param backing_transport: Transport for the directory served.
431
SmartServerStreamMedium.__init__(self, backing_transport,
194
SmartServerStreamMedium.__init__(self, backing_transport)
433
195
if sys.platform == 'win32':
434
196
# force binary mode for files
610
337
return self._read_bytes(count)
612
339
def _read_bytes(self, count):
613
"""Helper for SmartClientMediumRequest.read_bytes.
340
"""Helper for read_bytes.
615
342
read_bytes checks the state of the request to determing if bytes
616
343
should be read. After that it hands off to _read_bytes to do the
619
By default this forwards to self._medium.read_bytes because we are
620
operating on the medium's stream.
622
return self._medium.read_bytes(count)
346
raise NotImplementedError(self._read_bytes)
624
348
def read_line(self):
625
line = self._read_line()
626
if not line.endswith('\n'):
627
# end of file encountered reading from server
628
raise errors.ConnectionReset(
629
"Unexpected end of message. Please check connectivity "
630
"and permissions, and report a bug if problems persist.")
349
"""Read bytes from this request's response until a newline byte.
351
This isn't particularly efficient, so should only be used when the
352
expected size of the line is quite short.
354
:returns: a string of bytes ending in a newline (byte 0x0A).
356
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
358
while not line or line[-1] != '\n':
359
new_char = self.read_bytes(1)
362
raise errors.SmartProtocolError(
363
'unexpected end of file reading from server')
633
def _read_line(self):
634
"""Helper for SmartClientMediumRequest.read_line.
636
By default this forwards to self._medium._get_line because we are
637
operating on the medium's stream.
639
return self._medium._get_line()
642
class _VfsRefuser(object):
643
"""An object that refuses all VFS requests.
648
client._SmartClient.hooks.install_named_hook(
649
'call', self.check_vfs, 'vfs refuser')
651
def check_vfs(self, params):
653
request_method = request.request_handlers.get(params.method)
655
# A method we don't know about doesn't count as a VFS method.
657
if issubclass(request_method, vfs.VfsRequest):
658
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
661
class _DebugCounter(object):
662
"""An object that counts the HPSS calls made to each client medium.
664
When a medium is garbage-collected, or failing that when
665
bzrlib.global_state exits, the total number of calls made on that medium
666
are reported via trace.note.
670
self.counts = weakref.WeakKeyDictionary()
671
client._SmartClient.hooks.install_named_hook(
672
'call', self.increment_call_count, 'hpss call counter')
673
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
675
def track(self, medium):
676
"""Start tracking calls made to a medium.
678
This only keeps a weakref to the medium, so shouldn't affect the
681
medium_repr = repr(medium)
682
# Add this medium to the WeakKeyDictionary
683
self.counts[medium] = dict(count=0, vfs_count=0,
684
medium_repr=medium_repr)
685
# Weakref callbacks are fired in reverse order of their association
686
# with the referenced object. So we add a weakref *after* adding to
687
# the WeakKeyDict so that we can report the value from it before the
688
# entry is removed by the WeakKeyDict's own callback.
689
ref = weakref.ref(medium, self.done)
691
def increment_call_count(self, params):
692
# Increment the count in the WeakKeyDictionary
693
value = self.counts[params.medium]
696
request_method = request.request_handlers.get(params.method)
698
# A method we don't know about doesn't count as a VFS method.
700
if issubclass(request_method, vfs.VfsRequest):
701
value['vfs_count'] += 1
704
value = self.counts[ref]
705
count, vfs_count, medium_repr = (
706
value['count'], value['vfs_count'], value['medium_repr'])
707
# In case this callback is invoked for the same ref twice (by the
708
# weakref callback and by the atexit function), set the call count back
709
# to 0 so this item won't be reported twice.
711
value['vfs_count'] = 0
713
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
714
count, vfs_count, medium_repr))
717
for ref in list(self.counts.keys()):
720
_debug_counter = None
724
class SmartClientMedium(SmartMedium):
367
class SmartClientMedium(object):
725
368
"""Smart client is a medium for sending smart protocol requests over."""
727
def __init__(self, base):
728
super(SmartClientMedium, self).__init__()
730
self._protocol_version_error = None
731
self._protocol_version = None
732
self._done_hello = False
733
# Be optimistic: we assume the remote end can accept new remote
734
# requests until we get an error saying otherwise.
735
# _remote_version_is_before tracks the bzr version the remote side
736
# can be based on what we've seen so far.
737
self._remote_version_is_before = None
738
# Install debug hook function if debug flag is set.
739
if 'hpss' in debug.debug_flags:
740
global _debug_counter
741
if _debug_counter is None:
742
_debug_counter = _DebugCounter()
743
_debug_counter.track(self)
744
if 'hpss_client_no_vfs' in debug.debug_flags:
746
if _vfs_refuser is None:
747
_vfs_refuser = _VfsRefuser()
749
def _is_remote_before(self, version_tuple):
750
"""Is it possible the remote side supports RPCs for a given version?
754
needed_version = (1, 2)
755
if medium._is_remote_before(needed_version):
756
fallback_to_pre_1_2_rpc()
760
except UnknownSmartMethod:
761
medium._remember_remote_is_before(needed_version)
762
fallback_to_pre_1_2_rpc()
764
:seealso: _remember_remote_is_before
766
if self._remote_version_is_before is None:
767
# So far, the remote side seems to support everything
769
return version_tuple >= self._remote_version_is_before
771
def _remember_remote_is_before(self, version_tuple):
772
"""Tell this medium that the remote side is older the given version.
774
:seealso: _is_remote_before
776
if (self._remote_version_is_before is not None and
777
version_tuple > self._remote_version_is_before):
778
# We have been told that the remote side is older than some version
779
# which is newer than a previously supplied older-than version.
780
# This indicates that some smart verb call is not guarded
781
# appropriately (it should simply not have been tried).
783
"_remember_remote_is_before(%r) called, but "
784
"_remember_remote_is_before(%r) was called previously."
785
, version_tuple, self._remote_version_is_before)
786
if 'hpss' in debug.debug_flags:
787
ui.ui_factory.show_warning(
788
"_remember_remote_is_before(%r) called, but "
789
"_remember_remote_is_before(%r) was called previously."
790
% (version_tuple, self._remote_version_is_before))
792
self._remote_version_is_before = version_tuple
794
def protocol_version(self):
795
"""Find out if 'hello' smart request works."""
796
if self._protocol_version_error is not None:
797
raise self._protocol_version_error
798
if not self._done_hello:
800
medium_request = self.get_request()
801
# Send a 'hello' request in protocol version one, for maximum
802
# backwards compatibility.
803
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
804
client_protocol.query_version()
805
self._done_hello = True
806
except errors.SmartProtocolError, e:
807
# Cache the error, just like we would cache a successful
809
self._protocol_version_error = e
813
def should_probe(self):
814
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
817
Some transports are unambiguously smart-only; there's no need to check
818
if the transport is able to carry smart requests, because that's all
819
it is for. In those cases, this method should return False.
821
But some HTTP transports can sometimes fail to carry smart requests,
822
but still be usuable for accessing remote bzrdirs via plain file
823
accesses. So for those transports, their media should return True here
824
so that RemoteBzrDirFormat can determine if it is appropriate for that
829
370
def disconnect(self):
830
371
"""If this medium maintains a persistent connection, close it.
832
373
The default implementation does nothing.
835
def remote_path_from_transport(self, transport):
836
"""Convert transport into a path suitable for using in a request.
838
Note that the resulting remote path doesn't encode the host name or
839
anything but path, so it is only safe to use it in requests sent over
840
the medium from the matching transport.
842
medium_base = urlutils.join(self.base, '/')
843
rel_url = urlutils.relative_url(medium_base, transport.base)
844
return urlutils.unquote(rel_url)
847
377
class SmartClientStreamMedium(SmartClientMedium):
848
378
"""Stream based medium common class.
883
412
return SmartClientStreamMediumRequest(self)
886
"""We have been disconnected, reset current state.
888
This resets things like _current_request and connected state.
891
self._current_request = None
414
def read_bytes(self, count):
415
return self._read_bytes(count)
894
418
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
895
419
"""A client medium using simple pipes.
897
421
This client does not manage the pipes: it assumes they will always be open.
900
def __init__(self, readable_pipe, writeable_pipe, base):
901
SmartClientStreamMedium.__init__(self, base)
424
def __init__(self, readable_pipe, writeable_pipe):
425
SmartClientStreamMedium.__init__(self)
902
426
self._readable_pipe = readable_pipe
903
427
self._writeable_pipe = writeable_pipe
905
429
def _accept_bytes(self, bytes):
906
430
"""See SmartClientStreamMedium.accept_bytes."""
908
self._writeable_pipe.write(bytes)
910
if e.errno in (errno.EINVAL, errno.EPIPE):
911
raise errors.ConnectionReset(
912
"Error trying to write to subprocess:\n%s" % (e,))
914
self._report_activity(len(bytes), 'write')
431
self._writeable_pipe.write(bytes)
916
433
def _flush(self):
917
434
"""See SmartClientStreamMedium._flush()."""
918
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
919
# However, testing shows that even when the child process is
920
# gone, this doesn't error.
921
435
self._writeable_pipe.flush()
923
437
def _read_bytes(self, count):
924
438
"""See SmartClientStreamMedium._read_bytes."""
925
bytes_to_read = min(count, _MAX_READ_SIZE)
926
bytes = self._readable_pipe.read(bytes_to_read)
927
self._report_activity(len(bytes), 'read')
931
class SSHParams(object):
932
"""A set of parameters for starting a remote bzr via SSH."""
439
return self._readable_pipe.read(count)
442
class SmartSSHClientMedium(SmartClientStreamMedium):
443
"""A client medium using SSH."""
934
445
def __init__(self, host, port=None, username=None, password=None,
935
bzr_remote_path='bzr'):
938
self.username = username
939
self.password = password
940
self.bzr_remote_path = bzr_remote_path
943
class SmartSSHClientMedium(SmartClientStreamMedium):
944
"""A client medium using SSH.
946
It delegates IO to a SmartSimplePipesClientMedium or
947
SmartClientAlreadyConnectedSocketMedium (depending on platform).
950
def __init__(self, base, ssh_params, vendor=None):
951
447
"""Creates a client that will connect on the first use.
953
:param ssh_params: A SSHParams instance.
954
449
:param vendor: An optional override for the ssh vendor to use. See
955
450
bzrlib.transport.ssh for details on ssh vendors.
957
self._real_medium = None
958
self._ssh_params = ssh_params
959
# for the benefit of progress making a short description of this
961
self._scheme = 'bzr+ssh'
962
# SmartClientStreamMedium stores the repr of this object in its
963
# _DebugCounter so we have to store all the values used in our repr
964
# method before calling the super init.
965
SmartClientStreamMedium.__init__(self, base)
452
SmartClientStreamMedium.__init__(self)
453
self._connected = False
455
self._password = password
457
self._username = username
458
self._read_from = None
459
self._ssh_connection = None
966
460
self._vendor = vendor
967
self._ssh_connection = None
970
if self._ssh_params.port is None:
973
maybe_port = ':%s' % self._ssh_params.port
974
if self._ssh_params.username is None:
977
maybe_user = '%s@' % self._ssh_params.username
978
return "%s(%s://%s%s%s/)" % (
979
self.__class__.__name__,
982
self._ssh_params.host,
461
self._write_to = None
985
463
def _accept_bytes(self, bytes):
986
464
"""See SmartClientStreamMedium.accept_bytes."""
987
465
self._ensure_connection()
988
self._real_medium.accept_bytes(bytes)
466
self._write_to.write(bytes)
990
468
def disconnect(self):
991
469
"""See SmartClientMedium.disconnect()."""
992
if self._real_medium is not None:
993
self._real_medium.disconnect()
994
self._real_medium = None
995
if self._ssh_connection is not None:
996
self._ssh_connection.close()
997
self._ssh_connection = None
470
if not self._connected:
472
self._read_from.close()
473
self._write_to.close()
474
self._ssh_connection.close()
475
self._connected = False
999
477
def _ensure_connection(self):
1000
478
"""Connect this medium if not already connected."""
1001
if self._real_medium is not None:
481
executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1003
482
if self._vendor is None:
1004
483
vendor = ssh._get_ssh_vendor()
1006
485
vendor = self._vendor
1007
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1008
self._ssh_params.password, self._ssh_params.host,
1009
self._ssh_params.port,
1010
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1011
'--directory=/', '--allow-writes'])
1012
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1013
if io_kind == 'socket':
1014
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1015
self.base, io_object)
1016
elif io_kind == 'pipes':
1017
read_from, write_to = io_object
1018
self._real_medium = SmartSimplePipesClientMedium(
1019
read_from, write_to, self.base)
1021
raise AssertionError(
1022
"Unexpected io_kind %r from %r"
1023
% (io_kind, self._ssh_connection))
486
self._ssh_connection = vendor.connect_ssh(self._username,
487
self._password, self._host, self._port,
488
command=[executable, 'serve', '--inet', '--directory=/',
490
self._read_from, self._write_to = \
491
self._ssh_connection.get_filelike_channels()
492
self._connected = True
1025
494
def _flush(self):
1026
495
"""See SmartClientStreamMedium._flush()."""
1027
self._real_medium._flush()
496
self._write_to.flush()
1029
498
def _read_bytes(self, count):
1030
499
"""See SmartClientStreamMedium.read_bytes."""
1031
if self._real_medium is None:
500
if not self._connected:
1032
501
raise errors.MediumNotConnected(self)
1033
return self._real_medium.read_bytes(count)
1036
# Port 4155 is the default port for bzr://, registered with IANA.
1037
BZR_DEFAULT_INTERFACE = None
1038
BZR_DEFAULT_PORT = 4155
1041
class SmartClientSocketMedium(SmartClientStreamMedium):
1042
"""A client medium using a socket.
502
return self._read_from.read(count)
505
class SmartTCPClientMedium(SmartClientStreamMedium):
506
"""A client medium using TCP."""
1044
This class isn't usable directly. Use one of its subclasses instead.
1047
def __init__(self, base):
1048
SmartClientStreamMedium.__init__(self, base)
508
def __init__(self, host, port):
509
"""Creates a client that will connect on the first use."""
510
SmartClientStreamMedium.__init__(self)
511
self._connected = False
1049
514
self._socket = None
1050
self._connected = False
1052
516
def _accept_bytes(self, bytes):
1053
517
"""See SmartClientMedium.accept_bytes."""
1054
518
self._ensure_connection()
1055
osutils.send_all(self._socket, bytes, self._report_activity)
519
self._socket.sendall(bytes)
521
def disconnect(self):
522
"""See SmartClientMedium.disconnect()."""
523
if not self._connected:
527
self._connected = False
1057
529
def _ensure_connection(self):
1058
530
"""Connect this medium if not already connected."""
1059
raise NotImplementedError(self._ensure_connection)
533
self._socket = socket.socket()
534
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
535
result = self._socket.connect_ex((self._host, int(self._port)))
537
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
538
(self._host, self._port, os.strerror(result)))
539
self._connected = True
1061
541
def _flush(self):
1062
542
"""See SmartClientStreamMedium._flush().
1064
For sockets we do no flushing. For TCP sockets we may want to turn off
1065
TCP_NODELAY and add a means to do a flush, but that can be done in the
544
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
545
add a means to do a flush, but that can be done in the future.
1069
548
def _read_bytes(self, count):
1070
549
"""See SmartClientMedium.read_bytes."""
1071
550
if not self._connected:
1072
551
raise errors.MediumNotConnected(self)
1073
return osutils.read_bytes_from_socket(
1074
self._socket, self._report_activity)
1076
def disconnect(self):
1077
"""See SmartClientMedium.disconnect()."""
1078
if not self._connected:
1080
self._socket.close()
1082
self._connected = False
1085
class SmartTCPClientMedium(SmartClientSocketMedium):
1086
"""A client medium that creates a TCP connection."""
1088
def __init__(self, host, port, base):
1089
"""Creates a client that will connect on the first use."""
1090
SmartClientSocketMedium.__init__(self, base)
1094
def _ensure_connection(self):
1095
"""Connect this medium if not already connected."""
1098
if self._port is None:
1099
port = BZR_DEFAULT_PORT
1101
port = int(self._port)
1103
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1104
socket.SOCK_STREAM, 0, 0)
1105
except socket.gaierror, (err_num, err_msg):
1106
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1107
(self._host, port, err_msg))
1108
# Initialize err in case there are no addresses returned:
1109
err = socket.error("no address found for %s" % self._host)
1110
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1112
self._socket = socket.socket(family, socktype, proto)
1113
self._socket.setsockopt(socket.IPPROTO_TCP,
1114
socket.TCP_NODELAY, 1)
1115
self._socket.connect(sockaddr)
1116
except socket.error, err:
1117
if self._socket is not None:
1118
self._socket.close()
1122
if self._socket is None:
1123
# socket errors either have a (string) or (errno, string) as their
1125
if type(err.args) is str:
1128
err_msg = err.args[1]
1129
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1130
(self._host, port, err_msg))
1131
self._connected = True
1134
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1135
"""A client medium for an already connected socket.
1137
Note that this class will assume it "owns" the socket, so it will close it
1138
when its disconnect method is called.
1141
def __init__(self, base, sock):
1142
SmartClientSocketMedium.__init__(self, base)
1144
self._connected = True
1146
def _ensure_connection(self):
1147
# Already connected, by definition! So nothing to do.
552
return self._socket.recv(count)
1151
555
class SmartClientStreamMediumRequest(SmartClientMediumRequest):