~bzr-pqm/bzr/bzr.dev

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

import os
import socket
import sys
import threading
import time

try:
    import paramiko
    paramiko_loaded = True
except ImportError:
    paramiko_loaded = False

from bzrlib import (
    bzrdir,
    errors,
    )
from bzrlib.osutils import (
    pathjoin,
    lexists,
    set_or_unset_env,
    )
from bzrlib.tests import (
    TestCaseWithTransport,
    TestCase,
    TestSkipped,
    )
from bzrlib.tests.http_server import HttpServer
from bzrlib.transport import get_transport
import bzrlib.transport.http

if paramiko_loaded:
    from bzrlib.transport.sftp import (
        SFTPAbsoluteServer,
        SFTPHomeDirServer,
        SFTPTransport,
        )

from bzrlib.workingtree import WorkingTree


def set_test_transport_to_sftp(testcase):
    """A helper to set transports on test case instances."""
    if getattr(testcase, '_get_remote_is_absolute', None) is None:
        testcase._get_remote_is_absolute = True
    if testcase._get_remote_is_absolute:
        testcase.transport_server = SFTPAbsoluteServer
    else:
        testcase.transport_server = SFTPHomeDirServer
    testcase.transport_readonly_server = HttpServer


class TestCaseWithSFTPServer(TestCaseWithTransport):
    """A test case base class that provides a sftp server on localhost."""

    def setUp(self):
        super(TestCaseWithSFTPServer, self).setUp()
        if not paramiko_loaded:
            raise TestSkipped('you must have paramiko to run this test')
        set_test_transport_to_sftp(self)


class SFTPLockTests (TestCaseWithSFTPServer):

    def test_sftp_locks(self):
        from bzrlib.errors import LockError
        t = self.get_transport()

        l = t.lock_write('bogus')
        self.failUnlessExists('bogus.write-lock')

        # Don't wait for the lock, locking an already locked
        # file should raise an assert
        self.assertRaises(LockError, t.lock_write, 'bogus')

        l.unlock()
        self.failIf(lexists('bogus.write-lock'))

        open('something.write-lock', 'wb').write('fake lock\n')
        self.assertRaises(LockError, t.lock_write, 'something')
        os.remove('something.write-lock')

        l = t.lock_write('something')

        l2 = t.lock_write('bogus')

        l.unlock()
        l2.unlock()


class SFTPTransportTestRelative(TestCaseWithSFTPServer):
    """Test the SFTP transport with homedir based relative paths."""

    def test__remote_path(self):
        if sys.platform == 'darwin':
            # This test is about sftp absolute path handling. There is already
            # (in this test) a TODO about windows needing an absolute path
            # without drive letter. To me, using self.test_dir is a trick to
            # get an absolute path for comparison purposes.  That fails for OSX
            # because the sftp server doesn't resolve the links (and it doesn't
            # have to). --vila 20070924
            self.knownFailure('Mac OSX symlinks /tmp to /private/tmp,'
                              ' testing against self.test_dir'
                              ' is not appropriate')
        t = self.get_transport()
        # This test require unix-like absolute path
        test_dir = self.test_dir
        if sys.platform == 'win32':
            # using hack suggested by John Meinel.
            # TODO: write another mock server for this test
            #       and use absolute path without drive letter
            test_dir = '/' + test_dir
        # try what is currently used:
        # remote path = self._abspath(relpath)
        self.assertIsSameRealPath(test_dir + '/relative',
                                  t._remote_path('relative'))
        # we dont os.path.join because windows gives us the wrong path
        root_segments = test_dir.split('/')
        root_parent = '/'.join(root_segments[:-1])
        # .. should be honoured
        self.assertIsSameRealPath(root_parent + '/sibling',
                                  t._remote_path('../sibling'))
        # /  should be illegal ?
        ### FIXME decide and then test for all transports. RBC20051208


class SFTPTransportTestRelativeRoot(TestCaseWithSFTPServer):
    """Test the SFTP transport with homedir based relative paths."""

    def setUp(self):
        # Only SFTPHomeDirServer is tested here
        self._get_remote_is_absolute = False
        super(SFTPTransportTestRelativeRoot, self).setUp()

    def test__remote_path_relative_root(self):
        # relative paths are preserved
        t = self.get_transport('')
        self.assertEqual('/~/', t._path)
        # the remote path should be relative to home dir
        # (i.e. not begining with a '/')
        self.assertEqual('a', t._remote_path('a'))


