~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Jelmer Vernooij
  • Date: 2011-12-16 16:40:10 UTC
  • mto: This revision was merged to the branch mainline in revision 6391.
  • Revision ID: jelmer@samba.org-20111216164010-z3hy00xrnclnkf7a
Update tests.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
24
24
bzrlib/transport/smart/__init__.py.
25
25
"""
26
26
 
 
27
import errno
27
28
import os
28
29
import sys
 
30
import time
29
31
import urllib
30
32
 
31
33
import bzrlib
32
34
from bzrlib.lazy_import import lazy_import
33
35
lazy_import(globals(), """
 
36
import select
34
37
import socket
35
38
import thread
36
39
import weakref
38
41
from bzrlib import (
39
42
    debug,
40
43
    errors,
41
 
    symbol_versioning,
42
44
    trace,
43
45
    ui,
44
46
    urlutils,
45
47
    )
46
 
from bzrlib.smart import client, protocol, request, vfs
 
48
from bzrlib.i18n import gettext
 
49
from bzrlib.smart import client, protocol, request, signals, vfs
47
50
from bzrlib.transport import ssh
48
51
""")
49
52
from bzrlib import osutils
176
179
        ui.ui_factory.report_transport_activity(self, bytes, direction)
177
180
 
178
181
 
 
182
_bad_file_descriptor = (errno.EBADF,)
 
183
if sys.platform == 'win32':
 
184
    # Given on Windows if you pass a closed socket to select.select. Probably
 
185
    # also given if you pass a file handle to select.
 
186
    WSAENOTSOCK = 10038
 
187
    _bad_file_descriptor += (WSAENOTSOCK,)
 
188
 
 
189
 
179
190
class SmartServerStreamMedium(SmartMedium):
180
191
    """Handles smart commands coming over a stream.
181
192
 
194
205
        the stream.  See also the _push_back method.
195
206
    """
196
207
 
197
 
    def __init__(self, backing_transport, root_client_path='/'):
 
208
    _timer = time.time
 
209
 
 
210
    def __init__(self, backing_transport, root_client_path='/', timeout=None):
198
211
        """Construct new server.
199
212
 
200
213
        :param backing_transport: Transport for the directory served.
203
216
        self.backing_transport = backing_transport
204
217
        self.root_client_path = root_client_path
205
218
        self.finished = False
 
219
        if timeout is None:
 
220
            raise AssertionError('You must supply a timeout.')
 
221
        self._client_timeout = timeout
 
222
        self._client_poll_timeout = min(timeout / 10.0, 1.0)
206
223
        SmartMedium.__init__(self)
207
224
 
208
225
    def serve(self):
214
231
            while not self.finished:
215
232
                server_protocol = self._build_protocol()
216
233
                self._serve_one_request(server_protocol)
 
234
        except errors.ConnectionTimeout, e:
 
235
            trace.note('%s' % (e,))
 
236
            trace.log_exception_quietly()
 
237
            self._disconnect_client()
 
238
            # We reported it, no reason to make a big fuss.
 
239
            return
217
240
        except Exception, e:
218
241
            stderr.write("%s terminating on exception %s\n" % (self, e))
219
242
            raise
 
243
        self._disconnect_client()
 
244
 
 
245
    def _stop_gracefully(self):
 
246
        """When we finish this message, stop looking for more."""
 
247
        trace.mutter('Stopping %s' % (self,))
 
248
        self.finished = True
 
249
 
 
250
    def _disconnect_client(self):
 
251
        """Close the current connection. We stopped due to a timeout/etc."""
 
252
        # The default implementation is a no-op, because that is all we used to
 
253
        # do when disconnecting from a client. I suppose we never had the
 
254
        # *server* initiate a disconnect, before
 
255
 
 
256
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
257
        """Wait for more bytes to be read, but timeout if none available.
 
258
 
 
259
        This allows us to detect idle connections, and stop trying to read from
 
260
        them, without setting the socket itself to non-blocking. This also
 
261
        allows us to specify when we watch for idle timeouts.
 
262
 
 
263
        :return: Did we timeout? (True if we timed out, False if there is data
 
264
            to be read)
 
265
        """
 
