~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/selftest/testsftp.py

Exclude more files from dumb-rsync upload

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
2
 
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
3
 
#
 
1
# Copyright (C) 2005 Robey Pointer <robey@lag.net>, Canonical Ltd
 
2
 
4
3
# This program is free software; you can redistribute it and/or modify
5
4
# it under the terms of the GNU General Public License as published by
6
5
# the Free Software Foundation; either version 2 of the License, or
7
6
# (at your option) any later version.
8
 
#
 
7
 
9
8
# This program is distributed in the hope that it will be useful,
10
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
11
# GNU General Public License for more details.
13
 
#
 
12
 
14
13
# You should have received a copy of the GNU General Public License
15
14
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
16
 
18
17
import os
19
18
import socket
20
 
import sys
21
19
import threading
22
 
import time
 
20
import unittest
 
21
 
 
22
from bzrlib.selftest import TestCaseInTempDir
 
23
from bzrlib.selftest.testtransport import TestTransportMixIn
23
24
 
24
25
try:
25
26
    import paramiko
 
27
    from stub_sftp import StubServer, StubSFTPServer
26
28
    paramiko_loaded = True
27
29
except ImportError:
28
30
    paramiko_loaded = False
29
31
 
30
 
from bzrlib import (
31
 
    bzrdir,
32
 
    config,
33
 
    errors,
34
 
    tests,
35
 
    transport as _mod_transport,
36
 
    ui,
37
 
    )
38
 
from bzrlib.osutils import (
39
 
    pathjoin,
40
 
    lexists,
41
 
    set_or_unset_env,
42
 
    )
43
 
from bzrlib.tests import (
44
 
    TestCaseWithTransport,
45
 
    TestCase,
46
 
    TestSkipped,
47
 
    )
48
 
from bzrlib.tests.http_server import HttpServer
49
 
from bzrlib.transport import get_transport
50
 
import bzrlib.transport.http
51
 
 
52
 
if paramiko_loaded:
53
 
    from bzrlib.transport import sftp as _mod_sftp
54
 
    from bzrlib.transport.sftp import (
55
 
        SFTPAbsoluteServer,
56
 
        SFTPHomeDirServer,
57
 
        SFTPTransport,
58
 
        )
59
 
 
60
 
from bzrlib.workingtree import WorkingTree
61
 
 
62
 
 
63
 
def set_test_transport_to_sftp(testcase):
64
 
    """A helper to set transports on test case instances."""
65
 
    if getattr(testcase, '_get_remote_is_absolute', None) is None:
66
 
        testcase._get_remote_is_absolute = True
67
 
    if testcase._get_remote_is_absolute:
68
 
        testcase.transport_server = SFTPAbsoluteServer
69
 
    else:
70
 
        testcase.transport_server = SFTPHomeDirServer
71
 
    testcase.transport_readonly_server = HttpServer
72
 
 
73
 
 
74
 
class TestCaseWithSFTPServer(TestCaseWithTransport):
75
 
    """A test case base class that provides a sftp server on localhost."""
76
 
 
77
 
    def setUp(self):
78
 
        super(TestCaseWithSFTPServer, self).setUp()
79
 
        if not paramiko_loaded:
80
 
            raise TestSkipped('you must have paramiko to run this test')
81
 
        set_test_transport_to_sftp(self)
82
 
 
83
 
 
84
 
class SFTPLockTests(TestCaseWithSFTPServer):
85
 
 
86
 
    def test_sftp_locks(self):
87
 
        from bzrlib.errors import LockError
88
 
        t = self.get_transport()
89
 
 
90
 
        l = t.lock_write('bogus')
91
 
        self.failUnlessExists('bogus.write-lock')
92
 
 
93
 
        # Don't wait for the lock, locking an already locked
94
 
        # file should raise an assert
95
 
        self.assertRaises(LockError, t.lock_write, 'bogus')
96
 
 
97
 
        l.unlock()