class SFTPNonServerTest(TestCase):
    def setUp(self):
        TestCase.setUp(self)
        if not paramiko_loaded:
            raise TestSkipped('you must have paramiko to run this test')

    def test_parse_url_with_home_dir(self):
        s = SFTPTransport('sftp://ro%62ey:h%40t@example.com:2222/~/relative')
        self.assertEquals(s._host, 'example.com')
        self.assertEquals(s._port, 2222)
        self.assertEquals(s._user, 'robey')
        self.assertEquals(s._password, 'h@t')
        self.assertEquals(s._path, '/~/relative/')

    def test_relpath(self):
        s = SFTPTransport('sftp://user@host.com/abs/path')
        self.assertRaises(errors.PathNotChild, s.relpath,
                          'sftp://user@host.com/~/rel/path/sub')

    def test_get_paramiko_vendor(self):
        """Test that if no 'ssh' is available we get builtin paramiko"""
        from bzrlib.transport import ssh
        # set '.' as the only location in the path, forcing no 'ssh' to exist
        orig_vendor = ssh._ssh_vendor_manager._cached_ssh_vendor
        orig_path = set_or_unset_env('PATH', '.')
        try:
            # No vendor defined yet, query for one
            ssh._ssh_vendor_manager.clear_cache()
            vendor = ssh._get_ssh_vendor()
            self.assertIsInstance(vendor, ssh.ParamikoVendor)
        finally:
            set_or_unset_env('PATH', orig_path)
            ssh._ssh_vendor_manager._cached_ssh_vendor = orig_vendor

    def test_abspath_root_sibling_server(self):
        from bzrlib.transport.sftp import SFTPSiblingAbsoluteServer
        server = SFTPSiblingAbsoluteServer()
        server.setUp()
        try:
            transport = get_transport(server.get_url())
            self.assertFalse(transport.abspath('/').endswith('/~/'))
            self.assertTrue(transport.abspath('/').endswith('/'))
            del transport
        finally:
            server.tearDown()


class SFTPBranchTest(TestCaseWithSFTPServer):
    """Test some stuff when accessing a bzr Branch over sftp"""

    def test_lock_file(self):
        # old format branches use a special lock file on sftp.
        b = self.make_branch('', format=bzrdir.BzrDirFormat6())
        b = bzrlib.branch.Branch.open(self.get_url())
        self.failUnlessExists('.bzr/')
        self.failUnlessExists('.bzr/branch-format')
        self.failUnlessExists('.bzr/branch-lock')

        self.failIf(lexists('.bzr/branch-lock.write-lock'))
        b.lock_write()
        self.failUnlessExists('.bzr/branch-lock.write-lock')
        b.unlock()
        self.failIf(lexists('.bzr/branch-lock.write-lock'))

    def test_push_support(self):
        self.build_tree(['a/', 'a/foo'])
        t = bzrdir.BzrDir.create_standalone_workingtree('a')
        b = t.branch
        t.add('foo')
        t.commit('foo', rev_id='a1')

        b2 = bzrdir.BzrDir.create_branch_and_repo(self.get_url('/b'))
        b2.pull(b)

        self.assertEquals(b2.revision_history(), ['a1'])

        open('a/foo', 'wt').write('something new in foo\n')
        t.commit('new', rev_id='a2')
        b2.pull(b)

        self.assertEquals(b2.revision_history(), ['a1', 'a2'])


class SSHVendorConnection(TestCaseWithSFTPServer):
    """Test that the ssh vendors can all connect.

    Verify that a full-handshake (SSH over loopback TCP) sftp connection works.

    We have 3 sftp implementations in the test suite:
      'loopback': Doesn't use ssh, just uses a local socket. Most tests are
                  done this way to save the handshaking time, so it is not
                  tested again here
      'none':     This uses paramiko's built-in ssh client and server, and layers
                  sftp on top of it.
      None:       If 'ssh' exists on the machine, then it will be spawned as a
                  child process.
    """
    
    def setUp(self):
        super(SSHVendorConnection, self).setUp()
        from bzrlib.transport.sftp import SFTPFullAbsoluteServer

        def create_server():
            """Just a wrapper so that when created, it will set _vendor"""
            # SFTPFullAbsoluteServer can handle any vendor,
            # it just needs to be set between the time it is instantiated
            # and the time .setUp() is called
            server = SFTPFullAbsoluteServer()
            server._vendor = self._test_vendor
            return server
        self._test_vendor = 'loopback'
        self.vfs_transport_server = create_server
        f = open('a_file', 'wb')
        try:
            f.write('foobar\n')
        finally:
            f.close()

    def set_vendor(self, vendor):
        self._test_vendor = vendor

    def test_connection_paramiko(self):
        from bzrlib.transport import ssh
        self.set_vendor(ssh.ParamikoVendor())
        t = self.get_transport()
        self.assertEqual('foobar\n', t.get('a_file').read())

    def test_connection_vendor(self):
        raise TestSkipped("We don't test spawning real ssh,"
                          " because it prompts for a password."
                          " Enable this test if we figure out"
                          " how to prevent this.")
        self.set_vendor(None)
        t = self.get_transport()
        self.assertEqual('foobar\n', t.get('a_file').read())