266
        raise NotImplementedError(self._wait_for_bytes_with_timeout)
220
267
 
221
268
    def _build_protocol(self):
222
269
        """Identifies the version of the incoming request, and returns an
227
274
 
228
275
        :returns: a SmartServerRequestProtocol.
229
276
        """
 
277
        self._wait_for_bytes_with_timeout(self._client_timeout)
 
278
        if self.finished:
 
279
            # We're stopping, so don't try to do any more work
 
280
            return None
230
281
        bytes = self._get_line()
231
282
        protocol_factory, unused_bytes = _get_protocol_factory_for_bytes(bytes)
232
283
        protocol = protocol_factory(
234
285
        protocol.accept_bytes(unused_bytes)
235
286
        return protocol
236
287
 
 
288
    def _wait_on_descriptor(self, fd, timeout_seconds):
 
289
        """select() on a file descriptor, waiting for nonblocking read()
 
290
 
 
291
        This will raise a ConnectionTimeout exception if we do not get a
 
292
        readable handle before timeout_seconds.
 
293
        :return: None
 
294
        """
 
295
        t_end = self._timer() + timeout_seconds
 
296
        poll_timeout = min(timeout_seconds, self._client_poll_timeout)
 
297
        rs = xs = None
 
298
        while not rs and not xs and self._timer() < t_end:
 
299
            if self.finished:
 
300
                return
 
301
            try:
 
302
                rs, _, xs = select.select([fd], [], [fd], poll_timeout)
 
303
            except (select.error, socket.error) as e:
 
304
                err = getattr(e, 'errno', None)
 
305
                if err is None and getattr(e, 'args', None) is not None:
 
306
                    # select.error doesn't have 'errno', it just has args[0]
 
307
                    err = e.args[0]
 
308
                if err in _bad_file_descriptor:
 
309
                    return # Not a socket indicates read() will fail
 
310
                elif err == errno.EINTR:
 
311
                    # Interrupted, keep looping.
 
312
                    continue
 
313
                raise
 
314
        if rs or xs:
 
315
            return
 
316
        raise errors.ConnectionTimeout('disconnecting client after %.1f seconds'
 
317
                                       % (timeout_seconds,))
 
318
 
237
319
    def _serve_one_request(self, protocol):
238
320
        """Read one request from input, process, send back a response.
239
321
 
240
322
        :param protocol: a SmartServerRequestProtocol.
241
323
        """
 
324
        if protocol is None:
 
325
            return
242
326
        try:
243
327
            self._serve_one_request_unguarded(protocol)
244
328
        except KeyboardInterrupt:
260
344
 
261
345
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
262
346
 
263
 
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
347
    def __init__(self, sock, backing_transport, root_client_path='/',
 
348
                 timeout=None):
264
349
        """Constructor.
265
350
 
266
351
        :param sock: the socket the server will read from.  It will be put
267
352
            into blocking mode.
268
353
        """
269
354
        SmartServerStreamMedium.__init__(
270
 
            self, backing_transport, root_client_path=root_client_path)
 
355
            self, backing_transport, root_client_path=root_client_path,
 
356
            timeout=timeout)
271
357
        sock.setblocking(True)
272
358
        self.socket = sock
 
359
        # Get the getpeername now, as we might be closed later when we care.
 
360
        try:
 
361
            self._client_info = sock.getpeername()
 
362
        except socket.error:
 
363
            self._client_info = '<unknown>'
 
364
 
 
365
    def __str__(self):
 
366
        return '%s(client=%s)' % (self.__class__.__name__, self._client_info)
 
367
 
 
368
    def __repr__(self):
 
369
        return '%s.%s(client=%s)' % (self.__module__, self.__class__.__name__,
 
370
            self._client_info)
273
371
 
274
372
    def _serve_one_request_unguarded(self, protocol):
275
373
        while protocol.next_read_size():
284
382
 
285
383
        self._push_back(protocol.unused_data)
286
384
 
 
385
    def _disconnect_client(self):
 
