~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

(gz) Never raise KnownFailure in tests,
 use knownFailure method instead (Martin [gz])

Show diffs side-by-side

added added

removed removed

Lines of Context:
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
27
 
import errno
28
27
import os
29
28
import sys
30
 
import time
31
29
import urllib
32
30
 
33
31
import bzrlib
34
32
from bzrlib.lazy_import import lazy_import
35
33
lazy_import(globals(), """
36
 
import select
37
34
import socket
38
35
import thread
39
36
import weakref
45
42
    ui,
46
43
    urlutils,
47
44
    )
48
 
from bzrlib.i18n import gettext
49
 
from bzrlib.smart import client, protocol, request, signals, vfs
 
45
from bzrlib.smart import client, protocol, request, vfs
50
46
from bzrlib.transport import ssh
51
47
""")
52
48
from bzrlib import osutils
179
175
        ui.ui_factory.report_transport_activity(self, bytes, direction)
180
176
 
181
177
 
182
 
_bad_file_descriptor = (errno.EBADF,)
183
 
if sys.platform == 'win32':
184
 
    # Given on Windows if you pass a closed socket to select.select. Probably
185
 
    # also given if you pass a file handle to select.
186
 
    WSAENOTSOCK = 10038
187
 
    _bad_file_descriptor += (WSAENOTSOCK,)
188
 
 
189
 
 
190
178
class SmartServerStreamMedium(SmartMedium):
191
179
    """Handles smart commands coming over a stream.
192
180
 
205
193
        the stream.  See also the _push_back method.
206
194
    """
207
195
 
208
 
    _timer = time.time
209
 
 
210
 
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
 
196
    def __init__(self, backing_transport, root_client_path='/'):
211
197
        """Construct new server.
212
198
 
213
199
        :param backing_transport: Transport for the directory served.
216
202
        self.backing_transport = backing_transport
217
203
        self.root_client_path = root_client_path
218
204
        self.finished = False
219
 
        if timeout is None:
220
 
            raise AssertionError('You must supply a timeout.')
221
 
        self._client_timeout = timeout
222
 
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
223
205
        SmartMedium.__init__(self)
224
206
 
225
207
    def serve(self):
231
213
            while not self.finished:
232
214
                server_protocol = self._build_protocol()
233
215
                self._serve_one_request(server_protocol)
234
 
        except errors.ConnectionTimeout, e:
235
 
            trace.note('%s' % (e,))
236
 
            trace.log_exception_quietly()
237
 
            self._disconnect_client()
238
 
            # We reported it, no reason to make a big fuss.
239
 
            return
240
216
        except Exception, e:
241
217
            stderr.write("%s terminating on exception %s\n" % (self, e))
242
218
            raise
243
 
        self._disconnect_client()
244
 
 
245
 
    def _stop_gracefully(self):
246
 
        """When we finish this message, stop looking for more."""
247
 
        trace.mutter('Stopping %s' % (self,))
248
 
        self.finished = True
249
 
 
250
 
    def _disconnect_client(self):
251
 
        """Close the current connection. We stopped due to a timeout/etc."""
252
 
        # The default implementation is a no-op, because that is all we used to
253
 
        # do when disconnecting from a client. I suppose we never had the
254
 
        # *server* initiate a disconnect, before
255
 
 
256
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
257
 
        """Wait for more bytes to be read, but timeout if none available.
258
 
 
259
 
        This allows us to detect idle connections, and stop trying to read from
260
 
        them, without setting the socket itself to non-blocking. This also
261
 
        allows us to specify when we watch for idle timeouts.
262
 
 
263
 
        :return: Did we timeout? (True if we timed out, False if there is data
264
 
            to be read)
265
 
        """
266
 
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
267
219
 
268
220
    def _build_protocol(self):
269
221
        """Identifies the version of the incoming request, and returns an
274
226
 
275
227
        :returns: a SmartServerRequestProtocol.
276
228
        """
277
 
        self._wait_for_bytes_with_timeout(self._client_timeout)
278
 
        if self.finished:
279
 
            # We're stopping, so don't try to do any more work
280
 
            return None
281
229
        bytes = self._get_line()
282
230
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
283
231
        protocol = protocol_factory(
285
233
        protocol.accept_bytes(unused_bytes)
286
234
        return protocol
287
235
 
288
 
    def _wait_on_descriptor(self, fd, timeout_seconds):
289
 
        """select() on a file descriptor, waiting for nonblocking read()
290
 
 
291
 
        This will raise a ConnectionTimeout exception if we do not get a
292
 
        readable handle before timeout_seconds.
293
 
        :return: None
