1
# Copyright (C) 2006-2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
"""The 'medium' layer for the smart servers and clients.
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic. See the overview in
24
bzrlib/transport/smart/__init__.py.
34
from bzrlib.lazy_import import lazy_import
35
lazy_import(globals(), """
48
from bzrlib.i18n import gettext
49
from bzrlib.smart import client, protocol, request, signals, vfs
50
from bzrlib.transport import ssh
52
from bzrlib import osutils
54
# Throughout this module buffer size parameters are either limited to be at
55
# most _MAX_READ_SIZE, or are ignored and _MAX_READ_SIZE is used instead.
56
# For this module's purposes, MAX_SOCKET_CHUNK is a reasonable size for reads
57
# from non-sockets as well.
58
_MAX_READ_SIZE = osutils.MAX_SOCKET_CHUNK
60
def _get_protocol_factory_for_bytes(bytes):
61
"""Determine the right protocol factory for 'bytes'.
63
This will return an appropriate protocol factory depending on the version
64
of the protocol being used, as determined by inspecting the given bytes.
65
The bytes should have at least one newline byte (i.e. be a whole line),
66
otherwise it's possible that a request will be incorrectly identified as
69
Typical use would be::
71
factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
72
server_protocol = factory(transport, write_func, root_client_path)
73
server_protocol.accept_bytes(unused_bytes)
75
:param bytes: a str of bytes of the start of the request.
76
:returns: 2-tuple of (protocol_factory, unused_bytes). protocol_factory is
77
a callable that takes three args: transport, write_func,
78
root_client_path. unused_bytes are any bytes that were not part of a
79
protocol version marker.
81
if bytes.startswith(protocol.MESSAGE_VERSION_THREE):
82
protocol_factory = protocol.build_server_protocol_three
83
bytes = bytes[len(protocol.MESSAGE_VERSION_THREE):]
84
elif bytes.startswith(protocol.REQUEST_VERSION_TWO):
85
protocol_factory = protocol.SmartServerRequestProtocolTwo
86
bytes = bytes[len(protocol.REQUEST_VERSION_TWO):]
88
protocol_factory = protocol.SmartServerRequestProtocolOne
89
return protocol_factory, bytes
92
def _get_line(read_bytes_func):
93
"""Read bytes using read_bytes_func until a newline byte.
95
This isn't particularly efficient, so should only be used when the
96
expected size of the line is quite short.
98
:returns: a tuple of two strs: (line, excess)
102
while newline_pos == -1:
103
new_bytes = read_bytes_func(1)
106
# Ran out of bytes before receiving a complete line.
108
newline_pos = bytes.find('\n')
109
line = bytes[:newline_pos+1]
110
excess = bytes[newline_pos+1:]
114
class SmartMedium(object):
115
"""Base class for smart protocol media, both client- and server-side."""
118
self._push_back_buffer = None
120
def _push_back(self, bytes):
121
"""Return unused bytes to the medium, because they belong to the next
124
This sets the _push_back_buffer to the given bytes.
126
if self._push_back_buffer is not None:
127
raise AssertionError(
128
"_push_back called when self._push_back_buffer is %r"
129
% (self._push_back_buffer,))
132
self._push_back_buffer = bytes
134
def _get_push_back_buffer(self):
135
if self._push_back_buffer == '':
136
raise AssertionError(
137
'%s._push_back_buffer should never be the empty string, '
138
'which can be confused with EOF' % (self,))
139
bytes = self._push_back_buffer
140
self._push_back_buffer = None
143
def read_bytes(self, desired_count):
144
"""Read some bytes from this medium.
146
:returns: some bytes, possibly more or less than the number requested
147
in 'desired_count' depending on the medium.
149
if self._push_back_buffer is not None:
150
return self._get_push_back_buffer()
151
bytes_to_read = min(desired_count, _MAX_READ_SIZE)
152
return self._read_bytes(bytes_to_read)
154
def _read_bytes(self, count):
155
raise NotImplementedError(self._read_bytes)
158
"""Read bytes from this request's response until a newline byte.
160
This isn't particularly efficient, so should only be used when the
161
expected size of the line is quite short.
163
:returns: a string of bytes ending in a newline (byte 0x0A).
165
line, excess = _get_line(self.read_bytes)
166
self._push_back(excess)
169
def _report_activity(self, bytes, direction):
170
"""Notify that this medium has activity.
172
Implementations should call this from all methods that actually do IO.
173
Be careful that it's not called twice, if one method is implemented on
176
:param bytes: Number of bytes read or written.
177
:param direction: 'read' or 'write' or None.
179
ui.ui_factory.report_transport_activity(self, bytes, direction)
182
_bad_file_descriptor = (errno.EBADF,)
183
if sys.platform == 'win32':
184
# Given on Windows if you pass a closed socket to select.select. Probably
185
# also given if you pass a file handle to select.
187
_bad_file_descriptor += (WSAENOTSOCK,)
190
class SmartServerStreamMedium(SmartMedium):
191
"""Handles smart commands coming over a stream.
193
The stream may be a pipe connected to sshd, or a tcp socket, or an
194
in-process fifo for testing.
196
One instance is created for each connected client; it can serve multiple
197
requests in the lifetime of the connection.
199
The server passes requests through to an underlying backing transport,
200
which will typically be a LocalTransport looking at the server's filesystem.
202
:ivar _push_back_buffer: a str of bytes that have been read from the stream
203
but not used yet, or None if there are no buffered bytes. Subclasses
204
should make sure to exhaust this buffer before reading more bytes from
205
the stream. See also the _push_back method.
210
def __init__(self, backing_transport, root_client_path='/', timeout=None):
211
"""Construct new server.
213
:param backing_transport: Transport for the directory served.
215
# backing_transport could be passed to serve instead of __init__
216
self.backing_transport = backing_transport
217
self.root_client_path = root_client_path
218
self.finished = False
220
raise AssertionError('You must supply a timeout.')
221
self._client_timeout = timeout
222
self._client_poll_timeout = min(timeout / 10.0, 1.0)
223
SmartMedium.__init__(self)
226
"""Serve requests until the client disconnects."""
227
# Keep a reference to stderr because the sys module's globals get set to
228
# None during interpreter shutdown.
229
from sys import stderr
231
while not self.finished:
232
server_protocol = self._build_protocol()
233
# TODO: This seems inelegant:
234
if server_protocol is None:
235
# We could 'continue' only to notice that self.finished is
238
self._serve_one_request(server_protocol)
239
except errors.ConnectionTimeout, e:
240
trace.note('%s' % (e,))
241
trace.log_exception_quietly()
242
self._disconnect_client()
243
# We reported it, no reason to make a big fuss.
246
stderr.write("%s terminating on exception %s\n" % (self, e))
248
self._disconnect_client()
250
def _stop_gracefully(self):
251
"""When we finish this message, stop looking for more."""
252
trace.mutter('Stopping %s' % (self,))
255
def _disconnect_client(self):
256
"""Close the current connection. We stopped due to a timeout/etc."""
257
# The default implementation is a no-op, because that is all we used to
258
# do when disconnecting from a client. I suppose we never had the
259
# *server* initiate a disconnect, before
261
def _wait_for_bytes_with_timeout(self, timeout_seconds):
262
"""Wait for more bytes to be read, but timeout if none available.
264
This allows us to detect idle connections, and stop trying to read from
265
them, without setting the socket itself to non-blocking. This also
266
allows us to specify when we watch for idle timeouts.
268
:return: Did we timeout? (True if we timed out, False if there is data
271
raise NotImplementedError(self._wait_for_bytes_with_timeout)
273
def _build_protocol(self):
274
"""Identifies the version of the incoming request, and returns an
275
a protocol object that can interpret it.
277
If more bytes than the version prefix of the request are read, they will
278
be fed into the protocol before it is returned.
280
:returns: a SmartServerRequestProtocol.
282
self._wait_for_bytes_with_timeout(self._client_timeout)
284
# We're stopping, so don't try to do any more work
286
bytes = self._get_line()
287
protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
288
protocol = protocol_factory(
289
self.backing_transport, self._write_out, self.root_client_path)
290
protocol.accept_bytes(unused_bytes)
293
def _wait_on_descriptor(self, fd, timeout_seconds):
294
"""select() on a file descriptor, waiting for nonblocking read()
296
This will raise a ConnectionTimeout exception if we do not get a
297
readable handle before timeout_seconds.
300
t_end = self._timer() + timeout_seconds
301
poll_timeout = min(timeout_seconds, self._client_poll_timeout)
303
while not rs and not xs and self._timer() < t_end:
307
rs, _, xs = select.select([fd], [], [fd], poll_timeout)
308
except (select.error, socket.error) as e:
309
err = getattr(e, 'errno', None)
310
if err is None and getattr(e, 'args', None) is not None:
311
# select.error doesn't have 'errno', it just has args[0]
313
if err in _bad_file_descriptor:
314
return # Not a socket indicates read() will fail
315
elif err == errno.EINTR:
316
# Interrupted, keep looping.
321
raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
322
% (timeout_seconds,))
324
def _serve_one_request(self, protocol):
325
"""Read one request from input, process, send back a response.
327
:param protocol: a SmartServerRequestProtocol.
330
self._serve_one_request_unguarded(protocol)
331
except KeyboardInterrupt:
334
self.terminate_due_to_error()
336
def terminate_due_to_error(self):
337
"""Called when an unhandled exception from the protocol occurs."""
338
raise NotImplementedError(self.terminate_due_to_error)
340
def _read_bytes(self, desired_count):
341
"""Get some bytes from the medium.
343
:param desired_count: number of bytes we want to read.
345
raise NotImplementedError(self._read_bytes)
348
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
350
def __init__(self, sock, backing_transport, root_client_path='/',
354
:param sock: the socket the server will read from. It will be put
357
SmartServerStreamMedium.__init__(
358
self, backing_transport, root_client_path=root_client_path,
360
sock.setblocking(True)
362
# Get the getpeername now, as we might be closed later when we care.
364
self._client_info = sock.getpeername()
366
self._client_info = '<unknown>'
369
return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
372
return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
375
def _serve_one_request_unguarded(self, protocol):
376
while protocol.next_read_size():
377
# We can safely try to read large chunks. If there is less data
378
# than MAX_SOCKET_CHUNK ready, the socket will just return a
379
# short read immediately rather than block.
380
bytes = self.read_bytes(osutils.MAX_SOCKET_CHUNK)
384
protocol.accept_bytes(bytes)
386
self._push_back(protocol.unused_data)
388
def _disconnect_client(self):
389
"""Close the current connection. We stopped due to a timeout/etc."""
392
def _wait_for_bytes_with_timeout(self, timeout_seconds):
393
"""Wait for more bytes to be read, but timeout if none available.
395
This allows us to detect idle connections, and stop trying to read from
396
them, without setting the socket itself to non-blocking. This also
397
allows us to specify when we watch for idle timeouts.
399
:return: None, this will raise ConnectionTimeout if we time out before
402
return self._wait_on_descriptor(self.socket, timeout_seconds)
404
def _read_bytes(self, desired_count):
405
return osutils.read_bytes_from_socket(
406
self.socket, self._report_activity)
408
def terminate_due_to_error(self):
409
# TODO: This should log to a server log file, but no such thing
410
# exists yet. Andrew Bennetts 2006-09-29.
414
def _write_out(self, bytes):
415
tstart = osutils.timer_func()
416
osutils.send_all(self.socket, bytes, self._report_activity)
417
if 'hpss' in debug.debug_flags:
418
thread_id = thread.get_ident()
419
trace.mutter('%12s: [%s] %d bytes to the socket in %.3fs'
420
% ('wrote', thread_id, len(bytes),
421
osutils.timer_func() - tstart))
424
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
426
def __init__(self, in_file, out_file, backing_transport, timeout=None):
427
"""Construct new server.
429
:param in_file: Python file from which requests can be read.
430
:param out_file: Python file to write responses.
431
:param backing_transport: Transport for the directory served.
433
SmartServerStreamMedium.__init__(self, backing_transport,
435
if sys.platform == 'win32':
436
# force binary mode for files
438
for f in (in_file, out_file):
439
fileno = getattr(f, 'fileno', None)
441
msvcrt.setmode(fileno(), os.O_BINARY)
446
"""See SmartServerStreamMedium.serve"""
447
# This is the regular serve, except it adds signal trapping for soft
449
stop_gracefully = self._stop_gracefully
450
signals.register_on_hangup(id(self), stop_gracefully)
452
return super(SmartServerPipeStreamMedium, self).serve()
454
signals.unregister_on_hangup(id(self))
456
def _serve_one_request_unguarded(self, protocol):
458
# We need to be careful not to read past the end of the current
459
# request, or else the read from the pipe will block, so we use
460
# protocol.next_read_size().
461
bytes_to_read = protocol.next_read_size()
462
if bytes_to_read == 0:
463
# Finished serving this request.
466
bytes = self.read_bytes(bytes_to_read)
468
# Connection has been closed.
472
protocol.accept_bytes(bytes)
474
def _disconnect_client(self):
479
def _wait_for_bytes_with_timeout(self, timeout_seconds):
480
"""Wait for more bytes to be read, but timeout if none available.
482
This allows us to detect idle connections, and stop trying to read from
483
them, without setting the socket itself to non-blocking. This also
484
allows us to specify when we watch for idle timeouts.
486
:return: None, this will raise ConnectionTimeout if we time out before
489
if (getattr(self._in, 'fileno', None) is None
490
or sys.platform == 'win32'):
491
# You can't select() file descriptors on Windows.
493
return self._wait_on_descriptor(self._in, timeout_seconds)
495
def _read_bytes(self, desired_count):
496
return self._in.read(desired_count)
498
def terminate_due_to_error(self):
499
# TODO: This should log to a server log file, but no such thing
500
# exists yet. Andrew Bennetts 2006-09-29.
504
def _write_out(self, bytes):
505
self._out.write(bytes)
508
class SmartClientMediumRequest(object):
509
"""A request on a SmartClientMedium.
511
Each request allows bytes to be provided to it via accept_bytes, and then
512
the response bytes to be read via read_bytes.
515
request.accept_bytes('123')
516
request.finished_writing()
517
result = request.read_bytes(3)
518
request.finished_reading()
520
It is up to the individual SmartClientMedium whether multiple concurrent
521
requests can exist. See SmartClientMedium.get_request to obtain instances
522
of SmartClientMediumRequest, and the concrete Medium you are using for
523
details on concurrency and pipelining.
526
def __init__(self, medium):
527
"""Construct a SmartClientMediumRequest for the medium medium."""
528
self._medium = medium
529
# we track state by constants - we may want to use the same
530
# pattern as BodyReader if it gets more complex.
531
# valid states are: "writing", "reading", "done"
532
self._state = "writing"
534
def accept_bytes(self, bytes):
535
"""Accept bytes for inclusion in this request.
537
This method may not be called after finished_writing() has been
538
called. It depends upon the Medium whether or not the bytes will be
539
immediately transmitted. Message based Mediums will tend to buffer the
540
bytes until finished_writing() is called.
542
:param bytes: A bytestring.
544
if self._state != "writing":
545
raise errors.WritingCompleted(self)
546
self._accept_bytes(bytes)
548
def _accept_bytes(self, bytes):
549
"""Helper for accept_bytes.
551
Accept_bytes checks the state of the request to determing if bytes
552
should be accepted. After that it hands off to _accept_bytes to do the
555
raise NotImplementedError(self._accept_bytes)
557
def finished_reading(self):
558
"""Inform the request that all desired data has been read.
560
This will remove the request from the pipeline for its medium (if the
561
medium supports pipelining) and any further calls to methods on the
562
request will raise ReadingCompleted.
564
if self._state == "writing":
565
raise errors.WritingNotComplete(self)
566
if self._state != "reading":
567
raise errors.ReadingCompleted(self)
569
self._finished_reading()
571
def _finished_reading(self):
572
"""Helper for finished_reading.
574
finished_reading checks the state of the request to determine if
575
finished_reading is allowed, and if it is hands off to _finished_reading
576
to perform the action.
578
raise NotImplementedError(self._finished_reading)
580
def finished_writing(self):
581
"""Finish the writing phase of this request.
583
This will flush all pending data for this request along the medium.
584
After calling finished_writing, you may not call accept_bytes anymore.
586
if self._state != "writing":
587
raise errors.WritingCompleted(self)
588
self._state = "reading"
589
self._finished_writing()
591
def _finished_writing(self):
592
"""Helper for finished_writing.
594
finished_writing checks the state of the request to determine if
595
finished_writing is allowed, and if it is hands off to _finished_writing
596
to perform the action.
598
raise NotImplementedError(self._finished_writing)
600
def read_bytes(self, count):
601
"""Read bytes from this requests response.
603
This method will block and wait for count bytes to be read. It may not
604
be invoked until finished_writing() has been called - this is to ensure
605
a message-based approach to requests, for compatibility with message
606
based mediums like HTTP.
608
if self._state == "writing":
609
raise errors.WritingNotComplete(self)
610
if self._state != "reading":
611
raise errors.ReadingCompleted(self)
612
return self._read_bytes(count)
614
def _read_bytes(self, count):
615
"""Helper for SmartClientMediumRequest.read_bytes.
617
read_bytes checks the state of the request to determing if bytes
618
should be read. After that it hands off to _read_bytes to do the
621
By default this forwards to self._medium.read_bytes because we are
622
operating on the medium's stream.
624
return self._medium.read_bytes(count)
627
line = self._read_line()
628
if not line.endswith('\n'):
629
# end of file encountered reading from server
630
raise errors.ConnectionReset(
631
"Unexpected end of message. Please check connectivity "
632
"and permissions, and report a bug if problems persist.")
635
def _read_line(self):
636
"""Helper for SmartClientMediumRequest.read_line.
638
By default this forwards to self._medium._get_line because we are
639
operating on the medium's stream.
641
return self._medium._get_line()
644
class _VfsRefuser(object):
645
"""An object that refuses all VFS requests.
650
client._SmartClient.hooks.install_named_hook(
651
'call', self.check_vfs, 'vfs refuser')
653
def check_vfs(self, params):
655
request_method = request.request_handlers.get(params.method)
657
# A method we don't know about doesn't count as a VFS method.
659
if issubclass(request_method, vfs.VfsRequest):
660
raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
663
class _DebugCounter(object):
664
"""An object that counts the HPSS calls made to each client medium.
666
When a medium is garbage-collected, or failing that when
667
bzrlib.global_state exits, the total number of calls made on that medium
668
are reported via trace.note.
672
self.counts = weakref.WeakKeyDictionary()
673
client._SmartClient.hooks.install_named_hook(
674
'call', self.increment_call_count, 'hpss call counter')
675
bzrlib.global_state.cleanups.add_cleanup(self.flush_all)
677
def track(self, medium):
678
"""Start tracking calls made to a medium.
680
This only keeps a weakref to the medium, so shouldn't affect the
683
medium_repr = repr(medium)
684
# Add this medium to the WeakKeyDictionary
685
self.counts[medium] = dict(count=0, vfs_count=0,
686
medium_repr=medium_repr)
687
# Weakref callbacks are fired in reverse order of their association
688
# with the referenced object. So we add a weakref *after* adding to
689
# the WeakKeyDict so that we can report the value from it before the
690
# entry is removed by the WeakKeyDict's own callback.
691
ref = weakref.ref(medium, self.done)
693
def increment_call_count(self, params):
694
# Increment the count in the WeakKeyDictionary
695
value = self.counts[params.medium]
698
request_method = request.request_handlers.get(params.method)
700
# A method we don't know about doesn't count as a VFS method.
702
if issubclass(request_method, vfs.VfsRequest):
703
value['vfs_count'] += 1
706
value = self.counts[ref]
707
count, vfs_count, medium_repr = (
708
value['count'], value['vfs_count'], value['medium_repr'])
709
# In case this callback is invoked for the same ref twice (by the
710
# weakref callback and by the atexit function), set the call count back
711
# to 0 so this item won't be reported twice.
713
value['vfs_count'] = 0
715
trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
716
count, vfs_count, medium_repr))
719
for ref in list(self.counts.keys()):
722
_debug_counter = None
726
class SmartClientMedium(SmartMedium):
727
"""Smart client is a medium for sending smart protocol requests over."""
729
def __init__(self, base):
730
super(SmartClientMedium, self).__init__()
732
self._protocol_version_error = None
733
self._protocol_version = None
734
self._done_hello = False
735
# Be optimistic: we assume the remote end can accept new remote
736
# requests until we get an error saying otherwise.
737
# _remote_version_is_before tracks the bzr version the remote side
738
# can be based on what we've seen so far.
739
self._remote_version_is_before = None
740
# Install debug hook function if debug flag is set.
741
if 'hpss' in debug.debug_flags:
742
global _debug_counter
743
if _debug_counter is None:
744
_debug_counter = _DebugCounter()
745
_debug_counter.track(self)
746
if 'hpss_client_no_vfs' in debug.debug_flags:
748
if _vfs_refuser is None:
749
_vfs_refuser = _VfsRefuser()
751
def _is_remote_before(self, version_tuple):
752
"""Is it possible the remote side supports RPCs for a given version?
756
needed_version = (1, 2)
757
if medium._is_remote_before(needed_version):
758
fallback_to_pre_1_2_rpc()
762
except UnknownSmartMethod:
763
medium._remember_remote_is_before(needed_version)
764
fallback_to_pre_1_2_rpc()
766
:seealso: _remember_remote_is_before
768
if self._remote_version_is_before is None:
769
# So far, the remote side seems to support everything
771
return version_tuple >= self._remote_version_is_before
773
def _remember_remote_is_before(self, version_tuple):
774
"""Tell this medium that the remote side is older the given version.
776
:seealso: _is_remote_before
778
if (self._remote_version_is_before is not None and
779
version_tuple > self._remote_version_is_before):
780
# We have been told that the remote side is older than some version
781
# which is newer than a previously supplied older-than version.
782
# This indicates that some smart verb call is not guarded
783
# appropriately (it should simply not have been tried).
785
"_remember_remote_is_before(%r) called, but "
786
"_remember_remote_is_before(%r) was called previously."
787
, version_tuple, self._remote_version_is_before)
788
if 'hpss' in debug.debug_flags:
789
ui.ui_factory.show_warning(
790
"_remember_remote_is_before(%r) called, but "
791
"_remember_remote_is_before(%r) was called previously."
792
% (version_tuple, self._remote_version_is_before))
794
self._remote_version_is_before = version_tuple
796
def protocol_version(self):
797
"""Find out if 'hello' smart request works."""
798
if self._protocol_version_error is not None:
799
raise self._protocol_version_error
800
if not self._done_hello:
802
medium_request = self.get_request()
803
# Send a 'hello' request in protocol version one, for maximum
804
# backwards compatibility.
805
client_protocol = protocol.SmartClientRequestProtocolOne(medium_request)
806
client_protocol.query_version()
807
self._done_hello = True
808
except errors.SmartProtocolError, e:
809
# Cache the error, just like we would cache a successful
811
self._protocol_version_error = e
815
def should_probe(self):
816
"""Should RemoteBzrDirFormat.probe_transport send a smart request on
819
Some transports are unambiguously smart-only; there's no need to check
820
if the transport is able to carry smart requests, because that's all
821
it is for. In those cases, this method should return False.
823
But some HTTP transports can sometimes fail to carry smart requests,
824
but still be usuable for accessing remote bzrdirs via plain file
825
accesses. So for those transports, their media should return True here
826
so that RemoteBzrDirFormat can determine if it is appropriate for that
831
def disconnect(self):
832
"""If this medium maintains a persistent connection, close it.
834
The default implementation does nothing.
837
def remote_path_from_transport(self, transport):
838
"""Convert transport into a path suitable for using in a request.
840
Note that the resulting remote path doesn't encode the host name or
841
anything but path, so it is only safe to use it in requests sent over
842
the medium from the matching transport.
844
medium_base = urlutils.join(self.base, '/')
845
rel_url = urlutils.relative_url(medium_base, transport.base)
846
return urllib.unquote(rel_url)
849
class SmartClientStreamMedium(SmartClientMedium):
850
"""Stream based medium common class.
852
SmartClientStreamMediums operate on a stream. All subclasses use a common
853
SmartClientStreamMediumRequest for their requests, and should implement
854
_accept_bytes and _read_bytes to allow the request objects to send and
858
def __init__(self, base):
859
SmartClientMedium.__init__(self, base)
860
self._current_request = None
862
def accept_bytes(self, bytes):
863
self._accept_bytes(bytes)
866
"""The SmartClientStreamMedium knows how to close the stream when it is
872
"""Flush the output stream.
874
This method is used by the SmartClientStreamMediumRequest to ensure that
875
all data for a request is sent, to avoid long timeouts or deadlocks.
877
raise NotImplementedError(self._flush)
879
def get_request(self):
880
"""See SmartClientMedium.get_request().
882
SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
885
return SmartClientStreamMediumRequest(self)
888
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
889
"""A client medium using simple pipes.
891
This client does not manage the pipes: it assumes they will always be open.
894
def __init__(self, readable_pipe, writeable_pipe, base):
895
SmartClientStreamMedium.__init__(self, base)
896
self._readable_pipe = readable_pipe
897
self._writeable_pipe = writeable_pipe
899
def _accept_bytes(self, bytes):
900
"""See SmartClientStreamMedium.accept_bytes."""
901
self._writeable_pipe.write(bytes)
902
self._report_activity(len(bytes), 'write')
905
"""See SmartClientStreamMedium._flush()."""
906
self._writeable_pipe.flush()
908
def _read_bytes(self, count):
909
"""See SmartClientStreamMedium._read_bytes."""
910
bytes_to_read = min(count, _MAX_READ_SIZE)
911
bytes = self._readable_pipe.read(bytes_to_read)
912
self._report_activity(len(bytes), 'read')
916
class SSHParams(object):
917
"""A set of parameters for starting a remote bzr via SSH."""
919
def __init__(self, host, port=None, username=None, password=None,
920
bzr_remote_path='bzr'):
923
self.username = username
924
self.password = password
925
self.bzr_remote_path = bzr_remote_path
928
class SmartSSHClientMedium(SmartClientStreamMedium):
929
"""A client medium using SSH.
931
It delegates IO to a SmartClientSocketMedium or
932
SmartClientAlreadyConnectedSocketMedium (depending on platform).
935
def __init__(self, base, ssh_params, vendor=None):
936
"""Creates a client that will connect on the first use.
938
:param ssh_params: A SSHParams instance.
939
:param vendor: An optional override for the ssh vendor to use. See
940
bzrlib.transport.ssh for details on ssh vendors.
942
self._real_medium = None
943
self._ssh_params = ssh_params
944
# for the benefit of progress making a short description of this
946
self._scheme = 'bzr+ssh'
947
# SmartClientStreamMedium stores the repr of this object in its
948
# _DebugCounter so we have to store all the values used in our repr
949
# method before calling the super init.
950
SmartClientStreamMedium.__init__(self, base)
951
self._vendor = vendor
952
self._ssh_connection = None
955
if self._ssh_params.port is None:
958
maybe_port = ':%s' % self._ssh_params.port
959
return "%s(%s://%s@%s%s/)" % (
960
self.__class__.__name__,
962
self._ssh_params.username,
963
self._ssh_params.host,
966
def _accept_bytes(self, bytes):
967
"""See SmartClientStreamMedium.accept_bytes."""
968
self._ensure_connection()
969
self._real_medium.accept_bytes(bytes)
971
def disconnect(self):
972
"""See SmartClientMedium.disconnect()."""
973
if self._real_medium is not None:
974
self._real_medium.disconnect()
975
self._real_medium = None
976
if self._ssh_connection is not None:
977
self._ssh_connection.close()
978
self._ssh_connection = None
980
def _ensure_connection(self):
981
"""Connect this medium if not already connected."""
982
if self._real_medium is not None:
984
if self._vendor is None:
985
vendor = ssh._get_ssh_vendor()
987
vendor = self._vendor
988
self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
989
self._ssh_params.password, self._ssh_params.host,
990
self._ssh_params.port,
991
command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
992
'--directory=/', '--allow-writes'])
993
io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
994
if io_kind == 'socket':
995
self._real_medium = SmartClientAlreadyConnectedSocketMedium(
996
self.base, io_object)
997
elif io_kind == 'pipes':
998
read_from, write_to = io_object
999
self._real_medium = SmartSimplePipesClientMedium(
1000
read_from, write_to, self.base)
1002
raise AssertionError(
1003
"Unexpected io_kind %r from %r"
1004
% (io_kind, self._ssh_connection))
1007
"""See SmartClientStreamMedium._flush()."""
1008
self._real_medium._flush()
1010
def _read_bytes(self, count):
1011
"""See SmartClientStreamMedium.read_bytes."""
1012
if self._real_medium is None:
1013
raise errors.MediumNotConnected(self)
1014
return self._real_medium.read_bytes(count)
1017
# Port 4155 is the default port for bzr://, registered with IANA.
1018
BZR_DEFAULT_INTERFACE = None
1019
BZR_DEFAULT_PORT = 4155
1022
class SmartClientSocketMedium(SmartClientStreamMedium):
1023
"""A client medium using a socket.
1025
This class isn't usable directly. Use one of its subclasses instead.
1028
def __init__(self, base):
1029
SmartClientStreamMedium.__init__(self, base)
1031
self._connected = False
1033
def _accept_bytes(self, bytes):
1034
"""See SmartClientMedium.accept_bytes."""
1035
self._ensure_connection()
1036
osutils.send_all(self._socket, bytes, self._report_activity)
1038
def _ensure_connection(self):
1039
"""Connect this medium if not already connected."""
1040
raise NotImplementedError(self._ensure_connection)
1043
"""See SmartClientStreamMedium._flush().
1045
For sockets we do no flushing. For TCP sockets we may want to turn off
1046
TCP_NODELAY and add a means to do a flush, but that can be done in the
1050
def _read_bytes(self, count):
1051
"""See SmartClientMedium.read_bytes."""
1052
if not self._connected:
1053
raise errors.MediumNotConnected(self)
1054
return osutils.read_bytes_from_socket(
1055
self._socket, self._report_activity)
1057
def disconnect(self):
1058
"""See SmartClientMedium.disconnect()."""
1059
if not self._connected:
1061
self._socket.close()
1063
self._connected = False
1066
class SmartTCPClientMedium(SmartClientSocketMedium):
1067
"""A client medium that creates a TCP connection."""
1069
def __init__(self, host, port, base):
1070
"""Creates a client that will connect on the first use."""
1071
SmartClientSocketMedium.__init__(self, base)
1075
def _ensure_connection(self):
1076
"""Connect this medium if not already connected."""
1079
if self._port is None:
1080
port = BZR_DEFAULT_PORT
1082
port = int(self._port)
1084
sockaddrs = socket.getaddrinfo(self._host, port, socket.AF_UNSPEC,
1085
socket.SOCK_STREAM, 0, 0)
1086
except socket.gaierror, (err_num, err_msg):
1087
raise errors.ConnectionError("failed to lookup %s:%d: %s" %
1088
(self._host, port, err_msg))
1089
# Initialize err in case there are no addresses returned:
1090
err = socket.error("no address found for %s" % self._host)
1091
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
1093
self._socket = socket.socket(family, socktype, proto)
1094
self._socket.setsockopt(socket.IPPROTO_TCP,
1095
socket.TCP_NODELAY, 1)
1096
self._socket.connect(sockaddr)
1097
except socket.error, err:
1098
if self._socket is not None:
1099
self._socket.close()
1103
if self._socket is None:
1104
# socket errors either have a (string) or (errno, string) as their
1106
if type(err.args) is str:
1109
err_msg = err.args[1]
1110
raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1111
(self._host, port, err_msg))
1112
self._connected = True
1115
class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
1116
"""A client medium for an already connected socket.
1118
Note that this class will assume it "owns" the socket, so it will close it
1119
when its disconnect method is called.
1122
def __init__(self, base, sock):
1123
SmartClientSocketMedium.__init__(self, base)
1125
self._connected = True
1127
def _ensure_connection(self):
1128
# Already connected, by definition! So nothing to do.
1132
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
1133
"""A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
1135
def __init__(self, medium):
1136
SmartClientMediumRequest.__init__(self, medium)
1137
# check that we are safe concurrency wise. If some streams start
1138
# allowing concurrent requests - i.e. via multiplexing - then this
1139
# assert should be moved to SmartClientStreamMedium.get_request,
1140
# and the setting/unsetting of _current_request likewise moved into
1141
# that class : but its unneeded overhead for now. RBC 20060922
1142
if self._medium._current_request is not None:
1143
raise errors.TooManyConcurrentRequests(self._medium)
1144
self._medium._current_request = self
1146
def _accept_bytes(self, bytes):
1147
"""See SmartClientMediumRequest._accept_bytes.
1149
This forwards to self._medium._accept_bytes because we are operating
1150
on the mediums stream.
1152
self._medium._accept_bytes(bytes)
1154
def _finished_reading(self):
1155
"""See SmartClientMediumRequest._finished_reading.
1157
This clears the _current_request on self._medium to allow a new
1158
request to be created.
1160
if self._medium._current_request is not self:
1161
raise AssertionError()
1162
self._medium._current_request = None
1164
def _finished_writing(self):
1165
"""See SmartClientMediumRequest._finished_writing.
1167
This invokes self._medium._flush to ensure all bytes are transmitted.
1169
self._medium._flush()