~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

Merge the fix for bug #819604 into trunk, resolve conflicts.

Show diffs side-by-side

added added

removed removed

Lines of Context:
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
21
21
import doctest
 
22
import errno
22
23
import os
23
24
import socket
 
25
import subprocess
24
26
import sys
25
27
import threading
26
28
import time
30
32
import bzrlib
31
33
from bzrlib import (
32
34
        bzrdir,
 
35
        debug,
33
36
        errors,
34
37
        osutils,
35
38
        tests,
59
62
        )
60
63
 
61
64
 
 
65
def create_file_pipes():
 
66
    r, w = os.pipe()
 
67
    # These must be opened without buffering, or we get undefined results
 
68
    rf = os.fdopen(r, 'rb', 0)
 
69
    wf = os.fdopen(w, 'wb', 0)
 
70
    return rf, wf
 
71
 
 
72
 
62
73
def portable_socket_pair():
63
74
    """Return a pair of TCP sockets connected to each other.
64
75
 
88
99
        return StringIOSSHConnection(self)
89
100
 
90
101
 
 
102
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
 
103
    """The first connection will be considered closed.
 
104
 
 
105
    The second connection will succeed normally.
 
106
    """
 
107
 
 
108
    def __init__(self, read_from, write_to, fail_at_write=True):
 
109
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
 
110
            write_to)
 
111
        self.fail_at_write = fail_at_write
 
112
        self._first = True
 
113
 
 
114
    def connect_ssh(self, username, password, host, port, command):
 
115
        self.calls.append(('connect_ssh', username, password, host, port,
 
116
            command))
 
117
        if self._first:
 
118
            self._first = False
 
119
            return ClosedSSHConnection(self)
 
120
        return StringIOSSHConnection(self)
 
121
 
 
122
 
91
123
class StringIOSSHConnection(ssh.SSHConnection):
92
124
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
93
125
 
103
135
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
104
136
 
105
137
 
 
138
class ClosedSSHConnection(ssh.SSHConnection):
 
139
    """An SSH connection that just has closed channels."""
 
140
 
 
141
    def __init__(self, vendor):
 
142
        self.vendor = vendor
 
143
 
 
144
    def close(self):
 
145
        self.vendor.calls.append(('close', ))
 
146
 
 
147
    def get_sock_or_pipes(self):
 
148
        # We create matching pipes, and then close the ssh side
 
149
        bzr_read, ssh_write = create_file_pipes()
 
150
        # We always fail when bzr goes to read
 
151
        ssh_write.close()
 
152
        if self.vendor.fail_at_write:
 
153
            # If set, we'll also fail when bzr goes to write
 
154
            ssh_read, bzr_write = create_file_pipes()
 
155
            ssh_read.close()
 
156
        else:
 
157
            bzr_write = self.vendor.write_to
 
158
        return 'pipes', (bzr_read, bzr_write)
 
159
 
 
160
 
106
161
class _InvalidHostnameFeature(features.Feature):
107
162
    """Does 'non_existent.invalid' fail to resolve?
108
163
 
198
253
        client_medium._accept_bytes('abc')
199
254
        self.assertEqual('abc', output.getvalue())
200
255
 
 
256
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
 
257
        # It is unfortunate that we have to use Popen for this. However,
 
258
        # os.pipe() does not behave the same as subprocess.Popen().
 
259
        # On Windows, if you use os.pipe() and close the write side,
 
260
        # read.read() hangs. On Linux, read.read() returns the empty string.
 