98
 
        self.failIf(lexists('bogus.write-lock'))
99
 
 
100
 
        open('something.write-lock', 'wb').write('fake lock\n')
101
 
        self.assertRaises(LockError, t.lock_write, 'something')
102
 
        os.remove('something.write-lock')
103
 
 
104
 
        l = t.lock_write('something')
105
 
 
106
 
        l2 = t.lock_write('bogus')
107
 
 
108
 
        l.unlock()
109
 
        l2.unlock()
110
 
 
111
 
 
112
 
class SFTPTransportTestRelative(TestCaseWithSFTPServer):
113
 
    """Test the SFTP transport with homedir based relative paths."""
114
 
 
115
 
    def test__remote_path(self):
116
 
        if sys.platform == 'darwin':
117
 
            # This test is about sftp absolute path handling. There is already
118
 
            # (in this test) a TODO about windows needing an absolute path
119
 
            # without drive letter. To me, using self.test_dir is a trick to
120
 
            # get an absolute path for comparison purposes.  That fails for OSX
121
 
            # because the sftp server doesn't resolve the links (and it doesn't
122
 
            # have to). --vila 20070924
123
 
            self.knownFailure('Mac OSX symlinks /tmp to /private/tmp,'
124
 
                              ' testing against self.test_dir'
125
 
                              ' is not appropriate')
126
 
        t = self.get_transport()
127
 
        # This test require unix-like absolute path
128
 
        test_dir = self.test_dir
129
 
        if sys.platform == 'win32':
130
 
            # using hack suggested by John Meinel.
131
 
            # TODO: write another mock server for this test
132
 
            #       and use absolute path without drive letter
133
 
            test_dir = '/' + test_dir
134
 
        # try what is currently used:
135
 
        # remote path = self._abspath(relpath)
136
 
        self.assertIsSameRealPath(test_dir + '/relative',
137
 
                                  t._remote_path('relative'))
138
 
        # we dont os.path.join because windows gives us the wrong path
139
 
        root_segments = test_dir.split('/')
140
 
        root_parent = '/'.join(root_segments[:-1])
141
 
        # .. should be honoured
142
 
        self.assertIsSameRealPath(root_parent + '/sibling',
143
 
                                  t._remote_path('../sibling'))
144
 
        # /  should be illegal ?
145
 
        ### FIXME decide and then test for all transports. RBC20051208
146
 
 
147
 
 
148
 
class SFTPTransportTestRelativeRoot(TestCaseWithSFTPServer):
149
 
    """Test the SFTP transport with homedir based relative paths."""
150
 
 
151
 
    def setUp(self):
152
 
        # Only SFTPHomeDirServer is tested here
153
 
        self._get_remote_is_absolute = False
154
 
        super(SFTPTransportTestRelativeRoot, self).setUp()
155
 
 
156
 
    def test__remote_path_relative_root(self):
157
 
        # relative paths are preserved
158
 
        t = self.get_transport('')
159
 
        self.assertEqual('/~/', t._path)
160
 
        # the remote path should be relative to home dir
161
 
        # (i.e. not begining with a '/')
162
 
        self.assertEqual('a', t._remote_path('a'))
163
 
 
164
 
 
165
 
class SFTPNonServerTest(TestCase):
166
 
    def setUp(self):
167
 
        TestCase.setUp(self)
168
 
        if not paramiko_loaded:
169
 
            raise TestSkipped('you must have paramiko to run this test')
170
 
 
171
 
    def test_parse_url_with_home_dir(self):
172
 
        s = SFTPTransport('sftp://ro%62ey:h%40t@example.com:2222/~/relative')
173
 
        self.assertEquals(s._host, 'example.com')
174
 
        self.assertEquals(s._port, 2222)
175
 
        self.assertEquals(s._user, 'robey')
176
 
        self.assertEquals(s._password, 'h@t')
177
 
        self.assertEquals(s._path, '/~/relative/')