class SSHVendorBadConnection(TestCaseWithTransport):
    """Test that the ssh vendors handle bad connection properly

    We don't subclass TestCaseWithSFTPServer, because we don't actually
    need an SFTP connection.
    """

    def setUp(self):
        if not paramiko_loaded:
            raise TestSkipped('you must have paramiko to run this test')
        super(SSHVendorBadConnection, self).setUp()
        import bzrlib.transport.ssh

        # open a random port, so we know nobody else is using it
        # but don't actually listen on the port.
        s = socket.socket()
        s.bind(('localhost', 0))
        self.bogus_url = 'sftp://%s:%s/' % s.getsockname()

        orig_vendor = bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor
        def reset():
            bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor = orig_vendor
            s.close()
        self.addCleanup(reset)

    def set_vendor(self, vendor):
        import bzrlib.transport.ssh
        bzrlib.transport.ssh._ssh_vendor_manager._cached_ssh_vendor = vendor

    def test_bad_connection_paramiko(self):
        """Test that a real connection attempt raises the right error"""
        from bzrlib.transport import ssh
        self.set_vendor(ssh.ParamikoVendor())
        t = bzrlib.transport.get_transport(self.bogus_url)
        self.assertRaises(errors.ConnectionError, t.get, 'foobar')

    def test_bad_connection_ssh(self):
        """None => auto-detect vendor"""
        self.set_vendor(None)
        # This is how I would normally test the connection code
        # it makes it very clear what we are testing.
        # However, 'ssh' will create stipple on the output, so instead
        # I'm using run_bzr_subprocess, and parsing the output
        # try:
        #     t = bzrlib.transport.get_transport(self.bogus_url)
        # except errors.ConnectionError:
        #     # Correct error
        #     pass
        # except errors.NameError, e:
        #     if 'SSHException' in str(e):
        #         raise TestSkipped('Known NameError bug in paramiko 1.6.1')
        #     raise
        # else:
        #     self.fail('Excepted ConnectionError to be raised')

        out, err = self.run_bzr_subprocess(['log', self.bogus_url], retcode=3)
        self.assertEqual('', out)
        if "NameError: global name 'SSHException'" in err:
            # We aren't fixing this bug, because it is a bug in
            # paramiko, but we know about it, so we don't have to
            # fail the test
            raise TestSkipped('Known NameError bug with paramiko-1.6.1')
        self.assertContainsRe(err, r'bzr: ERROR: Unable to connect to SSH host'
                                   r' 127\.0\.0\.1:\d+; ')


class SFTPLatencyKnob(TestCaseWithSFTPServer):
    """Test that the testing SFTPServer's latency knob works."""

    def test_latency_knob_slows_transport(self):
        # change the latency knob to 500ms. We take about 40ms for a 
        # loopback connection ordinarily.
        start_time = time.time()
        self.get_server().add_latency = 0.5
        transport = self.get_transport()
        transport.has('not me') # Force connection by issuing a request
        with_latency_knob_time = time.time() - start_time
        self.assertTrue(with_latency_knob_time > 0.4)

    def test_default(self):
        # This test is potentially brittle: under extremely high machine load
        # it could fail, but that is quite unlikely
        raise TestSkipped('Timing-sensitive test')
        start_time = time.time()
        transport = self.get_transport()
        transport.has('not me') # Force connection by issuing a request
        regular_time = time.time() - start_time
        self.assertTrue(regular_time < 0.5)


class FakeSocket(object):
    """Fake socket object used to test the SocketDelay wrapper without
    using a real socket.
    """

    def __init__(self):
        self._data = ""

    def send(self, data, flags=0):
        self._data += data
        return len(data)

    def sendall(self, data, flags=0):
        self._data += data
        return len(data)

    def recv(self, size, flags=0):
        if size < len(self._data):
            result = self._data[:size]
            self._data = self._data[size:]
            return result
        else:
            result = self._data
            self._data = ""
            return result


class TestSocketDelay(TestCase):

    def setUp(self):
        TestCase.setUp(self)
        if not paramiko_loaded:
            raise TestSkipped('you must have paramiko to run this test')

    def test_delay(self):
        from bzrlib.transport.sftp import SocketDelay
        sending = FakeSocket()
        receiving = SocketDelay(sending, 0.1, bandwidth=1000000,
                                really_sleep=False)
        # check that simulated time is charged only per round-trip:
        t1 = SocketDelay.simulated_time
        receiving.send("connect1")
        self.assertEqual(sending.recv(1024), "connect1")
        t2 = SocketDelay.simulated_time
        self.assertAlmostEqual(t2 - t1, 0.1)
        receiving.send("connect2")
        self.assertEqual(sending.recv(1024), "connect2")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        t3 = SocketDelay.simulated_time
        self.assertAlmostEqual(t3 - t2, 0.1)
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        t4 = SocketDelay.simulated_time
        self.assertAlmostEqual(t4, t3)

    def test_bandwidth(self):
        from bzrlib.transport.sftp import SocketDelay
        sending = FakeSocket()
        receiving = SocketDelay(sending, 0, bandwidth=8.0/(1024*1024),
                                really_sleep=False)
        # check that simulated time is charged only per round-trip:
        t1 = SocketDelay.simulated_time
        receiving.send("connect")
        self.assertEqual(sending.recv(1024), "connect")
        sending.send("a" * 100)
        self.assertEqual(receiving.recv(1024), "a" * 100)
        t2 = SocketDelay.simulated_time
        self.assertAlmostEqual(t2 - t1, 100 + 7)