261
        p = subprocess.Popen([sys.executable, '-c',
 
262
            'import sys\n'
 
263
            'sys.stdout.write(sys.stdin.read(4))\n'
 
264
            'sys.stdout.close()\n'],
 
265
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
266
        client_medium = medium.SmartSimplePipesClientMedium(
 
267
            p.stdout, p.stdin, 'base')
 
268
        client_medium._accept_bytes('abc\n')
 
269
        self.assertEqual('abc', client_medium._read_bytes(3))
 
270
        p.wait()
 
271
        # While writing to the underlying pipe,
 
272
        #   Windows py2.6.6 we get IOError(EINVAL)
 
273
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
274
        # In both cases, it should be wrapped to ConnectionReset
 
275
        self.assertRaises(errors.ConnectionReset,
 
276
                          client_medium._accept_bytes, 'more')
 
277
 
 
278
    def test_simple_pipes__accept_bytes_pipe_closed(self):
 
279
        child_read, client_write = create_file_pipes()
 
280
        client_medium = medium.SmartSimplePipesClientMedium(
 
281
            None, client_write, 'base')
 
282
        client_medium._accept_bytes('abc\n')
 
283
        self.assertEqual('abc\n', child_read.read(4))
 
284
        # While writing to the underlying pipe,
 
285
        #   Windows py2.6.6 we get IOError(EINVAL)
 
286
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
287
        # In both cases, it should be wrapped to ConnectionReset
 
288
        child_read.close()
 
289
        self.assertRaises(errors.ConnectionReset,
 
290
                          client_medium._accept_bytes, 'more')
 
291
 
 
292
    def test_simple_pipes__flush_pipe_closed(self):
 
293
        child_read, client_write = create_file_pipes()
 
294
        client_medium = medium.SmartSimplePipesClientMedium(
 
295
            None, client_write, 'base')
 
296
        client_medium._accept_bytes('abc\n')
 
297
        child_read.close()
 
298
        # Even though the pipe is closed, flush on the write side seems to be a
 
299
        # no-op, rather than a failure.
 
300
        client_medium._flush()
 
301
 
 
302
    def test_simple_pipes__flush_subprocess_closed(self):
 
303
        p = subprocess.Popen([sys.executable, '-c',
 
304
            'import sys\n'
 
305
            'sys.stdout.write(sys.stdin.read(4))\n'
 
306
            'sys.stdout.close()\n'],
 
307
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
308
        client_medium = medium.SmartSimplePipesClientMedium(
 
309
            p.stdout, p.stdin, 'base')
 
310
        client_medium._accept_bytes('abc\n')
 
311
        p.wait()
 
312
        # Even though the child process is dead, flush seems to be a no-op.
 
313
        client_medium._flush()
 
314
 
 
315
    def test_simple_pipes__read_bytes_pipe_closed(self):
 
316
        child_read, client_write = create_file_pipes()
 
317
        client_medium = medium.SmartSimplePipesClientMedium(
 
318
            child_read, client_write, 'base')
 
319
        client_medium._accept_bytes('abc\n')
 
320
        client_write.close()
 
321
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
322
        self.assertEqual('', client_medium._read_bytes(4))
 
323
 
 
324
    def test_simple_pipes__read_bytes_subprocess_closed(self):
 
325
        p = subprocess.Popen([sys.executable, '-c',
 
326
            'import sys\n'
 
327
            'if sys.platform == "win32":\n'
 
328
            '    import msvcrt, os\n'
 
329
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
330
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
331
            'sys.stdout.write(sys.stdin.read(4))\n'
 
332
            'sys.stdout.close()\n'],
 
333
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
334
        client_medium = medium.SmartSimplePipesClientMedium(
 
335
            p.stdout, p.stdin, 'base')
 
336
        client_medium._accept_bytes('abc\n')
 
337
        p.wait()
 
338
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
339
        self.assertEqual('', client_medium._read_bytes(4))
 
340
 
201
341
    def test_simple_pipes_client_disconnect_does_nothing(self):
202
342
        # calling disconnect does nothing.
203
343
        input = StringIO()
585
725
        request.finished_reading()
586
726
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
587
727
 
 
728
    def test_reset(self):
 
729
        server_sock, client_sock = portable_socket_pair()
 
730
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
 
731
        #       bzr where it exists.
 
732
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
733
        client_medium._socket = client_sock
 
734
        client_medium._connected = True
 
735
        req = client_medium.get_request()
 
736
        self.assertRaises(errors.TooManyConcurrentRequests,
 
737
            client_medium.get_request)
 
738
        client_medium.reset()
 
739
        # The stream should be reset, marked as disconnected, though ready for
 
740
        # us to make a new request
 
741
        self.assertFalse(client_medium._connected)
 
742
        self.assertIs(None, client_medium._socket)
 
743
        try:
 
744
            self.assertEqual('', client_sock.recv(1))
 
745
        except socket.error, e:
 
746
            if e.errno not in (errno.EBADF,):
 
747
                raise
 
748
        req = client_medium.get_request()
 
749
 
588
750
 
589
751
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
590
752
 
3101
3263
            'e', # end
3102
3264
            output.getvalue())
3103
3265
 
 
3266
    def test_records_start_of_body_stream(self):
 
3267
        requester, output = self.make_client_encoder_and_output()
 
3268
        requester.set_headers({})
 
3269
        in_stream = [False]
 
3270
        def stream_checker():
 
3271
            self.assertTrue(requester.body_stream_started)
 
3272
            in_stream[0] = True
 
3273
            yield 'content'
 
3274
        flush_called = []
 
3275
        orig_flush = requester.flush
 
3276
        def tracked_flush():
 
3277
            flush_called.append(in_stream[0])
 
3278
            if in_stream[0]:
 
3279
                self.assertTrue(requester.body_stream_started)
 
3280
            else:
 
3281
                self.assertFalse(requester.body_stream_started)
 
3282
            return orig_flush()
 
3283
        requester.flush = tracked_flush
 
3284
        requester.call_with_body_stream(('one arg',), stream_checker())
 
3285
        self.assertEqual(
 
3286
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
3287
            '\x00\x00\x00\x02de' # headers
 
3288
            's\x00\x00\x00\x0bl7:one arge' # args
 
3289
            'b\x00\x00\x00\x07content' # body
 
3290
            'e', output.getvalue())
 
3291
        self.assertEqual([False, True, True], flush_called)
 
3292
 
3104
3293
 
3105
3294
class StubMediumRequest(object):
3106
3295
    """A stub medium request that tracks the number of times accept_bytes is
3526
3715
        # encoder.
3527
3716
 
3528
3717
 
 
3718
class Test_SmartClientRequest(tests.TestCase):
 
3719
 
 
3720
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
 
3721
        response_io = StringIO(response)
 
3722
        output = StringIO()
 
3723
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
 
3724
                    fail_at_write=fail_at_write)
 
3725
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3726
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3727
        smart_client = client._SmartClient(client_medium, headers={})
 
3728
        return output, vendor, smart_client
 
3729
 
 
3730
    def make_response(self, args, body=None, body_stream=None):
 
3731
        response_io = StringIO()
 
3732
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
 
3733
            body_stream=body_stream)
 
3734
        responder = protocol.ProtocolThreeResponder(response_io.write)
 
3735
        responder.send_response(response)
 
3736
        return response_io.getvalue()
 
3737
 
 
3738
    def test__call_doesnt_retry_append(self):
 
3739
        response = self.make_response(('appended', '8'))
 
3740
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3741
            fail_at_write=False, response=response)
 
3742
        smart_request = client._SmartClientRequest(smart_client, 'append',
 
3743
            ('foo', ''), body='content\n')
 
3744
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3745
 
 
3746
    def test__call_retries_get_bytes(self):
 
3747
        response = self.make_response(('ok',), 'content\n')
 
3748
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3749
            fail_at_write=False, response=response)
 
3750
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3751
            ('foo',))
 
3752
        response, response_handler = smart_request._call(3)
 
3753
        self.assertEqual(('ok',), response)
 
3754
        self.assertEqual('content\n', response_handler.read_body_bytes())
 
3755
 
 
3756
    def test__call_noretry_get_bytes(self):
 
3757
        debug.debug_flags.add('noretry')
 
3758
        response = self.make_response(('ok',), 'content\n')
 
3759
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3760
            fail_at_write=False, response=response)
 
3761
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3762
            ('foo',))
 
3763
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3764
 
 
3765
    def test__send_no_retry_pipes(self):
 
3766
        client_read, server_write = create_file_pipes()
 
3767
        server_read, client_write = create_file_pipes()
 
3768
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
 
3769
            client_write, base='/')
 
3770
        smart_client = client._SmartClient(client_medium)
 
3771
        smart_request = client._SmartClientRequest(smart_client,
 
3772
            'hello', ())
 
3773
        # Close the server side
 
3774
        server_read.close()
 
3775
        encoder, response_handler = smart_request._construct_protocol(3)
 
3776
        self.assertRaises(errors.ConnectionReset,
 
3777
            smart_request._send_no_retry, encoder)
 
3778
 
 
3779
    def test__send_read_response_sockets(self):
 
3780
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
3781
        listen_sock.bind(('127.0.0.1', 0))
 
3782
        listen_sock.listen(1)
 
3783
        host, port = listen_sock.getsockname()
 
3784
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
 
3785
        client_medium._ensure_connection()
 
3786
        smart_client = client._SmartClient(client_medium)
 
3787
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3788
        # Accept the connection, but don't actually talk to the client.
 
3789
        server_sock, _ = listen_sock.accept()
 
3790
        server_sock.close()
 
3791
        # Sockets buffer and don't really notice that the server has closed the
 
3792
        # connection until we try to read again.
 
3793
        handler = smart_request._send(3)
 
3794
        self.assertRaises(errors.ConnectionReset,
 
3795
            handler.read_response_tuple, expect_body=False)
 
3796
 
 
3797
    def test__send_retries_on_write(self):
 
3798
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3799
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3800
        handler = smart_request._send(3)
 
3801
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3802
                         '\x00\x00\x00\x02de'   # empty headers
 
3803
                         's\x00\x00\x00\tl5:helloee',
 
3804
                         output.getvalue())
 
3805
        self.assertEqual(
 
3806
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3807
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3808
             ('close',),
 
3809
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3810
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3811
            ],
 
3812
            vendor.calls)
 
3813
 
 
3814
    def test__send_doesnt_retry_read_failure(self):
 
3815
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3816
            fail_at_write=False)
 
3817
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3818
        handler = smart_request._send(3)
 
3819
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3820
                         '\x00\x00\x00\x02de'   # empty headers
 
3821
                         's\x00\x00\x00\tl5:helloee',
 
3822
                         output.getvalue())
 
3823
        self.assertEqual(
 
3824
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3825
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3826
            ],
 
3827
            vendor.calls)
 
3828
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
3829
 
 
3830
    def test__send_request_retries_body_stream_if_not_started(self):
 
3831
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3832
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3833
            body_stream=['a', 'b'])
 
3834
        response_handler = smart_request._send(3)
 
3835
        # We connect, get disconnected, and notice before consuming the stream,
 
3836
        # so we try again one time and succeed.
 
3837
        self.assertEqual(
 
3838
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3839
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3840
             ('close',),
 
3841
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3842
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3843
            ],
 
3844
            vendor.calls)
 
3845
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3846
                         '\x00\x00\x00\x02de'   # empty headers
 
3847
                         's\x00\x00\x00\tl5:helloe'
 
3848
                         'b\x00\x00\x00\x01a'
 
3849
                         'b\x00\x00\x00\x01b'
 
3850
                         'e',
 
3851
                         output.getvalue())
 
3852
 
 
3853
    def test__send_request_stops_if_body_started(self):
 
3854
        # We intentionally use the python StringIO so that we can subclass it.
 
3855
        from StringIO import StringIO
 
3856
        response = StringIO()
 
3857
 
 
3858
        class FailAfterFirstWrite(StringIO):
 
3859
            """Allow one 'write' call to pass, fail the rest"""
 
3860
            def __init__(self):
 
3861
                StringIO.__init__(self)
 
3862
                self._first = True
 
3863
 
 
3864
            def write(self, s):
 
3865
                if self._first:
 
3866
                    self._first = False
 
3867
                    return StringIO.write(self, s)
 
3868
                raise IOError(errno.EINVAL, 'invalid file handle')
 
3869
        output = FailAfterFirstWrite()
 
3870
 
 
3871
        vendor = FirstRejectedStringIOSSHVendor(response, output,
 
3872
            fail_at_write=False)
 
3873
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3874
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3875
        smart_client = client._SmartClient(client_medium, headers={})
 
3876
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3877
            body_stream=['a', 'b'])
 
3878
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3879
        # We connect, and manage to get to the point that we start consuming
 
3880
        # the body stream. The next write fails, so we just stop.
 
3881
        self.assertEqual(
 
3882
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3883
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3884
             ('close',),
 
3885
            ],
 
3886
            vendor.calls)
 
3887
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3888
                         '\x00\x00\x00\x02de'   # empty headers
 
3889
                         's\x00\x00\x00\tl5:helloe',
 
3890
                         output.getvalue())
 
3891
 
 
3892
    def test__send_disabled_retry(self):
 
3893
        debug.debug_flags.add('noretry')
 
3894
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3895
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3896
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3897
        self.assertEqual(
 
3898
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3899
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3900
             ('close',),
 
3901
            ],
 
3902
            vendor.calls)
 
3903
 
 
3904
 
3529
3905
class LengthPrefixedBodyDecoder(tests.TestCase):
3530
3906
 
3531
3907
    # XXX: TODO: make accept_reading_trailer invoke translate_response or