1
# Copyright (C) 2006-2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""The 'medium' layer for the smart servers and clients.
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
bzrlib/transport/smart/__init__.py.
34
from bzrlib.lazy_import import lazy_import
35
lazy_import(globals(), """
48
from bzrlib.i18n import gettext
49
from bzrlib.smart import client, protocol, request, signals, vfs
50
from bzrlib.transport import ssh
52
from bzrlib import osutils
54
# Throughout this module buffer size parameters are either limited to be at
55
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
56
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
57
# from non-sockets as well.
58
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
60
def _get_protocol_factory_for_bytes(bytes):
61
"""Determine the right protocol factory for 'bytes'.
63
This will return an appropriate protocol factory depending on the version
64
of the protocol being used, as determined by inspecting the given bytes.
65
The bytes should have at least one newline byte (i.e. be a whole line),
66
otherwise it's possible that a request will be incorrectly identified as
69
Typical use would be::
71
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
72
server_protocol = factory(transport, write_func, root_client_path)
73
server_protocol.accept_bytes(unused_bytes)
75
:param bytes: a str of bytes of the start of the request.
76
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
77
a callable that takes three args: transport, write_func,
78
root_client_path. unused_bytes are any bytes that were not part of a
79
protocol version marker.
81
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
82
protocol_factory = protocol.build_server_protocol_three
83
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
84
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
85
protocol_factory = protocol.SmartServerRequestProtocolTwo
86
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
88
protocol_factory = protocol.SmartServerRequestProtocolOne
89
return protocol_factory, bytes
92
def _get_line(read_bytes_func):
93
"""Read bytes using read_bytes_func until a newline byte.
95
This isn't particularly efficient, so should only be used when the
96
expected size of the line is quite short.
98
:returns: a tuple of two strs: (line, excess)
102
while newline_pos == -1:
103
new_bytes = read_bytes_func(1)
106
# Ran out of bytes before receiving a complete line.
108
newline_pos = bytes.find('\n')
109
line = bytes[:newline_pos+1]
110
excess = bytes[newline_pos+1:]
114
class SmartMedium(object):
115
"""Base class for smart protocol media, both client- and server-side."""
118
self._push_back_buffer = None
120
def _push_back(self, bytes):
121
"""Return unused bytes to the medium, because they belong to the next
124
This sets the _push_back_buffer to the given bytes.
126
if self._push_back_buffer is not None:
127
raise AssertionError(
128
"_push_back called when self._push_back_buffer is %r"
129
% (self._push_back_buffer,))
132
self._push_back_buffer = bytes
134
def _get_push_back_buffer(self):
135
if self._push_back_buffer == '':
136
raise AssertionError(
137
'%s._push_back_buffer should never be the empty string, '
138
'which can be confused with EOF' % (self,))
139
bytes = self._push_back_buffer
140
self._push_back_buffer = None
143
def read_bytes(self, desired_count):
144
"""Read some bytes from this medium.
146
:returns: some bytes, possibly more or less than the number requested
147
in 'desired_count' depending on the medium.
149
if self._push_back_buffer is not None:
150
return self._get_push_back_buffer()
151
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
152
return self._read_bytes(bytes_to_read)
154
def _read_bytes(self, count):
155
raise NotImplementedError(self._read_bytes)
158
"""Read bytes from this request's response until a newline byte.
160
This isn't particularly efficient, so should only be used when the
161
expected size of the line is quite short.
163
:returns: a string of bytes ending in a newline (byte 0x0A).
165
line, excess = _get_line(self.read_bytes)
166
self._push_back(excess)
169
def _report_activity(self, bytes, direction):
170
"""Notify that this medium has activity.
172
Implementations should call this from all methods that actually do IO.
173
Be careful that it's not called twice, if one method is implemented on
176
:param bytes: Number of bytes read or written.
177
:param direction: 'read' or 'write' or None.
179
ui.ui_factory.report_transport_activity(self, bytes, direction)
182
_bad_file_descriptor = (errno.EBADF,)
183
if sys.platform == 'win32':
184
# Given on Windows if you pass a closed socket to select.select. Probably
185
# also given if you pass a file handle to select.
187
_bad_file_descriptor += (WSAENOTSOCK,)
190
class SmartServerStreamMedium(SmartMedium):
191
"""Handles smart commands coming over a stream.
193
The stream may be a pipe connected to sshd, or a tcp socket, or an
194
in-process fifo for testing.
196
One instance is created for each connected client; it can serve multiple
197
requests in the lifetime of the connection.
199
The server passes requests through to an underlying backing transport,
200
which will typically be a LocalTransport looking at the server's filesystem.
202
:ivar _push_back_buffer: a str of bytes that have been read from the stream
203
but not used yet, or None if there are no buffered bytes. Subclasses
204
should make sure to exhaust this buffer before reading more bytes from
205
the stream. See also the _push_back method.
210
def __init__(self, backing_transport, root_client_path='/', timeout=None):
211
"""Construct new server.
213
:param backing_transport: Transport for the directory served.
215
# backing_transport could be passed to serve instead of __init__
216
self.backing_transport = backing_transport
217
self.root_client_path = root_client_path
218
self.finished = False
220
raise AssertionError('You must supply a timeout.')
221
self._client_timeout = timeout
222
self._client_poll_timeout = min(timeout / 10.0, 1.0)
223
SmartMedium.__init__(self)
226
"""Serve requests until the client disconnects."""
227
# Keep a reference to stderr because the sys module's globals get set to
228
# None during interpreter shutdown.
229
from sys import stderr
231
while not self.finished:
232
server_protocol = self._build_protocol()
233
self._serve_one_request(server_protocol)
234
except errors.ConnectionTimeout, e:
235
trace.note('%s' % (e,))
236
trace.log_exception_quietly()
237
self._disconnect_client()
238
# We reported it, no reason to make a big fuss.
241
stderr.write("%s terminating on exception %s\n" % (self, e))
243
self._disconnect_client()
245
def _stop_gracefully(self):
246
"""When we finish this message, stop looking for more."""
247
trace.mutter('Stopping %s' % (self,))
250
def _disconnect_client(self):
251
"""Close the current connection. We stopped due to a timeout/etc."""
252
# The default implementation is a no-op, because that is all we used to
253
# do when disconnecting from a client. I suppose we never had the
254
# *server* initiate a disconnect, before
256
def _wait_for_bytes_with_timeout(self, timeout_seconds):
257
"""Wait for more bytes to be read, but timeout if none available.
259
This allows us to detect idle connections, and stop trying to read from
260
them, without setting the socket itself to non-blocking. This also
261
allows us to specify when we watch for idle timeouts.
263
:return: Did we timeout? (True if we timed out, False if there is data
266
raise NotImplementedError(self._wait_for_bytes_with_timeout)
268
def _build_protocol(self):
269
"""Identifies the version of the incoming request, and returns an
270
a protocol object that can interpret it.
272
If more bytes than the version prefix of the request are read, they will
273
be fed into the protocol before it is returned.
275
:returns: a SmartServerRequestProtocol.
277
self._wait_for_bytes_with_timeout(self._client_timeout)
279
# We're stopping, so don't try to do any more work
281
bytes = self._get_line()
282
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
283
protocol = protocol_factory(
284
self.backing_transport, self._write_out, self.root_client_path)
285
protocol.accept_bytes(unused_bytes)
288
def _wait_on_descriptor(self, fd, timeout_seconds):
289
"""select() on a file descriptor, waiting for nonblocking read()
291
This will raise a ConnectionTimeout exception if we do not get a
292
readable handle before timeout_seconds.
295
t_end = self._timer() + timeout_seconds
296
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
298
while not rs and not xs and self._timer() < t_end:
302
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
303
except (select.error, socket.error) as e:
304
err = getattr(e, 'errno', None)
305
if err is None and getattr(e, 'args', None) is not None:
306
# select.error doesn't have 'errno', it just has args[0]
308
if err in _bad_file_descriptor:
309
return # Not a socket indicates read() will fail
310
elif err == errno.EINTR:
311
# Interrupted, keep looping.
316
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
317
% (timeout_seconds,))
319
def _serve_one_request(self, protocol):
320
"""Read one request from input, process, send back a response.
322
:param protocol: a SmartServerRequestProtocol.
327
self._serve_one_request_unguarded(protocol)
328
except KeyboardInterrupt:
331
self.terminate_due_to_error()
333
def terminate_due_to_error(self):
334
"""Called when an unhandled exception from the protocol occurs."""
335
raise NotImplementedError(self.terminate_due_to_error)
337
def _read_bytes(self, desired_count):
338
"""Get some bytes from the medium.
340
:param desired_count: number of bytes we want to read.
342
raise NotImplementedError(self._read_bytes)
345
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
347
def __init__(self, sock, backing_transport, root_client_path='/',
351
:param sock: the socket the server will read from. It will be put
354
SmartServerStreamMedium.__init__(
355
self, backing_transport, root_client_path=root_client_path,
357
sock.setblocking(True)
359
# Get the getpeername now, as we might be closed later when we care.
361
self._client_info = sock.getpeername()
363
self._client_info = '<unknown>'
366
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
369
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
372
def _serve_one_request_unguarded(self, protocol):
373
while protocol.next_read_size():
374
# We can safely try to read large chunks. If there is less data
375
# than MAX_SOCKET_CHUNK ready, the socket will just return a
376
# short read immediately rather than block.
377
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
381
protocol.accept_bytes(bytes)
383
self._push_back(protocol.unused_data)
385
def _disconnect_client(self):
386
"""Close the current connection. We stopped due to a timeout/etc."""
389
def _wait_for_bytes_with_timeout(self, timeout_seconds):
390
"""Wait for more bytes to be read, but timeout if none available.
392
This allows us to detect idle connections, and stop trying to read from
393
them, without setting the socket itself to non-blocking. This also
394
allows us to specify when we watch for idle timeouts.
396
:return: None, this will raise ConnectionTimeout if we time out before
399
return self._wait_on_descriptor(self.socket, timeout_seconds)
401
def _read_bytes(self, desired_count):
402
return osutils.read_bytes_from_socket(
403
self.socket, self._report_activity)
405
def terminate_due_to_error(self):
406
# TODO: This should log to a server log file, but no such thing
407
# exists yet. Andrew Bennetts 2006-09-29.
411
def _write_out(self, bytes):
412
tstart = osutils.timer_func()
413
osutils.send_all(self.socket, bytes, self._report_activity)
414
if 'hpss' in debug.debug_flags:
415
thread_id = thread.get_ident()
416
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
417
% ('wrote', thread_id, len(bytes),
418
osutils.timer_func() - tstart))
421
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
423
def __init__(self, in_file, out_file, backing_transport, timeout=None):
424
"""Construct new server.
426
:param in_file: Python file from which requests can be read.
427
:param out_file: Python file to write responses.
428
:param backing_transport: Transport for the directory served.
430
SmartServerStreamMedium.__init__(self, backing_transport,
432
if sys.platform == 'win32':
433
# force binary mode for files
435
for f in (in_file, out_file):
436
fileno = getattr(f, 'fileno', None)
438
msvcrt.setmode(fileno(), os.O_BINARY)
443
"""See SmartServerStreamMedium.serve"""
444
# This is the regular serve, except it adds signal trapping for soft
446
stop_gracefully = self._stop_gracefully
447
signals.register_on_hangup(id(self), stop_gracefully)
449
return super(SmartServerPipeStreamMedium, self).serve()
451
signals.unregister_on_hangup(id(self))
453
def _serve_one_request_unguarded(self, protocol):
455
# We need to be careful not to read past the end of the current
456
# request, or else the read from the pipe will block, so we use
457
# protocol.next_read_size().
458
bytes_to_read = protocol.next_read_size()
459
if bytes_to_read == 0:
460
# Finished serving this request.
463
bytes = self.read_bytes(bytes_to_read)
465
# Connection has been closed.
469
protocol.accept_bytes(bytes)
471
def _disconnect_client(self):
476
def _wait_for_bytes_with_timeout(self, timeout_seconds):
477
"""Wait for more bytes to be read, but timeout if none available.
479
This allows us to detect idle connections, and stop trying to read from
480
them, without setting the socket itself to non-blocking. This also
481
allows us to specify when we watch for idle timeouts.
483
:return: None, this will raise ConnectionTimeout if we time out before
486
if (getattr(self._in, 'fileno', None) is None
487
or sys.platform == 'win32'):
488
# You can't select() file descriptors on Windows.
490
return self._wait_on_descriptor(self._in, timeout_seconds)
492
def _read_bytes(self, desired_count):
493
return self._in.read(desired_count)
495
def terminate_due_to_error(self):
496
# TODO: This should log to a server log file, but no such thing
497
# exists yet. Andrew Bennetts 2006-09-29.
501
def _write_out(self, bytes):
502
self._out.write(bytes)
505
class SmartClientMediumRequest(object):
506
"""A request on a SmartClientMedium.
508
Each request allows bytes to be provided to it via accept_bytes, and then
509
the response bytes to be read via read_bytes.
512
request.accept_bytes('123')
513
request.finished_writing()
514
result = request.read_bytes(3)
515
request.finished_reading()
517
It is up to the individual SmartClientMedium whether multiple concurrent
518
requests can exist. See SmartClientMedium.get_request to obtain instances
519
of SmartClientMediumRequest, and the concrete Medium you are using for
520
details on concurrency and pipelining.
523
def __init__(self, medium):
524
"""Construct a SmartClientMediumRequest for the medium medium."""
525
self._medium = medium
526
# we track state by constants - we may want to use the same
527
# pattern as BodyReader if it gets more complex.
528
# valid states are: "writing", "reading", "done"
529
self._state = "writing"
531
def accept_bytes(self, bytes):
532
"""Accept bytes for inclusion in this request.
534
This method may not be called after finished_writing() has been
535
called. It depends upon the Medium whether or not the bytes will be
536
immediately transmitted. Message based Mediums will tend to buffer the
537
bytes until finished_writing() is called.
539
:param bytes: A bytestring.
541
if self._state != "writing":
542
raise errors.WritingCompleted(self)
543
self._accept_bytes(bytes)
545
def _accept_bytes(self, bytes):
546
"""Helper for accept_bytes.
548
Accept_bytes checks the state of the request to determing if bytes
549
should be accepted. After that it hands off to _accept_bytes to do the
552
raise NotImplementedError(self._accept_bytes)
554
def finished_reading(self):
555
"""Inform the request that all desired data has been read.
557
This will remove the request from the pipeline for its medium (if the
558
medium supports pipelining) and any further calls to methods on the
559
request will raise ReadingCompleted.
561
if self._state == "writing":
562
raise errors.WritingNotComplete(self)
563
if self._state != "reading":
564
raise errors.ReadingCompleted(self)
566
self._finished_reading()
568
def _finished_reading(self):
569
"""Helper for finished_reading.
571
finished_reading checks the state of the request to determine if
572
finished_reading is allowed, and if it is hands off to _finished_reading
573
to perform the action.
575
raise NotImplementedError(self._finished_reading)
577
def finished_writing(self):
578
"""Finish the writing phase of this request.
580
This will flush all pending data for this request along the medium.
581
After calling finished_writing, you may not call accept_bytes anymore.
583
if self._state != "writing":
584
raise errors.WritingCompleted(self)
585
self._state = "reading"
586
self._finished_writing()
588
def _finished_writing(self):
589
"""Helper for finished_writing.
591
finished_writing checks the state of the request to determine if
592
finished_writing is allowed, and if it is hands off to _finished_writing
593
to perform the action.
595
raise NotImplementedError(self._finished_writing)
597
def read_bytes(self, count):
598
"""Read bytes from this requests response.
600
This method will block and wait for count bytes to be read. It may not
601
be invoked until finished_writing() has been called - this is to ensure
602
a message-based approach to requests, for compatibility with message
603
based mediums like HTTP.
605
if self._state == "writing":
606
raise errors.WritingNotComplete(self)
607
if self._state != "reading":
608
raise errors.ReadingCompleted(self)
609
return self._read_bytes(count)
611
def _read_bytes(self, count):
612
"""Helper for SmartClientMediumRequest.read_bytes.
614
read_bytes checks the state of the request to determing if bytes
615
should be read. After that it hands off to _read_bytes to do the
618
By default this forwards to self._medium.read_bytes because we are
619
operating on the medium's stream.
621
return self._medium.read_bytes(count)
624
line = self._read_line()
625
if not line.endswith('\n'):
626
# end of file encountered reading from server
627
raise errors.ConnectionReset(
628
"Unexpected end of message. Please check connectivity "
629
"and permissions, and report a bug if problems persist.")
632
def _read_line(self):
633
"""Helper for SmartClientMediumRequest.read_line.
635
By default this forwards to self._medium._get_line because we are
636
operating on the medium's stream.
638
return self._medium._get_line()
641
class _VfsRefuser(object):
642
"""An object that refuses all VFS requests.
647
client._SmartClient.hooks.install_named_hook(
648
'call', self.check_vfs, 'vfs refuser')
650
def check_vfs(self, params):
652
request_method = request.request_handlers.get(params.method)
654
# A method we don't know about doesn't count as a VFS method.
656
if issubclass(request_method, vfs.VfsRequest):
657
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
660
class _DebugCounter(object):
661
"""An object that counts the HPSS calls made to each client medium.
663
When a medium is garbage-collected, or failing that when
664
bzrlib.global_state exits, the total number of calls made on that medium
665
are reported via trace.note.
669
self.counts = weakref.WeakKeyDictionary()
670
client._SmartClient.hooks.install_named_hook(
671
'call', self.increment_call_count, 'hpss call counter')
672
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
674
def track(self, medium):
675
"""Start tracking calls made to a medium.
677
This only keeps a weakref to the medium, so shouldn't affect the
680
medium_repr = repr(medium)
681
# Add this medium to the WeakKeyDictionary
682
self.counts[medium] = dict(count=0, vfs_count=0,
683
medium_repr=medium_repr)
684
# Weakref callbacks are fired in reverse order of their association
685
# with the referenced object. So we add a weakref *after* adding to
686
# the WeakKeyDict so that we can report the value from it before the
687
# entry is removed by the WeakKeyDict's own callback.
688
ref = weakref.ref(medium, self.done)
690
def increment_call_count(self, params):
691
# Increment the count in the WeakKeyDictionary
692
value = self.counts[params.medium]
695
request_method = request.request_handlers.get(params.method)
697
# A method we don't know about doesn't count as a VFS method.
699
if issubclass(request_method, vfs.VfsRequest):
700
value['vfs_count'] += 1
703
value = self.counts[ref]
704
count, vfs_count, medium_repr = (
705
value['count'], value['vfs_count'], value['medium_repr'])
706
# In case this callback is invoked for the same ref twice (by the
707
# weakref callback and by the atexit function), set the call count back
708
# to 0 so this item won't be reported twice.
710
value['vfs_count'] = 0
712
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
713
count, vfs_count, medium_repr))
716
for ref in list(self.counts.keys()):
719
_debug_counter = None
723
class SmartClientMedium(SmartMedium):
724
"""Smart client is a medium for sending smart protocol requests over."""
726
def __init__(self, base):
727
super(SmartClientMedium, self).__init__()
729
self._protocol_version_error = None
730
self._protocol_version = None
731
self._done_hello = False
732
# Be optimistic: we assume the remote end can accept new remote
733
# requests until we get an error saying otherwise.
734
# _remote_version_is_before tracks the bzr version the remote side
735
# can be based on what we've seen so far.
736
self._remote_version_is_before = None
737
# Install debug hook function if debug flag is set.
738
if 'hpss' in debug.debug_flags:
739
global _debug_counter
740
if _debug_counter is None:
741
_debug_counter = _DebugCounter()
742
_debug_counter.track(self)
743
if 'hpss_client_no_vfs' in debug.debug_flags:
745
if _vfs_refuser is None:
746
_vfs_refuser = _VfsRefuser()
748
def _is_remote_before(self, version_tuple):
749
"""Is it possible the remote side supports RPCs for a given version?
753
needed_version = (1, 2)
754
if medium._is_remote_before(needed_version):
755
fallback_to_pre_1_2_rpc()
759
except UnknownSmartMethod:
760
medium._remember_remote_is_before(needed_version)
761
fallback_to_pre_1_2_rpc()
763
:seealso: _remember_remote_is_before
765
if self._remote_version_is_before is None:
766
# So far, the remote side seems to support everything
768
return version_tuple >= self._remote_version_is_before
770
def _remember_remote_is_before(self, version_tuple):
771
"""Tell this medium that the remote side is older the given version.
773
:seealso: _is_remote_before
775
if (self._remote_version_is_before is not None and
776
version_tuple > self._remote_version_is_before):
777
# We have been told that the remote side is older than some version
778
# which is newer than a previously supplied older-than version.
779
# This indicates that some smart verb call is not guarded
780
# appropriately (it should simply not have been tried).
782
"_remember_remote_is_before(%r) called, but "
783
"_remember_remote_is_before(%r) was called previously."
784
, version_tuple, self._remote_version_is_before)
785
if 'hpss' in debug.debug_flags:
786
ui.ui_factory.show_warning(
787
"_remember_remote_is_before(%r) called, but "
788
"_remember_remote_is_before(%r) was called previously."
789
% (version_tuple, self._remote_version_is_before))
791
self._remote_version_is_before = version_tuple
793
def protocol_version(self):
794
"""Find out if 'hello' smart request works."""
795
if self._protocol_version_error is not None:
796
raise self._protocol_version_error
797
if not self._done_hello:
799
medium_request = self.get_request()
800
# Send a 'hello' request in protocol version one, for maximum
801
# backwards compatibility.
802
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
803
client_protocol.query_version()
804
self._done_hello = True
805
except errors.SmartProtocolError, e:
806
# Cache the error, just like we would cache a successful
808
self._protocol_version_error = e
812
def should_probe(self):
813
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
816
Some transports are unambiguously smart-only; there's no need to check
817
if the transport is able to carry smart requests, because that's all
818
it is for. In those cases, this method should return False.
820
But some HTTP transports can sometimes fail to carry smart requests,
821
but still be usuable for accessing remote bzrdirs via plain file
822
accesses. So for those transports, their media should return True here
823
so that RemoteBzrDirFormat can determine if it is appropriate for that
828
def disconnect(self):
829
"""If this medium maintains a persistent connection, close it.
831
The default implementation does nothing.
834
def remote_path_from_transport(self, transport):
835
"""Convert transport into a path suitable for using in a request.
837
Note that the resulting remote path doesn't encode the host name or
838
anything but path, so it is only safe to use it in requests sent over
839
the medium from the matching transport.
841
medium_base = urlutils.join(self.base, '/')
842
rel_url = urlutils.relative_url(medium_base, transport.base)
843
return urllib.unquote(rel_url)
846
class SmartClientStreamMedium(SmartClientMedium):
847
"""Stream based medium common class.
849
SmartClientStreamMediums operate on a stream. All subclasses use a common
850
SmartClientStreamMediumRequest for their requests, and should implement
851
_accept_bytes and _read_bytes to allow the request objects to send and
855
def __init__(self, base):
856
SmartClientMedium.__init__(self, base)
857
self._current_request = None
859
def accept_bytes(self, bytes):
860
self._accept_bytes(bytes)
863
"""The SmartClientStreamMedium knows how to close the stream when it is
869
"""Flush the output stream.
871
This method is used by the SmartClientStreamMediumRequest to ensure that
872
all data for a request is sent, to avoid long timeouts or deadlocks.
874
raise NotImplementedError(self._flush)
876
def get_request(self):
877
"""See SmartClientMedium.get_request().
879
SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
882
return SmartClientStreamMediumRequest(self)
885
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
886
"""A client medium using simple pipes.
888
This client does not manage the pipes: it assumes they will always be open.
891
def __init__(self, readable_pipe, writeable_pipe, base):
892
SmartClientStreamMedium.__init__(self, base)
893
self._readable_pipe = readable_pipe
894
self._writeable_pipe = writeable_pipe
896
def _accept_bytes(self, bytes):
897
"""See SmartClientStreamMedium.accept_bytes."""
898
self._writeable_pipe.write(bytes)
899
self._report_activity(len(bytes), 'write')
902
"""See SmartClientStreamMedium._flush()."""
903
self._writeable_pipe.flush()
905
def _read_bytes(self, count):
906
"""See SmartClientStreamMedium._read_bytes."""
907
bytes_to_read = min(count, _MAX_READ_SIZE)
908
bytes = self._readable_pipe.read(bytes_to_read)
909
self._report_activity(len(bytes), 'read')
913
class SSHParams(object):
914
"""A set of parameters for starting a remote bzr via SSH."""
916
def __init__(self, host, port=None, username=None, password=None,
917
bzr_remote_path='bzr'):
920
self.username = username
921
self.password = password
922
self.bzr_remote_path = bzr_remote_path
925
class SmartSSHClientMedium(SmartClientStreamMedium):
926
"""A client medium using SSH.
928
It delegates IO to a SmartClientSocketMedium or
929
SmartClientAlreadyConnectedSocketMedium (depending on platform).
932
def __init__(self, base, ssh_params, vendor=None):
933
"""Creates a client that will connect on the first use.
935
:param ssh_params: A SSHParams instance.
936
:param vendor: An optional override for the ssh vendor to use. See
937
bzrlib.transport.ssh for details on ssh vendors.
939
self._real_medium = None
940
self._ssh_params = ssh_params
941
# for the benefit of progress making a short description of this
943
self._scheme = 'bzr+ssh'
944
# SmartClientStreamMedium stores the repr of this object in its
945
# _DebugCounter so we have to store all the values used in our repr
946
# method before calling the super init.
947
SmartClientStreamMedium.__init__(self, base)
948
self._vendor = vendor
949
self._ssh_connection = None
952
if self._ssh_params.port is None:
955
maybe_port = ':%s' % self._ssh_params.port
956
return "%s(%s://%s@%s%s/)" % (
957
self.__class__.__name__,
959
self._ssh_params.username,
960
self._ssh_params.host,
963
def _accept_bytes(self, bytes):
964
"""See SmartClientStreamMedium.accept_bytes."""
965
self._ensure_connection()
966
self._real_medium.accept_bytes(bytes)
968
def disconnect(self):
969
"""See SmartClientMedium.disconnect()."""
970
if self._real_medium is not None:
971
self._real_medium.disconnect()
972
self._real_medium = None
973
if self._ssh_connection is not None:
974
self._ssh_connection.close()
975
self._ssh_connection = None
977
def _ensure_connection(self):
978
"""Connect this medium if not already connected."""
979
if self._real_medium is not None:
981
if self._vendor is None:
982
vendor = ssh._get_ssh_vendor()
984
vendor = self._vendor
985
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
986
self._ssh_params.password, self._ssh_params.host,
987
self._ssh_params.port,
988
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
989
'--directory=/', '--allow-writes'])
990
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
991
if io_kind == 'socket':
992
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
993
self.base, io_object)
994
elif io_kind == 'pipes':
995
read_from, write_to = io_object
996
self._real_medium = SmartSimplePipesClientMedium(
997
read_from, write_to, self.base)
999
raise AssertionError(
1000
"Unexpected io_kind %r from %r"
1001
% (io_kind, self._ssh_connection))
1004
"""See SmartClientStreamMedium._flush()."""
1005
self._real_medium._flush()
1007
def _read_bytes(self, count):
1008
"""See SmartClientStreamMedium.read_bytes."""
1009
if self._real_medium is None:
1010
raise errors.MediumNotConnected(self)
1011
return self._real_medium.read_bytes(count)
1014
# Port 4155 is the default port for bzr://, registered with IANA.
1015
BZR_DEFAULT_INTERFACE = None
1016
BZR_DEFAULT_PORT = 4155
1019
class SmartClientSocketMedium(SmartClientStreamMedium):
1020
"""A client medium using a socket.
1022
This class isn't usable directly. Use one of its subclasses instead.
1025
def __init__(self, base):
1026
SmartClientStreamMedium.__init__(self, base)
1028
self._connected = False
1030
def _accept_bytes(self, bytes):
1031
"""See SmartClientMedium.accept_bytes."""
1032
self._ensure_connection()
1033
osutils.send_all(self._socket, bytes, self._report_activity)
1035
def _ensure_connection(self):
1036
"""Connect this medium if not already connected."""
1037
raise NotImplementedError(self._ensure_connection)
1040
"""See SmartClientStreamMedium._flush().
1042
For sockets we do no flushing. For TCP sockets we may want to turn off
1043
TCP_NODELAY and add a means to do a flush, but that can be done in the
1047
def _read_bytes(self, count):
1048
"""See SmartClientMedium.read_bytes."""
1049
if not self._connected:
1050
raise errors.MediumNotConnected(self)
1051
return osutils.read_bytes_from_socket(
1052
self._socket, self._report_activity)
1054
def disconnect(self):
1055
"""See SmartClientMedium.disconnect()."""
1056
if not self._connected:
1058
self._socket.close()
1060
self._connected = False
1063
class SmartTCPClientMedium(SmartClientSocketMedium):
1064
"""A client medium that creates a TCP connection."""
1066
def __init__(self, host, port, base):
1067
"""Creates a client that will connect on the first use."""
1068
SmartClientSocketMedium.__init__(self, base)
1072
def _ensure_connection(self):
1073
"""Connect this medium if not already connected."""
1076
if self._port is None:
1077
port = BZR_DEFAULT_PORT
1079
port = int(self._port)
1081
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1082
socket.SOCK_STREAM, 0, 0)
1083
except socket.gaierror, (err_num, err_msg):
1084
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1085
(self._host, port, err_msg))
1086
# Initialize err in case there are no addresses returned:
1087
err = socket.error("no address found for %s" % self._host)
1088
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1090
self._socket = socket.socket(family, socktype, proto)
1091
self._socket.setsockopt(socket.IPPROTO_TCP,
1092
socket.TCP_NODELAY, 1)
1093
self._socket.connect(sockaddr)
1094
except socket.error, err:
1095
if self._socket is not None:
1096
self._socket.close()
1100
if self._socket is None:
1101
# socket errors either have a (string) or (errno, string) as their
1103
if type(err.args) is str:
1106
err_msg = err.args[1]
1107
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1108
(self._host, port, err_msg))
1109
self._connected = True
1112
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1113
"""A client medium for an already connected socket.
1115
Note that this class will assume it "owns" the socket, so it will close it
1116
when its disconnect method is called.
1119
def __init__(self, base, sock):
1120
SmartClientSocketMedium.__init__(self, base)
1122
self._connected = True
1124
def _ensure_connection(self):
1125
# Already connected, by definition! So nothing to do.
1129
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1130
"""A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
1132
def __init__(self, medium):
1133
SmartClientMediumRequest.__init__(self, medium)
1134
# check that we are safe concurrency wise. If some streams start
1135
# allowing concurrent requests - i.e. via multiplexing - then this
1136
# assert should be moved to SmartClientStreamMedium.get_request,
1137
# and the setting/unsetting of _current_request likewise moved into
1138
# that class : but its unneeded overhead for now. RBC 20060922
1139
if self._medium._current_request is not None:
1140
raise errors.TooManyConcurrentRequests(self._medium)
1141
self._medium._current_request = self
1143
def _accept_bytes(self, bytes):
1144
"""See SmartClientMediumRequest._accept_bytes.
1146
This forwards to self._medium._accept_bytes because we are operating
1147
on the mediums stream.
1149
self._medium._accept_bytes(bytes)
1151
def _finished_reading(self):
1152
"""See SmartClientMediumRequest._finished_reading.
1154
This clears the _current_request on self._medium to allow a new
1155
request to be created.
1157
if self._medium._current_request is not self:
1158
raise AssertionError()
1159
self._medium._current_request = None
1161
def _finished_writing(self):
1162
"""See SmartClientMediumRequest._finished_writing.
1164
This invokes self._medium._flush to ensure all bytes are transmitted.
1166
self._medium._flush()