178
 
 
179
 
    def test_relpath(self):
180
 
        s = SFTPTransport('sftp://user@host.com/abs/path')
181
 
        self.assertRaises(errors.PathNotChild, s.relpath,
182
 
                          'sftp://user@host.com/~/rel/path/sub')
183
 
 
184
 
    def test_get_paramiko_vendor(self):
185
 
        """Test that if no 'ssh' is available we get builtin paramiko"""
186
 
        from bzrlib.transport import ssh
187
 
        # set '.' as the only location in the path, forcing no 'ssh' to exist
188
 
        orig_vendor = ssh._ssh_vendor_manager._cached_ssh_vendor
189
 
        orig_path = set_or_unset_env('PATH', '.')
190
 
        try:
191
 
            # No vendor defined yet, query for one
192
 
            ssh._ssh_vendor_manager.clear_cache()
193
 
            vendor = ssh._get_ssh_vendor()
194
 
            self.assertIsInstance(vendor, ssh.ParamikoVendor)
195
 
        finally:
196
 
            set_or_unset_env('PATH', orig_path)
197
 
            ssh._ssh_vendor_manager._cached_ssh_vendor = orig_vendor
198
 
 
199
 
    def test_abspath_root_sibling_server(self):
200
 
        from bzrlib.transport.sftp import SFTPSiblingAbsoluteServer
201
 
        server = SFTPSiblingAbsoluteServer()
202
 
        server.setUp()
203
 
        try:
204
 
            transport = get_transport(server.get_url())
205
 
            self.assertFalse(transport.abspath('/').endswith('/~/'))
206
 
            self.assertTrue(transport.abspath('/').endswith('/'))
207
 
            del transport
208
 
        finally:
209
 
            server.tearDown()
210
 
 
211
 
 
212
 
class SFTPBranchTest(TestCaseWithSFTPServer):
213
 
    """Test some stuff when accessing a bzr Branch over sftp"""
214
 
 
215
 
    def test_lock_file(self):
216
 
        # old format branches use a special lock file on sftp.
217
 
        b = self.make_branch('', format=bzrdir.BzrDirFormat6())
218
 
        b = bzrlib.branch.Branch.open(self.get_url())
219
 
        self.failUnlessExists('.bzr/')
220
 
        self.failUnlessExists('.bzr/branch-format')
221
 
        self.failUnlessExists('.bzr/branch-lock')
222
 
 
223
 
        self.failIf(lexists('.bzr/branch-lock.write-lock'))
224
 
        b.lock_write()
225
 
        self.failUnlessExists('.bzr/branch-lock.write-lock')
226
 
        b.unlock()
227
 
        self.failIf(lexists('.bzr/branch-lock.write-lock'))
228
 
 
229
 
    def test_push_support(self):
230
 
        self.build_tree(['a/', 'a/foo'])
231
 
        t = bzrdir.BzrDir.create_standalone_workingtree('a')
232
 
        b = t.branch
233
 
        t.add('foo')
234
 
        t.commit('foo', rev_id='a1')
235
 
 
236
 
        b2 = bzrdir.BzrDir.create_branch_and_repo(self.get_url('/b'))
237
 
        b2.pull(b)
238
 
 
239
 
        self.assertEquals(b2.revision_history(), ['a1'])
240
 
 
241
 
        open('a/foo', 'wt').write('something new in foo\n')
242
 
        t.commit('new', rev_id='a2')
243
 
        b2.pull(b)
244
 
 
245
 
        self.assertEquals(b2.revision_history(), ['a1', 'a2'])
246
 
 
247
 
 
248
 
