1
# Copyright (C) 2006-2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""The 'medium' layer for the smart servers and clients.
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
bzrlib/transport/smart/__init__.py.
27
from __future__ import absolute_import
35
from bzrlib.lazy_import import lazy_import
36
lazy_import(globals(), """
49
from bzrlib.i18n import gettext
50
from bzrlib.smart import client, protocol, request, signals, vfs
51
from bzrlib.transport import ssh
53
from bzrlib import osutils
55
# Throughout this module buffer size parameters are either limited to be at
56
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
57
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
58
# from non-sockets as well.
59
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
61
def _get_protocol_factory_for_bytes(bytes):
62
"""Determine the right protocol factory for 'bytes'.
64
This will return an appropriate protocol factory depending on the version
65
of the protocol being used, as determined by inspecting the given bytes.
66
The bytes should have at least one newline byte (i.e. be a whole line),
67
otherwise it's possible that a request will be incorrectly identified as
70
Typical use would be::
72
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
73
server_protocol = factory(transport, write_func, root_client_path)
74
server_protocol.accept_bytes(unused_bytes)
76
:param bytes: a str of bytes of the start of the request.
77
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
78
a callable that takes three args: transport, write_func,
79
root_client_path. unused_bytes are any bytes that were not part of a
80
protocol version marker.
82
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
83
protocol_factory = protocol.build_server_protocol_three
84
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
85
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
86
protocol_factory = protocol.SmartServerRequestProtocolTwo
87
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
89
protocol_factory = protocol.SmartServerRequestProtocolOne
90
return protocol_factory, bytes
93
def _get_line(read_bytes_func):
94
"""Read bytes using read_bytes_func until a newline byte.
96
This isn't particularly efficient, so should only be used when the
97
expected size of the line is quite short.
99
:returns: a tuple of two strs: (line, excess)
103
while newline_pos == -1:
104
new_bytes = read_bytes_func(1)
107
# Ran out of bytes before receiving a complete line.
109
newline_pos = bytes.find('\n')
110
line = bytes[:newline_pos+1]
111
excess = bytes[newline_pos+1:]
115
class SmartMedium(object):
116
"""Base class for smart protocol media, both client- and server-side."""
119
self._push_back_buffer = None
121
def _push_back(self, bytes):
122
"""Return unused bytes to the medium, because they belong to the next
125
This sets the _push_back_buffer to the given bytes.
127
if self._push_back_buffer is not None:
128
raise AssertionError(
129
"_push_back called when self._push_back_buffer is %r"
130
% (self._push_back_buffer,))
133
self._push_back_buffer = bytes
135
def _get_push_back_buffer(self):
136
if self._push_back_buffer == '':
137
raise AssertionError(
138
'%s._push_back_buffer should never be the empty string, '
139
'which can be confused with EOF' % (self,))
140
bytes = self._push_back_buffer
141
self._push_back_buffer = None
144
def read_bytes(self, desired_count):
145
"""Read some bytes from this medium.
147
:returns: some bytes, possibly more or less than the number requested
148
in 'desired_count' depending on the medium.
150
if self._push_back_buffer is not None:
151
return self._get_push_back_buffer()
152
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
153
return self._read_bytes(bytes_to_read)
155
def _read_bytes(self, count):
156
raise NotImplementedError(self._read_bytes)
159
"""Read bytes from this request's response until a newline byte.
161
This isn't particularly efficient, so should only be used when the
162
expected size of the line is quite short.
164
:returns: a string of bytes ending in a newline (byte 0x0A).
166
line, excess = _get_line(self.read_bytes)
167
self._push_back(excess)
170
def _report_activity(self, bytes, direction):
171
"""Notify that this medium has activity.
173
Implementations should call this from all methods that actually do IO.
174
Be careful that it's not called twice, if one method is implemented on
177
:param bytes: Number of bytes read or written.
178
:param direction: 'read' or 'write' or None.
180
ui.ui_factory.report_transport_activity(self, bytes, direction)
183
_bad_file_descriptor = (errno.EBADF,)
184
if sys.platform == 'win32':
185
# Given on Windows if you pass a closed socket to select.select. Probably
186
# also given if you pass a file handle to select.
188
_bad_file_descriptor += (WSAENOTSOCK,)
191
class SmartServerStreamMedium(SmartMedium):
192
"""Handles smart commands coming over a stream.
194
The stream may be a pipe connected to sshd, or a tcp socket, or an
195
in-process fifo for testing.
197
One instance is created for each connected client; it can serve multiple
198
requests in the lifetime of the connection.
200
The server passes requests through to an underlying backing transport,
201
which will typically be a LocalTransport looking at the server's filesystem.
203
:ivar _push_back_buffer: a str of bytes that have been read from the stream
204
but not used yet, or None if there are no buffered bytes. Subclasses
205
should make sure to exhaust this buffer before reading more bytes from
206
the stream. See also the _push_back method.
211
def __init__(self, backing_transport, root_client_path='/', timeout=None):
212
"""Construct new server.
214
:param backing_transport: Transport for the directory served.
216
# backing_transport could be passed to serve instead of __init__
217
self.backing_transport = backing_transport
218
self.root_client_path = root_client_path
219
self.finished = False
221
raise AssertionError('You must supply a timeout.')
222
self._client_timeout = timeout
223
self._client_poll_timeout = min(timeout / 10.0, 1.0)
224
SmartMedium.__init__(self)
227
"""Serve requests until the client disconnects."""
228
# Keep a reference to stderr because the sys module's globals get set to
229
# None during interpreter shutdown.
230
from sys import stderr
232
while not self.finished:
233
server_protocol = self._build_protocol()
234
self._serve_one_request(server_protocol)
235
except errors.ConnectionTimeout, e:
236
trace.note('%s' % (e,))
237
trace.log_exception_quietly()
238
self._disconnect_client()
239
# We reported it, no reason to make a big fuss.
242
stderr.write("%s terminating on exception %s\n" % (self, e))
244
self._disconnect_client()
246
def _stop_gracefully(self):
247
"""When we finish this message, stop looking for more."""
248
trace.mutter('Stopping %s' % (self,))
251
def _disconnect_client(self):
252
"""Close the current connection. We stopped due to a timeout/etc."""
253
# The default implementation is a no-op, because that is all we used to
254
# do when disconnecting from a client. I suppose we never had the
255
# *server* initiate a disconnect, before
257
def _wait_for_bytes_with_timeout(self, timeout_seconds):
258
"""Wait for more bytes to be read, but timeout if none available.
260
This allows us to detect idle connections, and stop trying to read from
261
them, without setting the socket itself to non-blocking. This also
262
allows us to specify when we watch for idle timeouts.
264
:return: Did we timeout? (True if we timed out, False if there is data
267
raise NotImplementedError(self._wait_for_bytes_with_timeout)
269
def _build_protocol(self):
270
"""Identifies the version of the incoming request, and returns an
271
a protocol object that can interpret it.
273
If more bytes than the version prefix of the request are read, they will
274
be fed into the protocol before it is returned.
276
:returns: a SmartServerRequestProtocol.
278
self._wait_for_bytes_with_timeout(self._client_timeout)
280
# We're stopping, so don't try to do any more work
282
bytes = self._get_line()
283
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
284
protocol = protocol_factory(
285
self.backing_transport, self._write_out, self.root_client_path)
286
protocol.accept_bytes(unused_bytes)
289
def _wait_on_descriptor(self, fd, timeout_seconds):
290
"""select() on a file descriptor, waiting for nonblocking read()
292
This will raise a ConnectionTimeout exception if we do not get a
293
readable handle before timeout_seconds.
296
t_end = self._timer() + timeout_seconds
297
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
299
while not rs and not xs and self._timer() < t_end:
303
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
304
except (select.error, socket.error) as e:
305
err = getattr(e, 'errno', None)
306
if err is None and getattr(e, 'args', None) is not None:
307
# select.error doesn't have 'errno', it just has args[0]
309
if err in _bad_file_descriptor:
310
return # Not a socket indicates read() will fail
311
elif err == errno.EINTR:
312
# Interrupted, keep looping.
317
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
318
% (timeout_seconds,))
320
def _serve_one_request(self, protocol):
321
"""Read one request from input, process, send back a response.
323
:param protocol: a SmartServerRequestProtocol.
328
self._serve_one_request_unguarded(protocol)
329
except KeyboardInterrupt:
332
self.terminate_due_to_error()
334
def terminate_due_to_error(self):
335
"""Called when an unhandled exception from the protocol occurs."""
336
raise NotImplementedError(self.terminate_due_to_error)
338
def _read_bytes(self, desired_count):
339
"""Get some bytes from the medium.
341
:param desired_count: number of bytes we want to read.
343
raise NotImplementedError(self._read_bytes)
346
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
348
def __init__(self, sock, backing_transport, root_client_path='/',
352
:param sock: the socket the server will read from. It will be put
355
SmartServerStreamMedium.__init__(
356
self, backing_transport, root_client_path=root_client_path,
358
sock.setblocking(True)
360
# Get the getpeername now, as we might be closed later when we care.
362
self._client_info = sock.getpeername()
364
self._client_info = '<unknown>'
367
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
370
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
373
def _serve_one_request_unguarded(self, protocol):
374
while protocol.next_read_size():
375
# We can safely try to read large chunks. If there is less data
376
# than MAX_SOCKET_CHUNK ready, the socket will just return a
377
# short read immediately rather than block.
378
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
382
protocol.accept_bytes(bytes)
384
self._push_back(protocol.unused_data)
386
def _disconnect_client(self):
387
"""Close the current connection. We stopped due to a timeout/etc."""
390
def _wait_for_bytes_with_timeout(self, timeout_seconds):
391
"""Wait for more bytes to be read, but timeout if none available.
393
This allows us to detect idle connections, and stop trying to read from
394
them, without setting the socket itself to non-blocking. This also
395
allows us to specify when we watch for idle timeouts.
397
:return: None, this will raise ConnectionTimeout if we time out before
400
return self._wait_on_descriptor(self.socket, timeout_seconds)
402
def _read_bytes(self, desired_count):
403
return osutils.read_bytes_from_socket(
404
self.socket, self._report_activity)
406
def terminate_due_to_error(self):
407
# TODO: This should log to a server log file, but no such thing
408
# exists yet. Andrew Bennetts 2006-09-29.
412
def _write_out(self, bytes):
413
tstart = osutils.timer_func()
414
osutils.send_all(self.socket, bytes, self._report_activity)
415
if 'hpss' in debug.debug_flags:
416
thread_id = thread.get_ident()
417
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
418
% ('wrote', thread_id, len(bytes),
419
osutils.timer_func() - tstart))
422
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
424
def __init__(self, in_file, out_file, backing_transport, timeout=None):
425
"""Construct new server.
427
:param in_file: Python file from which requests can be read.
428
:param out_file: Python file to write responses.
429
:param backing_transport: Transport for the directory served.
431
SmartServerStreamMedium.__init__(self, backing_transport,
433
if sys.platform == 'win32':
434
# force binary mode for files
436
for f in (in_file, out_file):
437
fileno = getattr(f, 'fileno', None)
439
msvcrt.setmode(fileno(), os.O_BINARY)
444
"""See SmartServerStreamMedium.serve"""
445
# This is the regular serve, except it adds signal trapping for soft
447
stop_gracefully = self._stop_gracefully
448
signals.register_on_hangup(id(self), stop_gracefully)
450
return super(SmartServerPipeStreamMedium, self).serve()
452
signals.unregister_on_hangup(id(self))
454
def _serve_one_request_unguarded(self, protocol):
456
# We need to be careful not to read past the end of the current
457
# request, or else the read from the pipe will block, so we use
458
# protocol.next_read_size().
459
bytes_to_read = protocol.next_read_size()
460
if bytes_to_read == 0:
461
# Finished serving this request.
464
bytes = self.read_bytes(bytes_to_read)
466
# Connection has been closed.
470
protocol.accept_bytes(bytes)
472
def _disconnect_client(self):
477
def _wait_for_bytes_with_timeout(self, timeout_seconds):
478
"""Wait for more bytes to be read, but timeout if none available.
480
This allows us to detect idle connections, and stop trying to read from
481
them, without setting the socket itself to non-blocking. This also
482
allows us to specify when we watch for idle timeouts.
484
:return: None, this will raise ConnectionTimeout if we time out before
487
if (getattr(self._in, 'fileno', None) is None
488
or sys.platform == 'win32'):
489
# You can't select() file descriptors on Windows.
491
return self._wait_on_descriptor(self._in, timeout_seconds)
493
def _read_bytes(self, desired_count):
494
return self._in.read(desired_count)
496
def terminate_due_to_error(self):
497
# TODO: This should log to a server log file, but no such thing
498
# exists yet. Andrew Bennetts 2006-09-29.
502
def _write_out(self, bytes):
503
self._out.write(bytes)
506
class SmartClientMediumRequest(object):
507
"""A request on a SmartClientMedium.
509
Each request allows bytes to be provided to it via accept_bytes, and then
510
the response bytes to be read via read_bytes.
513
request.accept_bytes('123')
514
request.finished_writing()
515
result = request.read_bytes(3)
516
request.finished_reading()
518
It is up to the individual SmartClientMedium whether multiple concurrent
519
requests can exist. See SmartClientMedium.get_request to obtain instances
520
of SmartClientMediumRequest, and the concrete Medium you are using for
521
details on concurrency and pipelining.
524
def __init__(self, medium):
525
"""Construct a SmartClientMediumRequest for the medium medium."""
526
self._medium = medium
527
# we track state by constants - we may want to use the same
528
# pattern as BodyReader if it gets more complex.
529
# valid states are: "writing", "reading", "done"
530
self._state = "writing"
532
def accept_bytes(self, bytes):
533
"""Accept bytes for inclusion in this request.
535
This method may not be called after finished_writing() has been
536
called. It depends upon the Medium whether or not the bytes will be
537
immediately transmitted. Message based Mediums will tend to buffer the
538
bytes until finished_writing() is called.
540
:param bytes: A bytestring.
542
if self._state != "writing":
543
raise errors.WritingCompleted(self)
544
self._accept_bytes(bytes)
546
def _accept_bytes(self, bytes):
547
"""Helper for accept_bytes.
549
Accept_bytes checks the state of the request to determing if bytes
550
should be accepted. After that it hands off to _accept_bytes to do the
553
raise NotImplementedError(self._accept_bytes)
555
def finished_reading(self):
556
"""Inform the request that all desired data has been read.
558
This will remove the request from the pipeline for its medium (if the
559
medium supports pipelining) and any further calls to methods on the
560
request will raise ReadingCompleted.
562
if self._state == "writing":
563
raise errors.WritingNotComplete(self)
564
if self._state != "reading":
565
raise errors.ReadingCompleted(self)
567
self._finished_reading()
569
def _finished_reading(self):
570
"""Helper for finished_reading.
572
finished_reading checks the state of the request to determine if
573
finished_reading is allowed, and if it is hands off to _finished_reading
574
to perform the action.
576
raise NotImplementedError(self._finished_reading)
578
def finished_writing(self):
579
"""Finish the writing phase of this request.
581
This will flush all pending data for this request along the medium.
582
After calling finished_writing, you may not call accept_bytes anymore.
584
if self._state != "writing":
585
raise errors.WritingCompleted(self)
586
self._state = "reading"
587
self._finished_writing()
589
def _finished_writing(self):
590
"""Helper for finished_writing.
592
finished_writing checks the state of the request to determine if
593
finished_writing is allowed, and if it is hands off to _finished_writing
594
to perform the action.
596
raise NotImplementedError(self._finished_writing)
598
def read_bytes(self, count):
599
"""Read bytes from this requests response.
601
This method will block and wait for count bytes to be read. It may not
602
be invoked until finished_writing() has been called - this is to ensure
603
a message-based approach to requests, for compatibility with message
604
based mediums like HTTP.
606
if self._state == "writing":
607
raise errors.WritingNotComplete(self)
608
if self._state != "reading":
609
raise errors.ReadingCompleted(self)
610
return self._read_bytes(count)
612
def _read_bytes(self, count):
613
"""Helper for SmartClientMediumRequest.read_bytes.
615
read_bytes checks the state of the request to determing if bytes
616
should be read. After that it hands off to _read_bytes to do the
619
By default this forwards to self._medium.read_bytes because we are
620
operating on the medium's stream.
622
return self._medium.read_bytes(count)
625
line = self._read_line()
626
if not line.endswith('\n'):
627
# end of file encountered reading from server
628
raise errors.ConnectionReset(
629
"Unexpected end of message. Please check connectivity "
630
"and permissions, and report a bug if problems persist.")
633
def _read_line(self):
634
"""Helper for SmartClientMediumRequest.read_line.
636
By default this forwards to self._medium._get_line because we are
637
operating on the medium's stream.
639
return self._medium._get_line()
642
class _VfsRefuser(object):
643
"""An object that refuses all VFS requests.
648
client._SmartClient.hooks.install_named_hook(
649
'call', self.check_vfs, 'vfs refuser')
651
def check_vfs(self, params):
653
request_method = request.request_handlers.get(params.method)
655
# A method we don't know about doesn't count as a VFS method.
657
if issubclass(request_method, vfs.VfsRequest):
658
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
661
class _DebugCounter(object):
662
"""An object that counts the HPSS calls made to each client medium.
664
When a medium is garbage-collected, or failing that when
665
bzrlib.global_state exits, the total number of calls made on that medium
666
are reported via trace.note.
670
self.counts = weakref.WeakKeyDictionary()
671
client._SmartClient.hooks.install_named_hook(
672
'call', self.increment_call_count, 'hpss call counter')
673
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
675
def track(self, medium):
676
"""Start tracking calls made to a medium.
678
This only keeps a weakref to the medium, so shouldn't affect the
681
medium_repr = repr(medium)
682
# Add this medium to the WeakKeyDictionary
683
self.counts[medium] = dict(count=0, vfs_count=0,
684
medium_repr=medium_repr)
685
# Weakref callbacks are fired in reverse order of their association
686
# with the referenced object. So we add a weakref *after* adding to
687
# the WeakKeyDict so that we can report the value from it before the
688
# entry is removed by the WeakKeyDict's own callback.
689
ref = weakref.ref(medium, self.done)
691
def increment_call_count(self, params):
692
# Increment the count in the WeakKeyDictionary
693
value = self.counts[params.medium]
696
request_method = request.request_handlers.get(params.method)
698
# A method we don't know about doesn't count as a VFS method.
700
if issubclass(request_method, vfs.VfsRequest):
701
value['vfs_count'] += 1
704
value = self.counts[ref]
705
count, vfs_count, medium_repr = (
706
value['count'], value['vfs_count'], value['medium_repr'])
707
# In case this callback is invoked for the same ref twice (by the
708
# weakref callback and by the atexit function), set the call count back
709
# to 0 so this item won't be reported twice.
711
value['vfs_count'] = 0
713
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
714
count, vfs_count, medium_repr))
717
for ref in list(self.counts.keys()):
720
_debug_counter = None
724
class SmartClientMedium(SmartMedium):
725
"""Smart client is a medium for sending smart protocol requests over."""
727
def __init__(self, base):
728
super(SmartClientMedium, self).__init__()
730
self._protocol_version_error = None
731
self._protocol_version = None
732
self._done_hello = False
733
# Be optimistic: we assume the remote end can accept new remote
734
# requests until we get an error saying otherwise.
735
# _remote_version_is_before tracks the bzr version the remote side
736
# can be based on what we've seen so far.
737
self._remote_version_is_before = None
738
# Install debug hook function if debug flag is set.
739
if 'hpss' in debug.debug_flags:
740
global _debug_counter
741
if _debug_counter is None:
742
_debug_counter = _DebugCounter()
743
_debug_counter.track(self)
744
if 'hpss_client_no_vfs' in debug.debug_flags:
746
if _vfs_refuser is None:
747
_vfs_refuser = _VfsRefuser()
749
def _is_remote_before(self, version_tuple):
750
"""Is it possible the remote side supports RPCs for a given version?
754
needed_version = (1, 2)
755
if medium._is_remote_before(needed_version):
756
fallback_to_pre_1_2_rpc()
760
except UnknownSmartMethod:
761
medium._remember_remote_is_before(needed_version)
762
fallback_to_pre_1_2_rpc()
764
:seealso: _remember_remote_is_before
766
if self._remote_version_is_before is None:
767
# So far, the remote side seems to support everything
769
return version_tuple >= self._remote_version_is_before
771
def _remember_remote_is_before(self, version_tuple):
772
"""Tell this medium that the remote side is older the given version.
774
:seealso: _is_remote_before
776
if (self._remote_version_is_before is not None and
777
version_tuple > self._remote_version_is_before):
778
# We have been told that the remote side is older than some version
779
# which is newer than a previously supplied older-than version.
780
# This indicates that some smart verb call is not guarded
781
# appropriately (it should simply not have been tried).
783
"_remember_remote_is_before(%r) called, but "
784
"_remember_remote_is_before(%r) was called previously."
785
, version_tuple, self._remote_version_is_before)
786
if 'hpss' in debug.debug_flags:
787
ui.ui_factory.show_warning(
788
"_remember_remote_is_before(%r) called, but "
789
"_remember_remote_is_before(%r) was called previously."
790
% (version_tuple, self._remote_version_is_before))
792
self._remote_version_is_before = version_tuple
794
def protocol_version(self):
795
"""Find out if 'hello' smart request works."""
796
if self._protocol_version_error is not None:
797
raise self._protocol_version_error
798
if not self._done_hello:
800
medium_request = self.get_request()
801
# Send a 'hello' request in protocol version one, for maximum
802
# backwards compatibility.
803
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
804
client_protocol.query_version()
805
self._done_hello = True
806
except errors.SmartProtocolError, e:
807
# Cache the error, just like we would cache a successful
809
self._protocol_version_error = e
813
def should_probe(self):
814
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
817
Some transports are unambiguously smart-only; there's no need to check
818
if the transport is able to carry smart requests, because that's all
819
it is for. In those cases, this method should return False.
821
But some HTTP transports can sometimes fail to carry smart requests,
822
but still be usuable for accessing remote bzrdirs via plain file
823
accesses. So for those transports, their media should return True here
824
so that RemoteBzrDirFormat can determine if it is appropriate for that
829
def disconnect(self):
830
"""If this medium maintains a persistent connection, close it.
832
The default implementation does nothing.
835
def remote_path_from_transport(self, transport):
836
"""Convert transport into a path suitable for using in a request.
838
Note that the resulting remote path doesn't encode the host name or
839
anything but path, so it is only safe to use it in requests sent over
840
the medium from the matching transport.
842
medium_base = urlutils.join(self.base, '/')
843
rel_url = urlutils.relative_url(medium_base, transport.base)
844
return urlutils.unquote(rel_url)
847
class SmartClientStreamMedium(SmartClientMedium):
848
"""Stream based medium common class.
850
SmartClientStreamMediums operate on a stream. All subclasses use a common
851
SmartClientStreamMediumRequest for their requests, and should implement
852
_accept_bytes and _read_bytes to allow the request objects to send and
856
def __init__(self, base):
857
SmartClientMedium.__init__(self, base)
858
self._current_request = None
860
def accept_bytes(self, bytes):
861
self._accept_bytes(bytes)
864
"""The SmartClientStreamMedium knows how to close the stream when it is
870
"""Flush the output stream.
872
This method is used by the SmartClientStreamMediumRequest to ensure that
873
all data for a request is sent, to avoid long timeouts or deadlocks.
875
raise NotImplementedError(self._flush)
877
def get_request(self):
878
"""See SmartClientMedium.get_request().
880
SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
883
return SmartClientStreamMediumRequest(self)
886
"""We have been disconnected, reset current state.
888
This resets things like _current_request and connected state.
891
self._current_request = None
894
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
895
"""A client medium using simple pipes.
897
This client does not manage the pipes: it assumes they will always be open.
900
def __init__(self, readable_pipe, writeable_pipe, base):
901
SmartClientStreamMedium.__init__(self, base)
902
self._readable_pipe = readable_pipe
903
self._writeable_pipe = writeable_pipe
905
def _accept_bytes(self, bytes):
906
"""See SmartClientStreamMedium.accept_bytes."""
908
self._writeable_pipe.write(bytes)
910
if e.errno in (errno.EINVAL, errno.EPIPE):
911
raise errors.ConnectionReset(
912
"Error trying to write to subprocess:\n%s" % (e,))
914
self._report_activity(len(bytes), 'write')
917
"""See SmartClientStreamMedium._flush()."""
918
# Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
919
# However, testing shows that even when the child process is
920
# gone, this doesn't error.
921
self._writeable_pipe.flush()
923
def _read_bytes(self, count):
924
"""See SmartClientStreamMedium._read_bytes."""
925
bytes_to_read = min(count, _MAX_READ_SIZE)
926
bytes = self._readable_pipe.read(bytes_to_read)
927
self._report_activity(len(bytes), 'read')
931
class SSHParams(object):
932
"""A set of parameters for starting a remote bzr via SSH."""
934
def __init__(self, host, port=None, username=None, password=None,
935
bzr_remote_path='bzr'):
938
self.username = username
939
self.password = password
940
self.bzr_remote_path = bzr_remote_path
943
class SmartSSHClientMedium(SmartClientStreamMedium):
944
"""A client medium using SSH.
946
It delegates IO to a SmartSimplePipesClientMedium or
947
SmartClientAlreadyConnectedSocketMedium (depending on platform).
950
def __init__(self, base, ssh_params, vendor=None):
951
"""Creates a client that will connect on the first use.
953
:param ssh_params: A SSHParams instance.
954
:param vendor: An optional override for the ssh vendor to use. See
955
bzrlib.transport.ssh for details on ssh vendors.
957
self._real_medium = None
958
self._ssh_params = ssh_params
959
# for the benefit of progress making a short description of this
961
self._scheme = 'bzr+ssh'
962
# SmartClientStreamMedium stores the repr of this object in its
963
# _DebugCounter so we have to store all the values used in our repr
964
# method before calling the super init.
965
SmartClientStreamMedium.__init__(self, base)
966
self._vendor = vendor
967
self._ssh_connection = None
970
if self._ssh_params.port is None:
973
maybe_port = ':%s' % self._ssh_params.port
974
if self._ssh_params.username is None:
977
maybe_user = '%s@' % self._ssh_params.username
978
return "%s(%s://%s%s%s/)" % (
979
self.__class__.__name__,
982
self._ssh_params.host,
985
def _accept_bytes(self, bytes):
986
"""See SmartClientStreamMedium.accept_bytes."""
987
self._ensure_connection()
988
self._real_medium.accept_bytes(bytes)
990
def disconnect(self):
991
"""See SmartClientMedium.disconnect()."""
992
if self._real_medium is not None:
993
self._real_medium.disconnect()
994
self._real_medium = None
995
if self._ssh_connection is not None:
996
self._ssh_connection.close()
997
self._ssh_connection = None
999
def _ensure_connection(self):
1000
"""Connect this medium if not already connected."""
1001
if self._real_medium is not None:
1003
if self._vendor is None:
1004
vendor = ssh._get_ssh_vendor()
1006
vendor = self._vendor
1007
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
1008
self._ssh_params.password, self._ssh_params.host,
1009
self._ssh_params.port,
1010
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
1011
'--directory=/', '--allow-writes'])
1012
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
1013
if io_kind == 'socket':
1014
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
1015
self.base, io_object)
1016
elif io_kind == 'pipes':
1017
read_from, write_to = io_object
1018
self._real_medium = SmartSimplePipesClientMedium(
1019
read_from, write_to, self.base)
1021
raise AssertionError(
1022
"Unexpected io_kind %r from %r"
1023
% (io_kind, self._ssh_connection))
1026
"""See SmartClientStreamMedium._flush()."""
1027
self._real_medium._flush()
1029
def _read_bytes(self, count):
1030
"""See SmartClientStreamMedium.read_bytes."""
1031
if self._real_medium is None:
1032
raise errors.MediumNotConnected(self)
1033
return self._real_medium.read_bytes(count)
1036
# Port 4155 is the default port for bzr://, registered with IANA.
1037
BZR_DEFAULT_INTERFACE = None
1038
BZR_DEFAULT_PORT = 4155
1041
class SmartClientSocketMedium(SmartClientStreamMedium):
1042
"""A client medium using a socket.
1044
This class isn't usable directly. Use one of its subclasses instead.
1047
def __init__(self, base):
1048
SmartClientStreamMedium.__init__(self, base)
1050
self._connected = False
1052
def _accept_bytes(self, bytes):
1053
"""See SmartClientMedium.accept_bytes."""
1054
self._ensure_connection()
1055
osutils.send_all(self._socket, bytes, self._report_activity)
1057
def _ensure_connection(self):
1058
"""Connect this medium if not already connected."""
1059
raise NotImplementedError(self._ensure_connection)
1062
"""See SmartClientStreamMedium._flush().
1064
For sockets we do no flushing. For TCP sockets we may want to turn off
1065
TCP_NODELAY and add a means to do a flush, but that can be done in the
1069
def _read_bytes(self, count):
1070
"""See SmartClientMedium.read_bytes."""
1071
if not self._connected:
1072
raise errors.MediumNotConnected(self)
1073
return osutils.read_bytes_from_socket(
1074
self._socket, self._report_activity)
1076
def disconnect(self):
1077
"""See SmartClientMedium.disconnect()."""
1078
if not self._connected:
1080
self._socket.close()
1082
self._connected = False
1085
class SmartTCPClientMedium(SmartClientSocketMedium):
1086
"""A client medium that creates a TCP connection."""
1088
def __init__(self, host, port, base):
1089
"""Creates a client that will connect on the first use."""
1090
SmartClientSocketMedium.__init__(self, base)
1094
def _ensure_connection(self):
1095
"""Connect this medium if not already connected."""
1098
if self._port is None:
1099
port = BZR_DEFAULT_PORT
1101
port = int(self._port)
1103
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1104
socket.SOCK_STREAM, 0, 0)
1105
except socket.gaierror, (err_num, err_msg):
1106
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1107
(self._host, port, err_msg))
1108
# Initialize err in case there are no addresses returned:
1109
err = socket.error("no address found for %s" % self._host)
1110
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1112
self._socket = socket.socket(family, socktype, proto)
1113
self._socket.setsockopt(socket.IPPROTO_TCP,
1114
socket.TCP_NODELAY, 1)
1115
self._socket.connect(sockaddr)
1116
except socket.error, err:
1117
if self._socket is not None:
1118
self._socket.close()
1122
if self._socket is None:
1123
# socket errors either have a (string) or (errno, string) as their
1125
if type(err.args) is str:
1128
err_msg = err.args[1]
1129
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1130
(self._host, port, err_msg))
1131
self._connected = True
1134
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1135
"""A client medium for an already connected socket.
1137
Note that this class will assume it "owns" the socket, so it will close it
1138
when its disconnect method is called.
1141
def __init__(self, base, sock):
1142
SmartClientSocketMedium.__init__(self, base)
1144
self._connected = True
1146
def _ensure_connection(self):
1147
# Already connected, by definition! So nothing to do.
1151
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1152
"""A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
1154
def __init__(self, medium):
1155
SmartClientMediumRequest.__init__(self, medium)
1156
# check that we are safe concurrency wise. If some streams start
1157
# allowing concurrent requests - i.e. via multiplexing - then this
1158
# assert should be moved to SmartClientStreamMedium.get_request,
1159
# and the setting/unsetting of _current_request likewise moved into
1160
# that class : but its unneeded overhead for now. RBC 20060922
1161
if self._medium._current_request is not None:
1162
raise errors.TooManyConcurrentRequests(self._medium)
1163
self._medium._current_request = self
1165
def _accept_bytes(self, bytes):
1166
"""See SmartClientMediumRequest._accept_bytes.
1168
This forwards to self._medium._accept_bytes because we are operating
1169
on the mediums stream.
1171
self._medium._accept_bytes(bytes)
1173
def _finished_reading(self):
1174
"""See SmartClientMediumRequest._finished_reading.
1176
This clears the _current_request on self._medium to allow a new
1177
request to be created.
1179
if self._medium._current_request is not self:
1180
raise AssertionError()
1181
self._medium._current_request = None
1183
def _finished_writing(self):
1184
"""See SmartClientMediumRequest._finished_writing.
1186
This invokes self._medium._flush to ensure all bytes are transmitted.
1188
self._medium._flush()