386
        """Close the current connection. We stopped due to a timeout/etc."""
 
387
        self.socket.close()
 
388
 
 
389
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
390
        """Wait for more bytes to be read, but timeout if none available.
 
391
 
 
392
        This allows us to detect idle connections, and stop trying to read from
 
393
        them, without setting the socket itself to non-blocking. This also
 
394
        allows us to specify when we watch for idle timeouts.
 
395
 
 
396
        :return: None, this will raise ConnectionTimeout if we time out before
 
397
            data is available.
 
398
        """
 
399
        return self._wait_on_descriptor(self.socket, timeout_seconds)
 
400
 
287
401
    def _read_bytes(self, desired_count):
288
402
        return osutils.read_bytes_from_socket(
289
403
            self.socket, self._report_activity)
306
420
 
307
421
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
308
422
 
309
 
    def __init__(self, in_file, out_file, backing_transport):
 
423
    def __init__(self, in_file, out_file, backing_transport, timeout=None):
310
424
        """Construct new server.
311
425
 
312
426
        :param in_file: Python file from which requests can be read.
313
427
        :param out_file: Python file to write responses.
314
428
        :param backing_transport: Transport for the directory served.
315
429
        """
316
 
        SmartServerStreamMedium.__init__(self, backing_transport)
 
430
        SmartServerStreamMedium.__init__(self, backing_transport,
 
431
            timeout=timeout)
317
432
        if sys.platform == 'win32':
318
433
            # force binary mode for files
319
434
            import msvcrt
324
439
        self._in = in_file
325
440
        self._out = out_file
326
441
 
 
442
    def serve(self):
 
443
        """See SmartServerStreamMedium.serve"""
 
444
        # This is the regular serve, except it adds signal trapping for soft
 
445
        # shutdown.
 
446
        stop_gracefully = self._stop_gracefully
 
447
        signals.register_on_hangup(id(self), stop_gracefully)
 
448
        try:
 
449
            return super(SmartServerPipeStreamMedium, self).serve()
 
450
        finally:
 
451
            signals.unregister_on_hangup(id(self))
 
452
 
327
453
    def _serve_one_request_unguarded(self, protocol):
328
454
        while True:
329
455
            # We need to be careful not to read past the end of the current
342
468
                return
343
469
            protocol.accept_bytes(bytes)
344
470
 
 
471
    def _disconnect_client(self):
 
472
        self._in.close()
 
473
        self._out.flush()
 
474
        self._out.close()
 
475
 
 
476
    def _wait_for_bytes_with_timeout(self, timeout_seconds):
 
477
        """Wait for more bytes to be read, but timeout if none available.
 
478
 
 
479
        This allows us to detect idle connections, and stop trying to read from
 
480
        them, without setting the socket itself to non-blocking. This also
 
481
        allows us to specify when we watch for idle timeouts.
 
482
 
 
483
        :return: None, this will raise ConnectionTimeout if we time out before
 
484
            data is available.
 
485
        """
 
486
        if (getattr(self._in, 'fileno', None) is None
 
487
            or sys.platform == 'win32'):
 
488
            # You can't select() file descriptors on Windows.
 
489
            return
 
490
        return self._wait_on_descriptor(self._in, timeout_seconds)
 
491
 
345
492
    def _read_bytes(self, desired_count):
346
493
        return self._in.read(desired_count)
347
494
 
491
638
        return self._medium._get_line()
492
639
 
493
640
 
 
641
class _VfsRefuser(object):
 
642
    """An object that refuses all VFS requests.
 
643
 
 
644
    """
 
645
 
 
646
    def __init__(self):
 
647
        client._SmartClient.hooks.install_named_hook(
 
648
            'call', self.check_vfs, 'vfs refuser')
 
649
 
 
650
    def check_vfs(self, params):
 
651
        try:
 
652
            request_method = request.request_handlers.get(params.method)
 
653
        except KeyError:
 
654
            # A method we don't know about doesn't count as a VFS method.
 
655
            return
 
656
        if issubclass(request_method, vfs.VfsRequest):
 