class SSHVendorConnection(TestCaseWithSFTPServer):
249
 
    """Test that the ssh vendors can all connect.
250
 
 
251
 
    Verify that a full-handshake (SSH over loopback TCP) sftp connection works.
252
 
 
253
 
    We have 3 sftp implementations in the test suite:
254
 
      'loopback': Doesn't use ssh, just uses a local socket. Most tests are
255
 
                  done this way to save the handshaking time, so it is not
256
 
                  tested again here
257
 
      'none':     This uses paramiko's built-in ssh client and server, and layers
258
 
                  sftp on top of it.
259
 
      None:       If 'ssh' exists on the machine, then it will be spawned as a
260
 
                  child process.
261
 
    """
262
 
 
263
 
    def setUp(self):
264
 
        super(SSHVendorConnection, self).setUp()
265
 
        from bzrlib.transport.sftp import SFTPFullAbsoluteServer
266
 
 
267
 
        def create_server():
268
 
            """Just a wrapper so that when created, it will set _vendor"""
269
 
            # SFTPFullAbsoluteServer can handle any vendor,
270
 
            # it just needs to be set between the time it is instantiated
271
 
            # and the time .setUp() is called
272
 
            server = SFTPFullAbsoluteServer()
273
 
            server._vendor = self._test_vendor
274
 
            return server
275
 
        self._test_vendor = 'loopback'
276
 
        self.vfs_transport_server = create_server
277
 
        f = open('a_file', 'wb')
278
 
        try:
279
 
            f.write('foobar\n')
280
 
        finally:
281
 
            f.close()
282
 
 
283
 
    def set_vendor(self, vendor):
284
 
        self._test_vendor = vendor
285
 
 
286
 
    def test_connection_paramiko(self):
287
 
        from bzrlib.transport import ssh
288
 
        self.set_vendor(ssh.ParamikoVendor())
289
 
        t = self.get_transport()
290
 
        self.assertEqual('foobar\n', t.get('a_file').read())
291
 
 
292
 
    def test_connection_vendor(self):
293
 
        raise TestSkipped("We don't test spawning real ssh,"
294
 
                          " because it prompts for a password."
295
 
                          " Enable this test if we figure out"
296
 
                          " how to prevent this.")
297
 
        self.set_vendor(None)
298
 
        t = self.get_transport()
299
 
        self.assertEqual('foobar\n', t.get('a_file').read())
300
 
 
301
 
 
302
 
class SSHVendorBadConnection(TestCaseWithTransport):
303
 
    """Test that the ssh vendors handle bad connection properly
304
 
 
305
 
    We don't subclass TestCaseWithSFTPServer, because we don't actually
306
 
    need an SFTP connection.
307
 
    """
308
 
 
309
 
    def setUp(self):
310
 
        if not paramiko_loaded:
311
 
            raise TestSkipped('you must have paramiko to run this test')
312
 
        super(SSHVendorBadConnection, self).setUp()
313
 
        import bzrlib.transport.ssh
314
 
 
315
 
        # open a random port, so we know nobody else is using it
316
 
        # but don't actually listen on the port.
317
 
        s = socket.socket()
318
 
        s.bind(('localhost', 0))
319
 
        self.bogus_url = 'sftp://%s:%s/' % s.getsockname()
320
 
 
321
 
        orig_vendor = bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor
322
 
        def reset():
323
 
            bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor = orig_vendor
324
 
            s.close()
325
 
        self.addCleanup(reset)
326
 
 
327
 
    def set_vendor(self, vendor):
328
 
        import bzrlib.transport.ssh
329
 
        bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor = vendor
330
 
 
331
 
    def test_bad_connection_paramiko(self):
332
 
        """Test that a real connection attempt raises the right error"""
333
 
        from bzrlib.transport import ssh
334
 
        self.set_vendor(ssh.ParamikoVendor())
335
 
        t = bzrlib.transport.get_transport(self.bogus_url)
336
 
        self.assertRaises(errors.ConnectionError, t.get, 'foobar')
337
 
 
338
 
    def test_bad_connection_ssh(self):
339
 
        """None => auto-detect vendor"""
340
 
        self.set_vendor(None)
341
 
        # This is how I would normally test the connection code
342
 
        # it makes it very clear what we are testing.
