~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/stub_sftp.py

  • Committer: Robert Collins
  • Date: 2010-05-11 08:44:59 UTC
  • mfrom: (5221 +trunk)
  • mto: This revision was merged to the branch mainline in revision 5223.
  • Revision ID: robertc@robertcollins.net-20100511084459-pb0uinna9zs3wu59
Merge trunk - resolve conflicts.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2008-2011 Robey Pointer <robey@lag.net>, Canonical Ltd
 
1
# Copyright (C) 2005, 2006, 2008, 2009, 2010 Robey Pointer <robey@lag.net>, Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
21
21
 
22
22
import os
23
23
import paramiko
 
24
import select
24
25
import socket
25
 
import SocketServer
26
26
import sys
 
27
import threading
27
28
import time
28
29
 
29
30
from bzrlib import (
37
38
from bzrlib.tests import test_server
38
39
 
39
40
 
40
 
class StubServer(paramiko.ServerInterface):
 
41
class StubServer (paramiko.ServerInterface):
41
42
 
42
 
    def __init__(self, test_case_server):
 
43
    def __init__(self, test_case):
43
44
        paramiko.ServerInterface.__init__(self)
44
 
        self.log = test_case_server.log
 
45
        self._test_case = test_case
45
46
 
46
47
    def check_auth_password(self, username, password):
47
48
        # all are allowed
48
 
        self.log('sftpserver - authorizing: %s' % (username,))
 
49
        self._test_case.log('sftpserver - authorizing: %s' % (username,))
49
50
        return paramiko.AUTH_SUCCESSFUL
50
51
 
51
52
    def check_channel_request(self, kind, chanid):
52
 
        self.log('sftpserver - channel request: %s, %s' % (kind, chanid))
 
53
        self._test_case.log(
 
54
            'sftpserver - channel request: %s, %s' % (kind, chanid))
53
55
        return paramiko.OPEN_SUCCEEDED
54
56
 
55
57
 
56
 
class StubSFTPHandle(paramiko.SFTPHandle):
57
 
 
 
58
class StubSFTPHandle (paramiko.SFTPHandle):
58
59
    def stat(self):
59
60
        try:
60
61
            return paramiko.SFTPAttributes.from_stat(
72
73
            return paramiko.SFTPServer.convert_errno(e.errno)
73
74
 
74
75
 
75
 
class StubSFTPServer(paramiko.SFTPServerInterface):
 
76
class StubSFTPServer (paramiko.SFTPServerInterface):
76
77
 
77
78
    def __init__(self, server, root, home=None):
78
79
        paramiko.SFTPServerInterface.__init__(self, server)
89
90
            self.home = home[len(self.root):]
90
91
        if self.home.startswith('/'):
91
92
            self.home = self.home[1:]
92
 
        server.log('sftpserver - new connection')
 
93
        server._test_case.log('sftpserver - new connection')
93
94
 
94
95
    def _realpath(self, path):
95
96
        # paths returned from self.canonicalize() always start with
134
135
        try:
135
136
            out = [ ]
136
137
            # TODO: win32 incorrectly lists paths with non-ascii if path is not
137
 
            # unicode. However on unix the server should only deal with
 
138
            # unicode. However on Linux the server should only deal with
138
139
            # bytestreams and posix.listdir does the right thing
139
140
            if sys.platform == 'win32':
140
141
                flist = [f.encode('utf8') for f in os.listdir(path)]
240
241
    # removed: chattr, symlink, readlink
241
242
    # (nothing in bzr's sftp transport uses those)
242
243
 
243
 
 
244
244
# ------------- server test implementation --------------
245
245
 
246
246
STUB_SERVER_KEY = """
262
262
"""
263
263
 
264
264
 
 
265
class SocketListener(threading.Thread):
 
266
 
 
267
    def __init__(self, callback):
 
268
        threading.Thread.__init__(self)
 
269
        self._callback = callback
 
270
        self._socket = socket.socket()
 
271
        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
272
        self._socket.bind(('localhost', 0))
 
273
        self._socket.listen(1)
 
274
        self.host, self.port = self._socket.getsockname()[:2]
 
275
        self._stop_event = threading.Event()
 
276
 
 
277
    def stop(self):
 
278
        # called from outside this thread
 
279
        self._stop_event.set()
 
280
        # use a timeout here, because if the test fails, the server thread may
 
281
        # never notice the stop_event.
 
282
        self.join(5.0)
 
283
        self._socket.close()
 
284
 
 
285
    def run(self):
 
286
        trace.mutter('SocketListener %r has started', self)
 
287
        while True:
 
288
            readable, writable_unused, exception_unused = \
 
289
                select.select([self._socket], [], [], 0.1)
 
290
            if self._stop_event.isSet():
 
291
                trace.mutter('SocketListener %r has stopped', self)
 
292
                return
 
293
            if len(readable) == 0:
 
294
                continue
 
295
            try:
 
296
                s, addr_unused = self._socket.accept()
 
297
                trace.mutter('SocketListener %r has accepted connection %r',
 
298
                    self, s)
 
299
                # because the loopback socket is inline, and transports are
 
300
                # never explicitly closed, best to launch a new thread.
 
301
                threading.Thread(target=self._callback, args=(s,)).start()
 
302
            except socket.error, x:
 
303
                sys.excepthook(*sys.exc_info())
 
304
                trace.warning('Socket error during accept() '
 
305
                              'within unit test server thread: %r' % x)
 
306
            except Exception, x:
 
307
                # probably a failed test; unit test thread will log the
 
308
                # failure/error
 
309
                sys.excepthook(*sys.exc_info())
 
310
                trace.warning(
 
311
                    'Exception from within unit test server thread: %r' % x)
 
312
 
 
313
 
265
314
class SocketDelay(object):
266
315
    """A socket decorator to make TCP appear slower.
267
316
 
337
386
        return bytes_sent
338
387
 
339
388
 
340
 
class TestingSFTPConnectionHandler(SocketServer.BaseRequestHandler):
341
 
 
342
 
    def setup(self):
343
 
        self.wrap_for_latency()
344
 
        tcs = self.server.test_case_server
345
 
        ptrans = paramiko.Transport(self.request)
346
 
        self.paramiko_transport = ptrans
347
 
        # Set it to a channel under 'bzr' so that we get debug info
348
 
        ptrans.set_log_channel('bzr.paramiko.transport')
349
 
        ptrans.add_server_key(tcs.get_host_key())
350
 
        ptrans.set_subsystem_handler('sftp', paramiko.SFTPServer,
351
 
                                     StubSFTPServer, root=tcs._root,
352
 
                                     home=tcs._server_homedir)
353
 
        server = tcs._server_interface(tcs)
354
 
        # This blocks until the key exchange has been done
355
 
        ptrans.start_server(None, server)
356
 
 
357
 
    def finish(self):
358
 
        # Wait for the conversation to finish, when the paramiko.Transport
359
 
        # thread finishes
360
 
        # TODO: Consider timing out after XX seconds rather than hanging.
361
 
        #       Also we could check paramiko_transport.active and possibly
362
 
        #       paramiko_transport.getException().
363
 
        self.paramiko_transport.join()
364
 
 
365
 
    def wrap_for_latency(self):
366
 
        tcs = self.server.test_case_server
367
 
        if tcs.add_latency:
368
 
            # Give the socket (which the request really is) a latency adding
369
 
            # decorator.
370
 
            self.request = SocketDelay(self.request, tcs.add_latency)
371
 
 
372
 
 
373
 
class TestingSFTPWithoutSSHConnectionHandler(TestingSFTPConnectionHandler):
374
 
 
375
 
    def setup(self):
376
 
        self.wrap_for_latency()
377
 
        # Re-import these as locals, so that they're still accessible during
378
 
        # interpreter shutdown (when all module globals get set to None, leading
379
 
        # to confusing errors like "'NoneType' object has no attribute 'error'".
380
 
        class FakeChannel(object):
381
 
            def get_transport(self):
382
 
                return self
383
 
            def get_log_channel(self):
384
 
                return 'bzr.paramiko'
385
 
            def get_name(self):
386
 
                return '1'
387
 
            def get_hexdump(self):
388
 
                return False
389
 
            def close(self):
390
 
                pass
391
 
 
392
 
        tcs = self.server.test_case_server
393
 
        sftp_server = paramiko.SFTPServer(
394
 
            FakeChannel(), 'sftp', StubServer(tcs), StubSFTPServer,
395
 
            root=tcs._root, home=tcs._server_homedir)
396
 
        self.sftp_server = sftp_server
397
 
        sys_stderr = sys.stderr # Used in error reporting during shutdown
398
 
        try:
399
 
            sftp_server.start_subsystem(
400
 
                'sftp', None, ssh.SocketAsChannelAdapter(self.request))
401
 
        except socket.error, e:
402
 
            if (len(e.args) > 0) and (e.args[0] == errno.EPIPE):
403
 
                # it's okay for the client to disconnect abruptly
404
 
                # (bug in paramiko 1.6: it should absorb this exception)
405
 
                pass
406
 
            else:
407
 
                raise
408
 
        except Exception, e:
409
 
            # This typically seems to happen during interpreter shutdown, so
410
 
            # most of the useful ways to report this error won't work.
411
 
            # Writing the exception type, and then the text of the exception,
412
 
            # seems to be the best we can do.
413
 
            # FIXME: All interpreter shutdown errors should have been related
414
 
            # to daemon threads, cleanup needed -- vila 20100623
415
 
            sys_stderr.write('\nEXCEPTION %r: ' % (e.__class__,))
416
 
            sys_stderr.write('%s\n\n' % (e,))
417
 
 
418
 
    def finish(self):
419
 
        self.sftp_server.finish_subsystem()
420
 
 
421
 
 
422
 
class TestingSFTPServer(test_server.TestingThreadingTCPServer):
423
 
 
424
 
    def __init__(self, server_address, request_handler_class, test_case_server):
425
 
        test_server.TestingThreadingTCPServer.__init__(
426
 
            self, server_address, request_handler_class)
427
 
        self.test_case_server = test_case_server
428
 
 
429
 
 
430
 
class SFTPServer(test_server.TestingTCPServerInAThread):
 
389
class SFTPServer(test_server.TestServer):
431
390
    """Common code for SFTP server facilities."""
432
391
 
433
392
    def __init__(self, server_interface=StubServer):
434
 
        self.host = '127.0.0.1'
435
 
        self.port = 0
436
 
        super(SFTPServer, self).__init__((self.host, self.port),
437
 
                                         TestingSFTPServer,
438
 
                                         TestingSFTPConnectionHandler)
439
393
        self._original_vendor = None
 
394
        self._homedir = None
 
395
        self._server_homedir = None
 
396
        self._listener = None
 
397
        self._root = None
440
398
        self._vendor = ssh.ParamikoVendor()
441
399
        self._server_interface = server_interface
442
 
        self._host_key = None
 
400
        # sftp server logs
443
401
        self.logs = []
444
402
        self.add_latency = 0
445
 
        self._homedir = None
446
 
        self._server_homedir = None
447
 
        self._root = None
448
403
 
449
404
    def _get_sftp_url(self, path):
450
405
        """Calculate an sftp url to this server for path."""
451
 
        return "sftp://foo:bar@%s:%s/%s" % (self.host, self.port, path)
 
406
        return 'sftp://foo:bar@%s:%d/%s' % (self._listener.host,
 
407
                                            self._listener.port, path)
452
408
 
453
409
    def log(self, message):
454
410
        """StubServer uses this to log when a new server is created."""
455
411
        self.logs.append(message)
456
412
 
457
 
    def create_server(self):
458
 
        server = self.server_class((self.host, self.port),
459
 
                                   self.request_handler_class,
460
 
                                   self)
461
 
        return server
462
 
 
463
 
    def get_host_key(self):
464
 
        if self._host_key is None:
465
 
            key_file = osutils.pathjoin(self._homedir, 'test_rsa.key')
466
 
            f = open(key_file, 'w')
467
 
            try:
468
 
                f.write(STUB_SERVER_KEY)
469
 
            finally:
470
 
                f.close()
471
 
            self._host_key = paramiko.RSAKey.from_private_key_file(key_file)
472
 
        return self._host_key
 
413
    def _run_server_entry(self, sock):
 
414
        """Entry point for all implementations of _run_server.
 
415
 
 
416
        If self.add_latency is > 0.000001 then sock is given a latency adding
 
417
        decorator.
 
418
        """
 
419
        if self.add_latency > 0.000001:
 
420
            sock = SocketDelay(sock, self.add_latency)
 
421
        return self._run_server(sock)
 
422
 
 
423
    def _run_server(self, s):
 
424
        ssh_server = paramiko.Transport(s)
 
425
        key_file = osutils.pathjoin(self._homedir, 'test_rsa.key')
 
426
        f = open(key_file, 'w')
 
427
        f.write(STUB_SERVER_KEY)
 
428
        f.close()
 
429
        host_key = paramiko.RSAKey.from_private_key_file(key_file)
 
430
        ssh_server.add_server_key(host_key)
 
431
        server = self._server_interface(self)
 
432
        ssh_server.set_subsystem_handler('sftp', paramiko.SFTPServer,
 
433
                                         StubSFTPServer, root=self._root,
 
434
                                         home=self._server_homedir)
 
435
        event = threading.Event()
 
436
        ssh_server.start_server(event, server)
 
437
        event.wait(5.0)
473
438
 
474
439
    def start_server(self, backing_server=None):
475
440
        # XXX: TODO: make sftpserver back onto backing_server rather than local
481
446
                'the local current working directory.' % (backing_server,))
482
447
        self._original_vendor = ssh._ssh_vendor_manager._cached_ssh_vendor
483
448
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._vendor
 
449
        # FIXME: the following block should certainly just be self._homedir =
 
450
        # osutils.getcwd() but that fails badly on Unix -- vila 20100224
484
451
        if sys.platform == 'win32':
485
452
            # Win32 needs to use the UNICODE api
486
453
            self._homedir = os.getcwdu()
487
 
            # Normalize the path or it will be wrongly escaped
488
 
            self._homedir = osutils.normpath(self._homedir)
489
454
        else:
490
 
            # But unix SFTP servers should just deal in bytestreams
 
455
            # But Linux SFTP servers should just deal in bytestreams
491
456
            self._homedir = os.getcwd()
492
457
        if self._server_homedir is None:
493
458
            self._server_homedir = self._homedir
494
459
        self._root = '/'
495
460
        if sys.platform == 'win32':
496
461
            self._root = ''
497
 
        super(SFTPServer, self).start_server()
 
462
        self._listener = SocketListener(self._run_server_entry)
 
463
        self._listener.setDaemon(True)
 
464
        self._listener.start()
498
465
 
499
466
    def stop_server(self):
500
 
        try:
501
 
            super(SFTPServer, self).stop_server()
502
 
        finally:
503
 
            ssh._ssh_vendor_manager._cached_ssh_vendor = self._original_vendor
 
467
        self._listener.stop()
 
468
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._original_vendor
504
469
 
505
470
    def get_bogus_url(self):
506
471
        """See bzrlib.transport.Server.get_bogus_url."""
507
 
        # this is chosen to try to prevent trouble with proxies, weird dns, etc
 
472
        # this is chosen to try to prevent trouble with proxies, wierd dns, etc
508
473
        # we bind a random socket, so that we get a guaranteed unused port
509
474
        # we just never listen on that port
510
475
        s = socket.socket()
530
495
    def __init__(self):
531
496
        super(SFTPServerWithoutSSH, self).__init__()
532
497
        self._vendor = ssh.LoopbackVendor()
533
 
        self.request_handler_class = TestingSFTPWithoutSSHConnectionHandler
534
 
 
535
 
    def get_host_key():
536
 
        return None
 
498
 
 
499
    def _run_server(self, sock):
 
500
        # Re-import these as locals, so that they're still accessible during
 
501
        # interpreter shutdown (when all module globals get set to None, leading
 
502
        # to confusing errors like "'NoneType' object has no attribute 'error'".
 
503
        class FakeChannel(object):
 
504
            def get_transport(self):
 
505
                return self
 
506
            def get_log_channel(self):
 
507
                return 'paramiko'
 
508
            def get_name(self):
 
509
                return '1'
 
510
            def get_hexdump(self):
 
511
                return False
 
512
            def close(self):
 
513
                pass
 
514
 
 
515
        server = paramiko.SFTPServer(
 
516
            FakeChannel(), 'sftp', StubServer(self), StubSFTPServer,
 
517
            root=self._root, home=self._server_homedir)
 
518
        try:
 
519
            server.start_subsystem(
 
520
                'sftp', None, ssh.SocketAsChannelAdapter(sock))
 
521
        except socket.error, e:
 
522
            if (len(e.args) > 0) and (e.args[0] == errno.EPIPE):
 
523
                # it's okay for the client to disconnect abruptly
 
524
                # (bug in paramiko 1.6: it should absorb this exception)
 
525
                pass
 
526
            else:
 
527
                raise
 
528
        except Exception, e:
 
529
            # This typically seems to happen during interpreter shutdown, so
 
530
            # most of the useful ways to report this error are won't work.
 
531
            # Writing the exception type, and then the text of the exception,
 
532
            # seems to be the best we can do.
 
533
            import sys
 
534
            sys.stderr.write('\nEXCEPTION %r: ' % (e.__class__,))
 
535
            sys.stderr.write('%s\n\n' % (e,))
 
536
        server.finish_subsystem()
537
537
 
538
538
 
539
539
class SFTPAbsoluteServer(SFTPServerWithoutSSH):
553
553
 
554
554
    def get_url(self):
555
555
        """See bzrlib.transport.Server.get_url."""
556
 
        return self._get_sftp_url("%7E/")
 
556
        return self._get_sftp_url("~/")
557
557
 
558
558
 
559
559
class SFTPSiblingAbsoluteServer(SFTPAbsoluteServer):
562
562
    It does this by serving from a deeply-nested directory that doesn't exist.
563
563
    """
564
564
 
565
 
    def create_server(self):
566
 
        # FIXME: Can't we do that in a cleaner way ? -- vila 20100623
567
 
        server = super(SFTPSiblingAbsoluteServer, self).create_server()
568
 
        server._server_homedir = '/dev/noone/runs/tests/here'
569
 
        return server
 
565
    def start_server(self, backing_server=None):
 
566
        self._server_homedir = '/dev/noone/runs/tests/here'
 
567
        super(SFTPSiblingAbsoluteServer, self).start_server(backing_server)
570
568