294
 
        """
295
 
        t_end = self._timer() + timeout_seconds
296
 
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
297
 
        rs = xs = None
298
 
        while not rs and not xs and self._timer() < t_end:
299
 
            if self.finished:
300
 
                return
301
 
            try:
302
 
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
303
 
            except (select.error, socket.error) as e:
304
 
                err = getattr(e, 'errno', None)
305
 
                if err is None and getattr(e, 'args', None) is not None:
306
 
                    # select.error doesn't have 'errno', it just has args[0]
307
 
                    err = e.args[0]
308
 
                if err in _bad_file_descriptor:
309
 
                    return # Not a socket indicates read() will fail
310
 
                elif err == errno.EINTR:
311
 
                    # Interrupted, keep looping.
312
 
                    continue
313
 
                raise
314
 
        if rs or xs:
315
 
            return
316
 
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
317
 
                                       % (timeout_seconds,))
318
 
 
319
236
    def _serve_one_request(self, protocol):
320
237
        """Read one request from input, process, send back a response.
321
238
 
322
239
        :param protocol: a SmartServerRequestProtocol.
323
240
        """
324
 
        if protocol is None:
325
 
            return
326
241
        try:
327
242
            self._serve_one_request_unguarded(protocol)
328
243
        except KeyboardInterrupt:
344
259
 
345
260
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
346
261
 
347
 
    def __init__(self, sock, backing_transport, root_client_path='/',
348
 
                 timeout=None):
 
262
    def __init__(self, sock, backing_transport, root_client_path='/'):
349
263
        """Constructor.
350
264
 
351
265
        :param sock: the socket the server will read from.  It will be put
352
266
            into blocking mode.
353
267
        """
354
268
        SmartServerStreamMedium.__init__(
355
 
            self, backing_transport, root_client_path=root_client_path,
356
 
            timeout=timeout)
 
269
            self, backing_transport, root_client_path=root_client_path)
357
270
        sock.setblocking(True)
358
271
        self.socket = sock
359
 
        # Get the getpeername now, as we might be closed later when we care.
360
 
        try:
361
 
            self._client_info = sock.getpeername()
362
 
        except socket.error:
363
 
            self._client_info = '<unknown>'
364
 
 
365
 
    def __str__(self):
366
 
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
367
 
 
368
 
    def __repr__(self):
369
 
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
370
 
            self._client_info)
371
272
 
372
273
    def _serve_one_request_unguarded(self, protocol):
373
274
        while protocol.next_read_size():
382
283
 
383
284
        self._push_back(protocol.unused_data)
384
285
 
385
 
    def _disconnect_client(self):
386
 
        """Close the current connection. We stopped due to a timeout/etc."""
387
 
        self.socket.close()
388
 
 
389
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
390
 
        """Wait for more bytes to be read, but timeout if none available.
391
 
 
392
 
        This allows us to detect idle connections, and stop trying to read from
393
 
        them, without setting the socket itself to non-blocking. This also
394
 
        allows us to specify when we watch for idle timeouts.
395
 
 
396
 
        :return: None, this will raise ConnectionTimeout if we time out before
397
 
            data is available.
398
 
        """
399
 
        return self._wait_on_descriptor(self.socket, timeout_seconds)
400
 
 
401
286
    def _read_bytes(self, desired_count):
402
287
        return osutils.read_bytes_from_socket(
403
288
            self.socket, self._report_activity)
420
305
 
421
306
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
422
307
 
423
 
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
 
308
    def __init__(self, in_file, out_file, backing_transport):
424
309
        """Construct new server.
425
310
 
426
311
        :param in_file: Python file from which requests can be read.
427
312
        :param out_file: Python file to write responses.
428
313
        :param backing_transport: Transport for the directory served.
429
314
        """
430
 
        SmartServerStreamMedium.__init__(self, backing_transport,
431
 
            timeout=timeout)
 
315
        SmartServerStreamMedium.__init__(self, backing_transport)
432
316
        if sys.platform == 'win32':
433
317
            # force binary mode for files
434
318
            import msvcrt
439
323
        self._in = in_file
440
324
        self._out = out_file
441
325
 
442
 
    def serve(self):
443
 
        """See SmartServerStreamMedium.serve"""
444
 
        # This is the regular serve, except it adds signal trapping for soft
445
 
        # shutdown.
446
 
        stop_gracefully = self._stop_gracefully
447
 
        signals.register_on_hangup(id(self), stop_gracefully)
448
 
        try:
449
 
            return super(SmartServerPipeStreamMedium, self).serve()
450
 
        finally:
451
 
            signals.unregister_on_hangup(id(self))
452
 
 
453
326
    def _serve_one_request_unguarded(self, protocol):
454
327
        while True:
455
328
            # We need to be careful not to read past the end of the current
468
341
                return
469
342
            protocol.accept_bytes(bytes)
470
343
 
471
 
    def _disconnect_client(self):
472
 
        self._in.close()
473
 
        self._out.flush()
474
 
        self._out.close()
475
 
 
476
 
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
477
 
        """Wait for more bytes to be read, but timeout if none available.
478
 
 
479
 
        This allows us to detect idle connections, and stop trying to read from
480
 
        them, without setting the socket itself to non-blocking. This also
481
 
        allows us to specify when we watch for idle timeouts.