343
 
        # However, 'ssh' will create stipple on the output, so instead
344
 
        # I'm using run_bzr_subprocess, and parsing the output
345
 
        # try:
346
 
        #     t = bzrlib.transport.get_transport(self.bogus_url)
347
 
        # except errors.ConnectionError:
348
 
        #     # Correct error
349
 
        #     pass
350
 
        # except errors.NameError, e:
351
 
        #     if 'SSHException' in str(e):
352
 
        #         raise TestSkipped('Known NameError bug in paramiko 1.6.1')
353
 
        #     raise
354
 
        # else:
355
 
        #     self.fail('Excepted ConnectionError to be raised')
356
 
 
357
 
        out, err = self.run_bzr_subprocess(['log', self.bogus_url], retcode=3)
358
 
        self.assertEqual('', out)
359
 
        if "NameError: global name 'SSHException'" in err:
360
 
            # We aren't fixing this bug, because it is a bug in
361
 
            # paramiko, but we know about it, so we don't have to
362
 
            # fail the test
363
 
            raise TestSkipped('Known NameError bug with paramiko-1.6.1')
364
 
        self.assertContainsRe(err, r'bzr: ERROR: Unable to connect to SSH host'
365
 
                                   r' 127\.0\.0\.1:\d+; ')
366
 
 
367
 
 
368
 
class SFTPLatencyKnob(TestCaseWithSFTPServer):
369
 
    """Test that the testing SFTPServer's latency knob works."""
370
 
 
371
 
    def test_latency_knob_slows_transport(self):
372
 
        # change the latency knob to 500ms. We take about 40ms for a
373
 
        # loopback connection ordinarily.
374
 
        start_time = time.time()
375
 
        self.get_server().add_latency = 0.5
376
 
        transport = self.get_transport()
377
 
        transport.has('not me') # Force connection by issuing a request
378
 
        with_latency_knob_time = time.time() - start_time
379
 
        self.assertTrue(with_latency_knob_time > 0.4)
380
 
 
381
 
    def test_default(self):
382
 
        # This test is potentially brittle: under extremely high machine load
383
 
        # it could fail, but that is quite unlikely
384
 
        raise TestSkipped('Timing-sensitive test')
385
 
        start_time = time.time()
386
 
        transport = self.get_transport()
387
 
        transport.has('not me') # Force connection by issuing a request
388
 
        regular_time = time.time() - start_time
389
 
        self.assertTrue(regular_time < 0.5)
390
 
 
391
 
 
392
 
class FakeSocket(object):
393
 
    """Fake socket object used to test the SocketDelay wrapper without
394
 
    using a real socket.
395
 
    """
396
 
 
397
 
    def __init__(self):
398
 
        self._data = ""
399
 
 
400
 
    def send(self, data, flags=0):
401
 
        self._data += data
402
 
        return len(data)
403
 
 
404
 
    def sendall(self, data, flags=0):
405
 
        self._data += data
406
 
        return len(data)
407
 
 
408
 
    def recv(self, size, flags=0):
409
 
        if size < len(self._data):
410
 
            result = self._data[:size]
411
 
            self._data = self._data[size:]
412
 
            return result
413
 
        else:
414
 
            result = self._data
415
 
            self._data = ""
416
 
            return result
417
 
 
418
 
 
419
 
class TestSocketDelay(TestCase):
420
 
 
421
 
    def setUp(self):
422
 
        TestCase.setUp(self)
423
 
        if not paramiko_loaded:
424
 
            raise TestSkipped('you must have paramiko to run this test')
425
 
 
426
 
    def test_delay(self):
427
 
        from bzrlib.transport.sftp import SocketDelay
428
 
        sending = FakeSocket()
429
 
        receiving = SocketDelay(sending, 0.1, bandwidth=1000000,
430
 
                                really_sleep=False)
431
 
        # check that simulated time is charged only per round-trip:
