24
24
bzrlib/transport/smart/__init__.py.
34
from bzrlib.lazy_import import lazy_import
35
lazy_import(globals(), """
31
from bzrlib import errors
32
from bzrlib.smart.protocol import (
34
SmartServerRequestProtocolOne,
35
SmartServerRequestProtocolTwo,
39
from bzrlib.transport import ssh
40
except errors.ParamikoNotPresent:
41
# no paramiko. SmartSSHClientMedium will break.
45
class SmartServerStreamMedium(object):
49
from bzrlib.i18n import gettext
50
from bzrlib.smart import client, protocol, request, signals, vfs
51
from bzrlib.transport import ssh
53
from bzrlib import osutils
55
# Throughout this module buffer size parameters are either limited to be at
56
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
57
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
58
# from non-sockets as well.
59
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
61
def _get_protocol_factory_for_bytes(bytes):
62
"""Determine the right protocol factory for 'bytes'.
64
This will return an appropriate protocol factory depending on the version
65
of the protocol being used, as determined by inspecting the given bytes.
66
The bytes should have at least one newline byte (i.e. be a whole line),
67
otherwise it's possible that a request will be incorrectly identified as
70
Typical use would be::
72
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
73
server_protocol = factory(transport, write_func, root_client_path)
74
server_protocol.accept_bytes(unused_bytes)
76
:param bytes: a str of bytes of the start of the request.
77
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
78
a callable that takes three args: transport, write_func,
79
root_client_path. unused_bytes are any bytes that were not part of a
80
protocol version marker.
82
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
83
protocol_factory = protocol.build_server_protocol_three
84
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
85
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
86
protocol_factory = protocol.SmartServerRequestProtocolTwo
87
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
89
protocol_factory = protocol.SmartServerRequestProtocolOne
90
return protocol_factory, bytes
93
def _get_line(read_bytes_func):
94
"""Read bytes using read_bytes_func until a newline byte.
96
This isn't particularly efficient, so should only be used when the
97
expected size of the line is quite short.
99
:returns: a tuple of two strs: (line, excess)
103
while newline_pos == -1:
104
new_bytes = read_bytes_func(1)
107
# Ran out of bytes before receiving a complete line.
109
newline_pos = bytes.find('\n')
110
line = bytes[:newline_pos+1]
111
excess = bytes[newline_pos+1:]
115
class SmartMedium(object):
116
"""Base class for smart protocol media, both client- and server-side."""
119
self._push_back_buffer = None
121
def _push_back(self, bytes):
122
"""Return unused bytes to the medium, because they belong to the next
125
This sets the _push_back_buffer to the given bytes.
127
if self._push_back_buffer is not None:
128
raise AssertionError(
129
"_push_back called when self._push_back_buffer is %r"
130
% (self._push_back_buffer,))
133
self._push_back_buffer = bytes
135
def _get_push_back_buffer(self):
136
if self._push_back_buffer == '':
137
raise AssertionError(
138
'%s._push_back_buffer should never be the empty string, '
139
'which can be confused with EOF' % (self,))
140
bytes = self._push_back_buffer
141
self._push_back_buffer = None
144
def read_bytes(self, desired_count):
145
"""Read some bytes from this medium.
147
:returns: some bytes, possibly more or less than the number requested
148
in 'desired_count' depending on the medium.
150
if self._push_back_buffer is not None:
151
return self._get_push_back_buffer()
152
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
153
return self._read_bytes(bytes_to_read)
155
def _read_bytes(self, count):
156
raise NotImplementedError(self._read_bytes)
159
"""Read bytes from this request's response until a newline byte.
161
This isn't particularly efficient, so should only be used when the
162
expected size of the line is quite short.
164
:returns: a string of bytes ending in a newline (byte 0x0A).
166
line, excess = _get_line(self.read_bytes)
167
self._push_back(excess)
170
def _report_activity(self, bytes, direction):
171
"""Notify that this medium has activity.
173
Implementations should call this from all methods that actually do IO.
174
Be careful that it's not called twice, if one method is implemented on
177
:param bytes: Number of bytes read or written.
178
:param direction: 'read' or 'write' or None.
180
ui.ui_factory.report_transport_activity(self, bytes, direction)
183
_bad_file_descriptor = (errno.EBADF,)
184
if sys.platform == 'win32':
185
# Given on Windows if you pass a closed socket to select.select. Probably
186
# also given if you pass a file handle to select.
188
_bad_file_descriptor += (WSAENOTSOCK,)
191
class SmartServerStreamMedium(SmartMedium):
46
192
"""Handles smart commands coming over a stream.
48
194
The stream may be a pipe connected to sshd, or a tcp socket, or an
87
276
:returns: a SmartServerRequestProtocol.
89
# Identify the protocol version.
278
self._wait_for_bytes_with_timeout(self._client_timeout)
280
# We're stopping, so don't try to do any more work
90
282
bytes = self._get_line()
91
if bytes.startswith(REQUEST_VERSION_TWO):
92
protocol_class = SmartServerRequestProtocolTwo
93
bytes = bytes[len(REQUEST_VERSION_TWO):]
95
protocol_class = SmartServerRequestProtocolOne
96
protocol = protocol_class(self.backing_transport, self._write_out)
97
protocol.accept_bytes(bytes)
283
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
284
protocol = protocol_factory(
285
self.backing_transport, self._write_out, self.root_client_path)
286
protocol.accept_bytes(unused_bytes)
289
def _wait_on_descriptor(self, fd, timeout_seconds):
290
"""select() on a file descriptor, waiting for nonblocking read()
292
This will raise a ConnectionTimeout exception if we do not get a
293
readable handle before timeout_seconds.
296
t_end = self._timer() + timeout_seconds
297
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
299
while not rs and not xs and self._timer() < t_end:
303
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
304
except (select.error, socket.error) as e:
305
err = getattr(e, 'errno', None)
306
if err is None and getattr(e, 'args', None) is not None:
307
# select.error doesn't have 'errno', it just has args[0]
309
if err in _bad_file_descriptor:
310
return # Not a socket indicates read() will fail
311
elif err == errno.EINTR:
312
# Interrupted, keep looping.
317
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
318
% (timeout_seconds,))
100
320
def _serve_one_request(self, protocol):
101
321
"""Read one request from input, process, send back a response.
103
323
:param protocol: a SmartServerRequestProtocol.
106
328
self._serve_one_request_unguarded(protocol)
107
329
except KeyboardInterrupt:
113
335
"""Called when an unhandled exception from the protocol occurs."""
114
336
raise NotImplementedError(self.terminate_due_to_error)
116
def _get_bytes(self, desired_count):
338
def _read_bytes(self, desired_count):
117
339
"""Get some bytes from the medium.
119
341
:param desired_count: number of bytes we want to read.
121
raise NotImplementedError(self._get_bytes)
124
"""Read bytes from this request's response until a newline byte.
126
This isn't particularly efficient, so should only be used when the
127
expected size of the line is quite short.
129
:returns: a string of bytes ending in a newline (byte 0x0A).
131
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
133
while not line or line[-1] != '\n':
134
new_char = self._get_bytes(1)
137
# Ran out of bytes before receiving a complete line.
343
raise NotImplementedError(self._read_bytes)
142
346
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
144
def __init__(self, sock, backing_transport):
348
def __init__(self, sock, backing_transport, root_client_path='/',
147
352
:param sock: the socket the server will read from. It will be put
148
353
into blocking mode.
150
SmartServerStreamMedium.__init__(self, backing_transport)
355
SmartServerStreamMedium.__init__(
356
self, backing_transport, root_client_path=root_client_path,
152
358
sock.setblocking(True)
153
359
self.socket = sock
360
# Get the getpeername now, as we might be closed later when we care.
362
self._client_info = sock.getpeername()
364
self._client_info = '<unknown>'
367
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
370
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
155
373
def _serve_one_request_unguarded(self, protocol):
156
374
while protocol.next_read_size():
158
protocol.accept_bytes(self.push_back)
161
bytes = self._get_bytes(4096)
165
protocol.accept_bytes(bytes)
167
self.push_back = protocol.excess_buffer
169
def _get_bytes(self, desired_count):
170
# We ignore the desired_count because on sockets it's more efficient to
172
return self.socket.recv(4096)
375
# We can safely try to read large chunks. If there is less data
376
# than MAX_SOCKET_CHUNK ready, the socket will just return a
377
# short read immediately rather than block.
378
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
382
protocol.accept_bytes(bytes)
384
self._push_back(protocol.unused_data)
386
def _disconnect_client(self):
387
"""Close the current connection. We stopped due to a timeout/etc."""
390
def _wait_for_bytes_with_timeout(self, timeout_seconds):
391
"""Wait for more bytes to be read, but timeout if none available.
393
This allows us to detect idle connections, and stop trying to read from
394
them, without setting the socket itself to non-blocking. This also
395
allows us to specify when we watch for idle timeouts.
397
:return: None, this will raise ConnectionTimeout if we time out before
400
return self._wait_on_descriptor(self.socket, timeout_seconds)
402
def _read_bytes(self, desired_count):
403
return osutils.read_bytes_from_socket(
404
self.socket, self._report_activity)
174
406
def terminate_due_to_error(self):
175
"""Called when an unhandled exception from the protocol occurs."""
176
407
# TODO: This should log to a server log file, but no such thing
177
408
# exists yet. Andrew Bennetts 2006-09-29.
178
409
self.socket.close()
179
410
self.finished = True
181
412
def _write_out(self, bytes):
182
self.socket.sendall(bytes)
413
tstart = osutils.timer_func()
414
osutils.send_all(self.socket, bytes, self._report_activity)
415
if 'hpss' in debug.debug_flags:
416
thread_id = thread.get_ident()
417
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
418
% ('wrote', thread_id, len(bytes),
419
osutils.timer_func() - tstart))
185
422
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
187
def __init__(self, in_file, out_file, backing_transport):
424
def __init__(self, in_file, out_file, backing_transport, timeout=None):
188
425
"""Construct new server.
190
427
:param in_file: Python file from which requests can be read.
191
428
:param out_file: Python file to write responses.
192
429
:param backing_transport: Transport for the directory served.
194
SmartServerStreamMedium.__init__(self, backing_transport)
431
SmartServerStreamMedium.__init__(self, backing_transport,
195
433
if sys.platform == 'win32':
196
434
# force binary mode for files
337
610
return self._read_bytes(count)
339
612
def _read_bytes(self, count):
340
"""Helper for read_bytes.
613
"""Helper for SmartClientMediumRequest.read_bytes.
342
615
read_bytes checks the state of the request to determing if bytes
343
616
should be read. After that it hands off to _read_bytes to do the
619
By default this forwards to self._medium.read_bytes because we are
620
operating on the medium's stream.
346
raise NotImplementedError(self._read_bytes)
622
return self._medium.read_bytes(count)
348
624
def read_line(self):
349
"""Read bytes from this request's response until a newline byte.
351
This isn't particularly efficient, so should only be used when the
352
expected size of the line is quite short.
354
:returns: a string of bytes ending in a newline (byte 0x0A).
356
# XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
358
while not line or line[-1] != '\n':
359
new_char = self.read_bytes(1)
362
raise errors.SmartProtocolError(
363
'unexpected end of file reading from server')
625
line = self._read_line()
626
if not line.endswith('\n'):
627
# end of file encountered reading from server
628
raise errors.ConnectionReset(
629
"Unexpected end of message. Please check connectivity "
630
"and permissions, and report a bug if problems persist.")
367
class SmartClientMedium(object):
633
def _read_line(self):
634
"""Helper for SmartClientMediumRequest.read_line.
636
By default this forwards to self._medium._get_line because we are
637
operating on the medium's stream.
639
return self._medium._get_line()
642
class _VfsRefuser(object):
643
"""An object that refuses all VFS requests.
648
client._SmartClient.hooks.install_named_hook(
649
'call', self.check_vfs, 'vfs refuser')
651
def check_vfs(self, params):
653
request_method = request.request_handlers.get(params.method)
655
# A method we don't know about doesn't count as a VFS method.
657
if issubclass(request_method, vfs.VfsRequest):
658
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
661
class _DebugCounter(object):
662
"""An object that counts the HPSS calls made to each client medium.
664
When a medium is garbage-collected, or failing that when
665
bzrlib.global_state exits, the total number of calls made on that medium
666
are reported via trace.note.
670
self.counts = weakref.WeakKeyDictionary()
671
client._SmartClient.hooks.install_named_hook(
672
'call', self.increment_call_count, 'hpss call counter')
673
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
675
def track(self, medium):
676
"""Start tracking calls made to a medium.
678
This only keeps a weakref to the medium, so shouldn't affect the
681
medium_repr = repr(medium)
682
# Add this medium to the WeakKeyDictionary
683
self.counts[medium] = dict(count=0, vfs_count=0,
684
medium_repr=medium_repr)
685
# Weakref callbacks are fired in reverse order of their association
686
# with the referenced object. So we add a weakref *after* adding to
687
# the WeakKeyDict so that we can report the value from it before the
688
# entry is removed by the WeakKeyDict's own callback.
689
ref = weakref.ref(medium, self.done)
691
def increment_call_count(self, params):
692
# Increment the count in the WeakKeyDictionary
693
value = self.counts[params.medium]
696
request_method = request.request_handlers.get(params.method)
698
# A method we don't know about doesn't count as a VFS method.
700
if issubclass(request_method, vfs.VfsRequest):
701
value['vfs_count'] += 1
704
value = self.counts[ref]
705
count, vfs_count, medium_repr = (
706
value['count'], value['vfs_count'], value['medium_repr'])
707
# In case this callback is invoked for the same ref twice (by the
708
# weakref callback and by the atexit function), set the call count back
709
# to 0 so this item won't be reported twice.
711
value['vfs_count'] = 0
713
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
714
count, vfs_count, medium_repr))
717
for ref in list(self.counts.keys()):
720
_debug_counter = None
724
class SmartClientMedium(SmartMedium):
368
725
"""Smart client is a medium for sending smart protocol requests over."""
727
def __init__(self, base):
728
super(SmartClientMedium, self).__init__()
730
self._protocol_version_error = None
731
self._protocol_version = None
732
self._done_hello = False
733
# Be optimistic: we assume the remote end can accept new remote
734
# requests until we get an error saying otherwise.
735
# _remote_version_is_before tracks the bzr version the remote side
736
# can be based on what we've seen so far.
737
self._remote_version_is_before = None
738
# Install debug hook function if debug flag is set.
739
if 'hpss' in debug.debug_flags:
740
global _debug_counter
741
if _debug_counter is None:
742
_debug_counter = _DebugCounter()
743
_debug_counter.track(self)
744
if 'hpss_client_no_vfs' in debug.debug_flags:
746
if _vfs_refuser is None:
747
_vfs_refuser = _VfsRefuser()
749
def _is_remote_before(self, version_tuple):
750
"""Is it possible the remote side supports RPCs for a given version?
754
needed_version = (1, 2)
755
if medium._is_remote_before(needed_version):
756
fallback_to_pre_1_2_rpc()
760
except UnknownSmartMethod:
761
medium._remember_remote_is_before(needed_version)
762
fallback_to_pre_1_2_rpc()
764
:seealso: _remember_remote_is_before
766
if self._remote_version_is_before is None:
767
# So far, the remote side seems to support everything
769
return version_tuple >= self._remote_version_is_before
771
def _remember_remote_is_before(self, version_tuple):
772
"""Tell this medium that the remote side is older the given version.
774
:seealso: _is_remote_before
776
if (self._remote_version_is_before is not None and
777
version_tuple > self._remote_version_is_before):
778
# We have been told that the remote side is older than some version
779
# which is newer than a previously supplied older-than version.
780
# This indicates that some smart verb call is not guarded
781
# appropriately (it should simply not have been tried).
783
"_remember_remote_is_before(%r) called, but "
784
"_remember_remote_is_before(%r) was called previously."
785
, version_tuple, self._remote_version_is_before)
786
if 'hpss' in debug.debug_flags:
787
ui.ui_factory.show_warning(
788
"_remember_remote_is_before(%r) called, but "
789
"_remember_remote_is_before(%r) was called previously."
790
% (version_tuple, self._remote_version_is_before))
792
self._remote_version_is_before = version_tuple
794
def protocol_version(self):
795
"""Find out if 'hello' smart request works."""
796
if self._protocol_version_error is not None:
797
raise self._protocol_version_error
798
if not self._done_hello:
800
medium_request = self.get_request()
801
# Send a 'hello' request in protocol version one, for maximum
802
# backwards compatibility.
803
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
804
client_protocol.query_version()
805
self._done_hello = True
806
except errors.SmartProtocolError, e:
807
# Cache the error, just like we would cache a successful
809
self._protocol_version_error = e
813
def should_probe(self):
814
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
817
Some transports are unambiguously smart-only; there's no need to check
818
if the transport is able to carry smart requests, because that's all
819
it is for. In those cases, this method should return False.
821
But some HTTP transports can sometimes fail to carry smart requests,
822
but still be usuable for accessing remote bzrdirs via plain file
823
accesses. So for those transports, their media should return True here
824
so that RemoteBzrDirFormat can determine if it is appropriate for that
370
829
def disconnect(self):
371
830
"""If this medium maintains a persistent connection, close it.
373
832
The default implementation does nothing.
835
def remote_path_from_transport(self, transport):
836
"""Convert transport into a path suitable for using in a request.
838
Note that the resulting remote path doesn't encode the host name or
839
anything but path, so it is only safe to use it in requests sent over
840
the medium from the matching transport.
842
medium_base = urlutils.join(self.base, '/')
843
rel_url = urlutils.relative_url(medium_base, transport.base)
844
return urllib.unquote(rel_url)
377
847
class SmartClientStreamMedium(SmartClientMedium):
378
848
"""Stream based medium common class.
412
883
return SmartClientStreamMediumRequest(self)
414
def read_bytes(self, count):
415
return self._read_bytes(count)
886
"""We have been disconnected, reset current state.
888
This resets things like _current_request and connected state.
891
self._current_request = None
418
894
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
419
895
"""A client medium using simple pipes.
421
897
This client does not manage the pipes: it assumes they will always be open.
424
def __init__(self, readable_pipe, writeable_pipe):
425
SmartClientStreamMedium.__init__(self)
900
def __init__(self, readable_pipe, writeable_pipe, base):
901
SmartClientStreamMedium.__init__(self, base)
426
902
self._readable_pipe = readable_pipe
427
903
self._writeable_pipe = writeable_pipe
429
905
def _accept_bytes(self, bytes):
430
906
"""See SmartClientStreamMedium.accept_bytes."""
431
self._writeable_pipe.write(bytes)
908
self._writeable_pipe.write(bytes)
910
if e.errno in (errno.EINVAL, errno.EPIPE):
911
raise errors.ConnectionReset(
912
"Error trying to write to subprocess:\n%s" % (e,))
914
self._report_activity(len(bytes), 'write')
433
916
def _flush(self):
434
917
"""See SmartClientStreamMedium._flush()."""
918
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
919
# However, testing shows that even when the child process is
920
# gone, this doesn't error.
435
921
self._writeable_pipe.flush()
437
923
def _read_bytes(self, count):
438
924
"""See SmartClientStreamMedium._read_bytes."""
439
return self._readable_pipe.read(count)
925
bytes_to_read = min(count, _MAX_READ_SIZE)
926
bytes = self._readable_pipe.read(bytes_to_read)
927
self._report_activity(len(bytes), 'read')
931
class SSHParams(object):
932
"""A set of parameters for starting a remote bzr via SSH."""
934
def __init__(self, host, port=None, username=None, password=None,
935
bzr_remote_path='bzr'):
938
self.username = username
939
self.password = password
940
self.bzr_remote_path = bzr_remote_path
442
943
class SmartSSHClientMedium(SmartClientStreamMedium):
443
"""A client medium using SSH."""
445
def __init__(self, host, port=None, username=None, password=None,
944
"""A client medium using SSH.
946
It delegates IO to a SmartSimplePipesClientMedium or
947
SmartClientAlreadyConnectedSocketMedium (depending on platform).
950
def __init__(self, base, ssh_params, vendor=None):
447
951
"""Creates a client that will connect on the first use.
953
:param ssh_params: A SSHParams instance.
449
954
:param vendor: An optional override for the ssh vendor to use. See
450
955
bzrlib.transport.ssh for details on ssh vendors.
452
SmartClientStreamMedium.__init__(self)
453
self._connected = False
455
self._password = password
457
self._username = username
458
self._read_from = None
957
self._real_medium = None
958
self._ssh_params = ssh_params
959
# for the benefit of progress making a short description of this
961
self._scheme = 'bzr+ssh'
962
# SmartClientStreamMedium stores the repr of this object in its
963
# _DebugCounter so we have to store all the values used in our repr
964
# method before calling the super init.
965
SmartClientStreamMedium.__init__(self, base)
966
self._vendor = vendor
459
967
self._ssh_connection = None
460
self._vendor = vendor
461
self._write_to = None
970
if self._ssh_params.port is None:
973
maybe_port = ':%s' % self._ssh_params.port
974
if self._ssh_params.username is None:
977
maybe_user = '%s@' % self._ssh_params.username
978
return "%s(%s://%s%s%s/)" % (
979
self.__class__.__name__,
982
self._ssh_params.host,
463
985
def _accept_bytes(self, bytes):
464
986
"""See SmartClientStreamMedium.accept_bytes."""
465
987
self._ensure_connection()
466
self._write_to.write(bytes)
988
self._real_medium.accept_bytes(bytes)
468
990
def disconnect(self):
469
991
"""See SmartClientMedium.disconnect()."""
470
if not self._connected:
472
self._read_from.close()
473
self._write_to.close()
474
self._ssh_connection.close()
475
self._connected = False
992
if self._real_medium is not None:
993
self._real_medium.disconnect()
994
self._real_medium = None
995
if self._ssh_connection is not None:
996
self._ssh_connection.close()
997
self._ssh_connection = None
477
999
def _ensure_connection(self):
478
1000
"""Connect this medium if not already connected."""
1001
if self._real_medium is not None:
481
executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
482
1003
if self._vendor is None:
483
1004
vendor = ssh._get_ssh_vendor()
485
1006
vendor = self._vendor
486
self._ssh_connection = vendor.connect_ssh(self._username,
487
self._password, self._host, self._port,
488
command=[executable, 'serve', '--inet', '--directory=/',
490
self._read_from, self._write_to = \
491
self._ssh_connection.get_filelike_channels()
492
self._connected = True
1007
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1008
self._ssh_params.password, self._ssh_params.host,
1009
self._ssh_params.port,
1010
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1011
'--directory=/', '--allow-writes'])
1012
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1013
if io_kind == 'socket':
1014
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1015
self.base, io_object)
1016
elif io_kind == 'pipes':
1017
read_from, write_to = io_object
1018
self._real_medium = SmartSimplePipesClientMedium(
1019
read_from, write_to, self.base)
1021
raise AssertionError(
1022
"Unexpected io_kind %r from %r"
1023
% (io_kind, self._ssh_connection))
1024
for hook in transport.Transport.hooks["post_connect"]:
494
1027
def _flush(self):
495
1028
"""See SmartClientStreamMedium._flush()."""
496
self._write_to.flush()
1029
self._real_medium._flush()
498
1031
def _read_bytes(self, count):
499
1032
"""See SmartClientStreamMedium.read_bytes."""
500
if not self._connected:
1033
if self._real_medium is None:
501
1034
raise errors.MediumNotConnected(self)
502
return self._read_from.read(count)
505
class SmartTCPClientMedium(SmartClientStreamMedium):
506
"""A client medium using TCP."""
1035
return self._real_medium.read_bytes(count)
1038
# Port 4155 is the default port for bzr://, registered with IANA.
1039
BZR_DEFAULT_INTERFACE = None
1040
BZR_DEFAULT_PORT = 4155
1043
class SmartClientSocketMedium(SmartClientStreamMedium):
1044
"""A client medium using a socket.
508
def __init__(self, host, port):
509
"""Creates a client that will connect on the first use."""
510
SmartClientStreamMedium.__init__(self)
1046
This class isn't usable directly. Use one of its subclasses instead.
1049
def __init__(self, base):
1050
SmartClientStreamMedium.__init__(self, base)
511
1052
self._connected = False
516
1054
def _accept_bytes(self, bytes):
517
1055
"""See SmartClientMedium.accept_bytes."""
518
1056
self._ensure_connection()
519
self._socket.sendall(bytes)
1057
osutils.send_all(self._socket, bytes, self._report_activity)
1059
def _ensure_connection(self):
1060
"""Connect this medium if not already connected."""
1061
raise NotImplementedError(self._ensure_connection)
1064
"""See SmartClientStreamMedium._flush().
1066
For sockets we do no flushing. For TCP sockets we may want to turn off
1067
TCP_NODELAY and add a means to do a flush, but that can be done in the
1071
def _read_bytes(self, count):
1072
"""See SmartClientMedium.read_bytes."""
1073
if not self._connected:
1074
raise errors.MediumNotConnected(self)
1075
return osutils.read_bytes_from_socket(
1076
self._socket, self._report_activity)
521
1078
def disconnect(self):
522
1079
"""See SmartClientMedium.disconnect()."""
526
1083
self._socket = None
527
1084
self._connected = False
1087
class SmartTCPClientMedium(SmartClientSocketMedium):
1088
"""A client medium that creates a TCP connection."""
1090
def __init__(self, host, port, base):
1091
"""Creates a client that will connect on the first use."""
1092
SmartClientSocketMedium.__init__(self, base)
529
1096
def _ensure_connection(self):
530
1097
"""Connect this medium if not already connected."""
531
1098
if self._connected:
533
self._socket = socket.socket()
534
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
535
result = self._socket.connect_ex((self._host, int(self._port)))
1100
if self._port is None:
1101
port = BZR_DEFAULT_PORT
1103
port = int(self._port)
1105
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1106
socket.SOCK_STREAM, 0, 0)
1107
except socket.gaierror, (err_num, err_msg):
1108
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1109
(self._host, port, err_msg))
1110
# Initialize err in case there are no addresses returned:
1111
err = socket.error("no address found for %s" % self._host)
1112
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1114
self._socket = socket.socket(family, socktype, proto)
1115
self._socket.setsockopt(socket.IPPROTO_TCP,
1116
socket.TCP_NODELAY, 1)
1117
self._socket.connect(sockaddr)
1118
except socket.error, err:
1119
if self._socket is not None:
1120
self._socket.close()
1124
if self._socket is None:
1125
# socket errors either have a (string) or (errno, string) as their
1127
if type(err.args) is str:
1130
err_msg = err.args[1]
537
1131
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
538
(self._host, self._port, os.strerror(result)))
539
self._connected = True
542
"""See SmartClientStreamMedium._flush().
544
For TCP we do no flushing. We may want to turn off TCP_NODELAY and
545
add a means to do a flush, but that can be done in the future.
548
def _read_bytes(self, count):
549
"""See SmartClientMedium.read_bytes."""
550
if not self._connected:
551
raise errors.MediumNotConnected(self)
552
return self._socket.recv(count)
1132
(self._host, port, err_msg))
1133
self._connected = True
1134
for hook in transport.Transport.hooks["post_connect"]:
1138
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1139
"""A client medium for an already connected socket.
1141
Note that this class will assume it "owns" the socket, so it will close it
1142
when its disconnect method is called.
1145
def __init__(self, base, sock):
1146
SmartClientSocketMedium.__init__(self, base)
1148
self._connected = True
1150
def _ensure_connection(self):
1151
# Already connected, by definition! So nothing to do.
555
1155
class SmartClientStreamMediumRequest(SmartClientMediumRequest):