657
            raise errors.HpssVfsRequestNotAllowed(params.method, params.args)
 
658
 
 
659
 
494
660
class _DebugCounter(object):
495
661
    """An object that counts the HPSS calls made to each client medium.
496
662
 
543
709
        value['count'] = 0
544
710
        value['vfs_count'] = 0
545
711
        if count != 0:
546
 
            trace.note('HPSS calls: %d (%d vfs) %s',
547
 
                       count, vfs_count, medium_repr)
 
712
            trace.note(gettext('HPSS calls: {0} ({1} vfs) {2}').format(
 
713
                       count, vfs_count, medium_repr))
548
714
 
549
715
    def flush_all(self):
550
716
        for ref in list(self.counts.keys()):
551
717
            self.done(ref)
552
718
 
553
719
_debug_counter = None
 
720
_vfs_refuser = None
554
721
 
555
722
 
556
723
class SmartClientMedium(SmartMedium):
573
740
            if _debug_counter is None:
574
741
                _debug_counter = _DebugCounter()
575
742
            _debug_counter.track(self)
 
743
        if 'hpss_client_no_vfs' in debug.debug_flags:
 
744
            global _vfs_refuser
 
745
            if _vfs_refuser is None:
 
746
                _vfs_refuser = _VfsRefuser()
576
747
 
577
748
    def _is_remote_before(self, version_tuple):
578
749
        """Is it possible the remote side supports RPCs for a given version?
710
881
        """
711
882
        return SmartClientStreamMediumRequest(self)
712
883
 
 
884
    def reset(self):
 
885
        """We have been disconnected, reset current state.
 
886
 
 
887
        This resets things like _current_request and connected state.
 
888
        """
 
889
        self.disconnect()
 
890
        self._current_request = None
 
891
 
713
892
 
714
893
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
715
894
    """A client medium using simple pipes.
724
903
 
725
904
    def _accept_bytes(self, bytes):
726
905
        """See SmartClientStreamMedium.accept_bytes."""
727
 
        self._writeable_pipe.write(bytes)
 
906
        try:
 
907
            self._writeable_pipe.write(bytes)
 
908
        except IOError, e:
 
909
            if e.errno in (errno.EINVAL, errno.EPIPE):
 
910
                raise errors.ConnectionReset(
 
911
                    "Error trying to write to subprocess:\n%s" % (e,))
 
912
            raise
728
913
        self._report_activity(len(bytes), 'write')
729
914
 
730
915
    def _flush(self):
731
916
        """See SmartClientStreamMedium._flush()."""
 
917
        # Note: If flush were to fail, we'd like to raise ConnectionReset, etc.
 
918
        #       However, testing shows that even when the child process is
 
919
        #       gone, this doesn't error.
732
920
        self._writeable_pipe.flush()
733
921
 
734
922
    def _read_bytes(self, count):
753
941
 
754
942
class SmartSSHClientMedium(SmartClientStreamMedium):
755
943
    """A client medium using SSH.
756
 
    
757
 
    It delegates IO to a SmartClientSocketMedium or
 
944
 
 
945
    It delegates IO to a SmartSimplePipesClientMedium or
758
946
    SmartClientAlreadyConnectedSocketMedium (depending on platform).
759
947
    """
760
948
 
782
970
            maybe_port = ''
783
971
        else:
784
972
            maybe_port = ':%s' % self._ssh_params.port
785
 
        return "%s(%s://%s@%s%s/)" % (
 
973
        if self._ssh_params.username is None:
 
974
            maybe_user = ''
 
975
        else:
 
976
            maybe_user = '%s@' % self._ssh_params.username
 
977
        return "%s(%s://%s%s%s/)" % (
786
978
            self.__class__.__name__,
787
979
            self._scheme,
788
 
            self._ssh_params.username,
 
980
            maybe_user,
789
981
            self._ssh_params.host,
790
982
            maybe_port)
791
983
 
993
1185
        This invokes self._medium._flush to ensure all bytes are transmitted.
994
1186
        """
995
1187
        self._medium._flush()
996
 
 
997