432
 
        t1 = SocketDelay.simulated_time
433
 
        receiving.send("connect1")
434
 
        self.assertEqual(sending.recv(1024), "connect1")
435
 
        t2 = SocketDelay.simulated_time
436
 
        self.assertAlmostEqual(t2 - t1, 0.1)
437
 
        receiving.send("connect2")
438
 
        self.assertEqual(sending.recv(1024), "connect2")
439
 
        sending.send("hello")
440
 
        self.assertEqual(receiving.recv(1024), "hello")
441
 
        t3 = SocketDelay.simulated_time
442
 
        self.assertAlmostEqual(t3 - t2, 0.1)
443
 
        sending.send("hello")
444
 
        self.assertEqual(receiving.recv(1024), "hello")
445
 
        sending.send("hello")
446
 
        self.assertEqual(receiving.recv(1024), "hello")
447
 
        sending.send("hello")
448
 
        self.assertEqual(receiving.recv(1024), "hello")
449
 
        t4 = SocketDelay.simulated_time
450
 
        self.assertAlmostEqual(t4, t3)
451
 
 
452
 
    def test_bandwidth(self):
453
 
        from bzrlib.transport.sftp import SocketDelay
454
 
        sending = FakeSocket()
455
 
        receiving = SocketDelay(sending, 0, bandwidth=8.0/(1024*1024),
456
 
                                really_sleep=False)
457
 
        # check that simulated time is charged only per round-trip:
458
 
        t1 = SocketDelay.simulated_time
459
 
        receiving.send("connect")
460
 
        self.assertEqual(sending.recv(1024), "connect")
461
 
        sending.send("a" * 100)
462
 
        self.assertEqual(receiving.recv(1024), "a" * 100)
463
 
        t2 = SocketDelay.simulated_time
464
 
        self.assertAlmostEqual(t2 - t1, 100 + 7)
465
 
 
466
 
 
467
 
class ReadvFile(object):
468
 
    """An object that acts like Paramiko's SFTPFile.readv()"""
469
 
 
470
 
    def __init__(self, data):
471
 
        self._data = data
472
 
 
473
 
    def readv(self, requests):
474
 
        for start, length in requests:
475
 
            yield self._data[start:start+length]
476
 
 
477
 
 
478
 
def _null_report_activity(*a, **k):
479
 
    pass
480
 
 
481
 
 
482
 
class Test_SFTPReadvHelper(tests.TestCase):
483
 
 
484
 
    def checkGetRequests(self, expected_requests, offsets):
485
 
        if not paramiko_loaded:
486
 
            raise TestSkipped('you must have paramiko to run this test')
487
 
        helper = _mod_sftp._SFTPReadvHelper(offsets, 'artificial_test',
488
 
            _null_report_activity)
489
 
        self.assertEqual(expected_requests, helper._get_requests())
490
 
 
491
 
    def test__get_requests(self):
492
 
        # Small single requests become a single readv request
493
 
        self.checkGetRequests([(0, 100)],
494
 
                              [(0, 20), (30, 50), (20, 10), (80, 20)])
495
 
        # Non-contiguous ranges are given as multiple requests
496
 
        self.checkGetRequests([(0, 20), (30, 50)],
497
 
                              [(10, 10), (30, 20), (0, 10), (50, 30)])
498
 
        # Ranges larger than _max_request_size (32kB) are broken up into
499
 
        # multiple requests, even if it actually spans multiple logical
500
 
        # requests
501
 
        self.checkGetRequests([(0, 32768), (32768, 32768), (65536, 464)],
502
 
                              [(0, 40000), (40000, 100), (40100, 1900),
503
 
                               (42000, 24000)])
504
 
 
505
 
    def checkRequestAndYield(self, expected, data, offsets):
506
 
        if not paramiko_loaded:
507
 
            raise TestSkipped('you must have paramiko to run this test')
508
 
        helper = _mod_sftp._SFTPReadvHelper(offsets, 'artificial_test',
509
 
            _null_report_activity)