482
 
 
483
 
        :return: None, this will raise ConnectionTimeout if we time out before
484
 
            data is available.
485
 
        """
486
 
        if (getattr(self._in, 'fileno', None) is None
487
 
            or sys.platform == 'win32'):
488
 
            # You can't select() file descriptors on Windows.
489
 
            return
490
 
        return self._wait_on_descriptor(self._in, timeout_seconds)
491
 
 
492
344
    def _read_bytes(self, desired_count):
493
345
        return self._in.read(desired_count)
494
346
 
638
490
        return self._medium._get_line()
639
491
 
640
492
 
641
 
class _VfsRefuser(object):
642
 
    """An object that refuses all VFS requests.
643
 
 
644
 
    """
645
 
 
646
 
    def __init__(self):
647
 
        client._SmartClient.hooks.install_named_hook(
648
 
            'call', self.check_vfs, 'vfs refuser')
649
 
 
650
 
    def check_vfs(self, params):
651
 
        try:
652
 
            request_method = request.request_handlers.get(params.method)
653
 
        except KeyError:
654
 
            # A method we don't know about doesn't count as a VFS method.
655
 
            return
656
 
        if issubclass(request_method, vfs.VfsRequest):
657
 
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
658
 
 
659
 
 
660
493
class _DebugCounter(object):
661
494
    """An object that counts the HPSS calls made to each client medium.
662
495
 
709
542
        value['count'] = 0
710
543
        value['vfs_count'] = 0
711
544
        if count != 0:
712
 
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
713
 
                       count, vfs_count, medium_repr))
 
545
            trace.note('HPSS calls: %d (%d vfs) %s',
 
546
                       count, vfs_count, medium_repr)
714
547
 
715
548
    def flush_all(self):
716
549
        for ref in list(self.counts.keys()):
717
550
            self.done(ref)
718
551
 
719
552
_debug_counter = None
720
 
_vfs_refuser = None
721
553
 
722
554
 
723
555
class SmartClientMedium(SmartMedium):
740
572
            if _debug_counter is None:
741
573
                _debug_counter = _DebugCounter()
742
574
            _debug_counter.track(self)
743
 
        if 'hpss_client_no_vfs' in debug.debug_flags:
744
 
            global _vfs_refuser
745
 
            if _vfs_refuser is None:
746
 
                _vfs_refuser = _VfsRefuser()
747
575
 
748
576
    def _is_remote_before(self, version_tuple):
749
577
        """Is it possible the remote side supports RPCs for a given version?
881
709
        """
882
710
        return SmartClientStreamMediumRequest(self)
883
711
 
884
 
    def reset(self):
885
 
        """We have been disconnected, reset current state.
886
 
 
887
 
        This resets things like _current_request and connected state.
888
 
        """
889
 
        self.disconnect()
890
 
        self._current_request = None
891
 
 
892
712
 
893
713
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
894
714
    """A client medium using simple pipes.
903
723
 
904
724
    def _accept_bytes(self, bytes):
905
725
        """See SmartClientStreamMedium.accept_bytes."""
906
 
        try:
907
 
            self._writeable_pipe.write(bytes)
908
 
        except IOError, e:
909
 
            if e.errno in (errno.EINVAL, errno.EPIPE):
910
 
                raise errors.ConnectionReset(
911
 
                    "Error trying to write to subprocess:\n%s" % (e,))
912
 
            raise
 
726
        self._writeable_pipe.write(bytes)
913
727
        self._report_activity(len(bytes), 'write')
914
728
 
915
729
    def _flush(self):
916
730
        """See SmartClientStreamMedium._flush()."""
917
 
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
918
 
        #       However, testing shows that even when the child process is
919
 
        #       gone, this doesn't error.
920
731
        self._writeable_pipe.flush()
921
732
 
922
733
    def _read_bytes(self, count):
941
752
 
942
753
class SmartSSHClientMedium(SmartClientStreamMedium):
943
754
    """A client medium using SSH.
944
 
 
945
 
    It delegates IO to a SmartSimplePipesClientMedium or
 
755
    
 
756
    It delegates IO to a SmartClientSocketMedium or
946
757
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
947
758
    """
948
759
 
970
781
            maybe_port = ''
971
782
        else:
972
783
            maybe_port = ':%s' % self._ssh_params.port
973
 
        if self._ssh_params.username is None:
974
 
            maybe_user = ''
975
 
        else:
976
 
            maybe_user = '%s@' % self._ssh_params.username
977
 
        return "%s(%s://%s%s%s/)" % (
 
784
        return "%s(%s://%s@%s%s/)" % (
978
785
            self.__class__.__name__,
979
786
            self._scheme,
980
 
            maybe_user,
 
787
            self._ssh_params.username,
981
788
            self._ssh_params.host,
982
789
            maybe_port)
983
790
 
1185
992
        This invokes self._medium._flush to ensure all bytes are transmitted.
1186
993
        """
1187
994
        self._medium._flush()
 
995
 
 
996