510
 
        data_f = ReadvFile(data)
511
 
        result = list(helper.request_and_yield_offsets(data_f))
512
 
        self.assertEqual(expected, result)
513
 
 
514
 
    def test_request_and_yield_offsets(self):
515
 
        data = 'abcdefghijklmnopqrstuvwxyz'
516
 
        self.checkRequestAndYield([(0, 'a'), (5, 'f'), (10, 'klm')], data,
517
 
                                  [(0, 1), (5, 1), (10, 3)])
518
 
        # Should combine requests, and split them again
519
 
        self.checkRequestAndYield([(0, 'a'), (1, 'b'), (10, 'klm')], data,
520
 
                                  [(0, 1), (1, 1), (10, 3)])
521
 
        # Out of order requests. The requests should get combined, but then be
522
 
        # yielded out-of-order. We also need one that is at the end of a
523
 
        # previous range. See bug #293746
524
 
        self.checkRequestAndYield([(0, 'a'), (10, 'k'), (4, 'efg'), (1, 'bcd')],
525
 
                                  data, [(0, 1), (10, 1), (4, 3), (1, 3)])
526
 
 
527
 
 
528
 
class TestUsesAuthConfig(TestCaseWithSFTPServer):
529
 
    """Test that AuthenticationConfig can supply default usernames."""
530
 
 
531
 
    def get_transport_for_connection(self, set_config):
532
 
        port = self.get_server()._listener.port
533
 
        if set_config:
534
 
            conf = config.AuthenticationConfig()
535
 
            conf._get_config().update(
536
 
                {'sftptest': {'scheme': 'ssh', 'port': port, 'user': 'bar'}})
537
 
            conf._save()
538
 
        t = get_transport('sftp://localhost:%d' % port)
539
 
        # force a connection to be performed.
540
 
        t.has('foo')
541
 
        return t
542
 
 
543
 
    def test_sftp_uses_config(self):
544
 
        t = self.get_transport_for_connection(set_config=True)
545
 
        self.assertEqual('bar', t._get_credentials()[0])
546
 
 
547
 
    def test_sftp_is_none_if_no_config(self):
548
 
        t = self.get_transport_for_connection(set_config=False)
549
 
        self.assertIs(None, t._get_credentials()[0])
550
 
 
551
 
    def test_sftp_doesnt_prompt_username(self):
552
 
        stdout = tests.StringIOWrapper()
553
 
        ui.ui_factory = tests.TestUIFactory(stdin='joe\nfoo\n', stdout=stdout)
554
 
        t = self.get_transport_for_connection(set_config=False)
555
 
        self.assertIs(None, t._get_credentials()[0])
556
 
        # No prompts should've been printed, stdin shouldn't have been read
557
 
        self.assertEquals("", stdout.getvalue())
558
 
        self.assertEquals(0, ui.ui_factory.stdin.tell())
 
32
 
 
33
STUB_SERVER_KEY = """
 
34
-----BEGIN RSA PRIVATE KEY-----
 
35
MIICWgIBAAKBgQDTj1bqB4WmayWNPB+8jVSYpZYk80Ujvj680pOTh2bORBjbIAyz
 
36
oWGW+GUjzKxTiiPvVmxFgx5wdsFvF03v34lEVVhMpouqPAYQ15N37K/ir5XY+9m/
 
37
d8ufMCkjeXsQkKqFbAlQcnWMCRnOoPHS3I4vi6hmnDDeeYTSRvfLbW0fhwIBIwKB
 
38
gBIiOqZYaoqbeD9OS9z2K9KR2atlTxGxOJPXiP4ESqP3NVScWNwyZ3NXHpyrJLa0
 
39
EbVtzsQhLn6rF+TzXnOlcipFvjsem3iYzCpuChfGQ6SovTcOjHV9z+hnpXvQ/fon
 
40
soVRZY65wKnF7IAoUwTmJS9opqgrN6kRgCd3DASAMd1bAkEA96SBVWFt/fJBNJ9H
 
41
tYnBKZGw0VeHOYmVYbvMSstssn8un+pQpUm9vlG/bp7Oxd/m+b9KWEh2xPfv6zqU
 
42
avNwHwJBANqzGZa/EpzF4J8pGti7oIAPUIDGMtfIcmqNXVMckrmzQ2vTfqtkEZsA
 
43
4rE1IERRyiJQx6EJsz21wJmGV9WJQ5kCQQDwkS0uXqVdFzgHO6S++tjmjYcxwr3g
 
44
H0CoFYSgbddOT6miqRskOQF3DZVkJT3kyuBgU2zKygz52ukQZMqxCb1fAkASvuTv
 
45
qfpH87Qq5kQhNKdbbwbmd2NxlNabazPijWuphGTdW0VfJdWfklyS2Kr+iqrs/5wV
 
46
HhathJt636Eg7oIjAkA8ht3MQ+XSl9yIJIS8gVpbPxSw5OMfw0PjVE7tBdQruiSc
 
47
nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7
 
48
-----END RSA PRIVATE KEY-----
 
49
"""
 
50
    
 
51
 
 
52
class SingleListener (threading.Thread):
 
53
    def __init__(self, callback):
 
54
        threading.Thread.__init__(self)
 
55
        self._callback = callback
 
56
        self._socket = socket.socket()
 
57
        self._socket.listen(1)
 
58
        self.port = self._socket.getsockname()[1]
 
59
        self.stop_event = threading.Event()
 
60
 
 
61
    def run(self):
 
62
        s, _ = self._socket.accept()
 
63
        # now close the listen socket
 
64
        self._socket.close()
 
65
        self._callback(s, self.stop_event)
 
66
    
 
67
    def stop(self):
 
68
        self.stop_event.set()
 
69
        
 
70
        
 
71
class TestCaseWithSFTPServer (TestCaseInTempDir):
 
72
    """
 
73
    Execute a test case with a stub SFTP server, serving files from the local
 
74
    filesystem over the loopback network.
 
75
    """
 
76
    
 
77
    def _run_server(self, s, stop_event):
 
78
        ssh_server = paramiko.Transport(s)
 
79
        key_file = os.path.join(self._root, 'test_rsa.key')
 
80
        file(key_file, 'w').write(STUB_SERVER_KEY)
 
81
        host_key = paramiko.RSAKey.from_private_key_file(key_file)
 
82
        ssh_server.add_server_key(host_key)
 
83
        server = StubServer()
 
84
        ssh_server.set_subsystem_handler('sftp', paramiko.SFTPServer, StubSFTPServer, root=self._root)
 
85
        event = threading.Event()
 
86
        ssh_server.start_server(event, server)
 
87
        event.wait(5.0)
 
88
        stop_event.wait(30.0)
 
89
 
 
90
    def setUp(self):
 
91
        TestCaseInTempDir.setUp(self)
 
92
        self._root = self.test_dir
 
93
 
 
94
        self._listener = SingleListener(self._run_server)
 
95
        self._listener.setDaemon(True)
 
96
        self._listener.start()        
 
97
        self._sftp_url = 'sftp://foo:bar@localhost:%d/' % (self._listener.port,)
 
98
        
 
99
    def tearDown(self):
 
100
        self._listener.stop()
 
101
        TestCaseInTempDir.tearDown(self)
 
102
 
 
103
        
 
104
class SFTPTransportTest (TestCaseWithSFTPServer, TestTransportMixIn):
 
105
    readonly = False
 
106
 
 
107
    def get_transport(self):
 
108
        from bzrlib.transport.sftp import SFTPTransport
 
109
        url = self._sftp_url
 
110
        return SFTPTransport(url)
 
111
 
 
112
if not paramiko_loaded:
 
113
    del SFTPTransportTest