~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Samuel Bronson
  • Date: 2012-08-30 20:36:18 UTC
  • mto: (6015.57.3 2.4)
  • mto: This revision was merged to the branch mainline in revision 6558.
  • Revision ID: naesten@gmail.com-20120830203618-y2dzw91nqpvpgxvx
Update INSTALL for switch to Python 2.6 and up.

Show diffs side-by-side

added added

removed removed

Lines of Context:
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
21
 
import doctest
22
 
import errno
23
21
import os
24
22
import socket
25
 
import subprocess
26
 
import sys
27
23
import threading
28
 
import time
29
 
 
30
 
from testtools.matchers import DocTestMatches
31
24
 
32
25
import bzrlib
33
26
from bzrlib import (
34
27
        bzrdir,
35
 
        debug,
36
28
        errors,
37
29
        osutils,
38
30
        tests,
39
 
        transport as _mod_transport,
 
31
        transport,
40
32
        urlutils,
41
33
        )
42
34
from bzrlib.smart import (
45
37
        message,
46
38
        protocol,
47
39
        request as _mod_request,
48
 
        server as _mod_server,
 
40
        server,
49
41
        vfs,
50
42
)
51
43
from bzrlib.tests import (
52
 
    features,
53
44
    test_smart,
54
45
    test_server,
55
46
    )
62
53
        )
63
54
 
64
55
 
65
 
def create_file_pipes():
66
 
    r, w = os.pipe()
67
 
    # These must be opened without buffering, or we get undefined results
68
 
    rf = os.fdopen(r, 'rb', 0)
69
 
    wf = os.fdopen(w, 'wb', 0)
70
 
    return rf, wf
71
 
 
72
 
 
73
 
def portable_socket_pair():
74
 
    """Return a pair of TCP sockets connected to each other.
75
 
 
76
 
    Unlike socket.socketpair, this should work on Windows.
77
 
    """
78
 
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
79
 
    listen_sock.bind(('127.0.0.1', 0))
80
 
    listen_sock.listen(1)
81
 
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
82
 
    client_sock.connect(listen_sock.getsockname())
83
 
    server_sock, addr = listen_sock.accept()
84
 
    listen_sock.close()
85
 
    return server_sock, client_sock
86
 
 
87
 
 
88
56
class StringIOSSHVendor(object):
89
57
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
90
58
 
99
67
        return StringIOSSHConnection(self)
100
68
 
101
69
 
102
 
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
103
 
    """The first connection will be considered closed.
104
 
 
105
 
    The second connection will succeed normally.
106
 
    """
107
 
 
108
 
    def __init__(self, read_from, write_to, fail_at_write=True):
109
 
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
110
 
            write_to)
111
 
        self.fail_at_write = fail_at_write
112
 
        self._first = True
113
 
 
114
 
    def connect_ssh(self, username, password, host, port, command):
115
 
        self.calls.append(('connect_ssh', username, password, host, port,
116
 
            command))
117
 
        if self._first:
118
 
            self._first = False
119
 
            return ClosedSSHConnection(self)
120
 
        return StringIOSSHConnection(self)
121
 
 
122
 
 
123
70
class StringIOSSHConnection(ssh.SSHConnection):
124
71
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
125
72
 
135
82
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
136
83
 
137
84
 
138
 
class ClosedSSHConnection(ssh.SSHConnection):
139
 
    """An SSH connection that just has closed channels."""
140
 
 
141
 
    def __init__(self, vendor):
142
 
        self.vendor = vendor
143
 
 
144
 
    def close(self):
145
 
        self.vendor.calls.append(('close', ))
146
 
 
147
 
    def get_sock_or_pipes(self):
148
 
        # We create matching pipes, and then close the ssh side
149
 
        bzr_read, ssh_write = create_file_pipes()
150
 
        # We always fail when bzr goes to read
151
 
        ssh_write.close()
152
 
        if self.vendor.fail_at_write:
153
 
            # If set, we'll also fail when bzr goes to write
154
 
            ssh_read, bzr_write = create_file_pipes()
155
 
            ssh_read.close()
156
 
        else:
157
 
            bzr_write = self.vendor.write_to
158
 
        return 'pipes', (bzr_read, bzr_write)
159
 
 
160
 
 
161
 
class _InvalidHostnameFeature(features.Feature):
 
85
class _InvalidHostnameFeature(tests.Feature):
162
86
    """Does 'non_existent.invalid' fail to resolve?
163
87
 
164
88
    RFC 2606 states that .invalid is reserved for invalid domain names, and
253
177
        client_medium._accept_bytes('abc')
254
178
        self.assertEqual('abc', output.getvalue())
255
179
 
256
 
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
257
 
        # It is unfortunate that we have to use Popen for this. However,
258
 
        # os.pipe() does not behave the same as subprocess.Popen().
259
 
        # On Windows, if you use os.pipe() and close the write side,
260
 
        # read.read() hangs. On Linux, read.read() returns the empty string.
261
 
        p = subprocess.Popen([sys.executable, '-c',
262
 
            'import sys\n'
263
 
            'sys.stdout.write(sys.stdin.read(4))\n'
264
 
            'sys.stdout.close()\n'],
265
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
266
 
        client_medium = medium.SmartSimplePipesClientMedium(
267
 
            p.stdout, p.stdin, 'base')
268
 
        client_medium._accept_bytes('abc\n')
269
 
        self.assertEqual('abc', client_medium._read_bytes(3))
270
 
        p.wait()
271
 
        # While writing to the underlying pipe,
272
 
        #   Windows py2.6.6 we get IOError(EINVAL)
273
 
        #   Lucid py2.6.5, we get IOError(EPIPE)
274
 
        # In both cases, it should be wrapped to ConnectionReset
275
 
        self.assertRaises(errors.ConnectionReset,
276
 
                          client_medium._accept_bytes, 'more')
277
 
 
278
 
    def test_simple_pipes__accept_bytes_pipe_closed(self):
279
 
        child_read, client_write = create_file_pipes()
280
 
        client_medium = medium.SmartSimplePipesClientMedium(
281
 
            None, client_write, 'base')
282
 
        client_medium._accept_bytes('abc\n')
283
 
        self.assertEqual('abc\n', child_read.read(4))
284
 
        # While writing to the underlying pipe,
285
 
        #   Windows py2.6.6 we get IOError(EINVAL)
286
 
        #   Lucid py2.6.5, we get IOError(EPIPE)
287
 
        # In both cases, it should be wrapped to ConnectionReset
288
 
        child_read.close()
289
 
        self.assertRaises(errors.ConnectionReset,
290
 
                          client_medium._accept_bytes, 'more')
291
 
 
292
 
    def test_simple_pipes__flush_pipe_closed(self):
293
 
        child_read, client_write = create_file_pipes()
294
 
        client_medium = medium.SmartSimplePipesClientMedium(
295
 
            None, client_write, 'base')
296
 
        client_medium._accept_bytes('abc\n')
297
 
        child_read.close()
298
 
        # Even though the pipe is closed, flush on the write side seems to be a
299
 
        # no-op, rather than a failure.
300
 
        client_medium._flush()
301
 
 
302
 
    def test_simple_pipes__flush_subprocess_closed(self):
303
 
        p = subprocess.Popen([sys.executable, '-c',
304
 
            'import sys\n'
305
 
            'sys.stdout.write(sys.stdin.read(4))\n'
306
 
            'sys.stdout.close()\n'],
307
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
308
 
        client_medium = medium.SmartSimplePipesClientMedium(
309
 
            p.stdout, p.stdin, 'base')
310
 
        client_medium._accept_bytes('abc\n')
311
 
        p.wait()
312
 
        # Even though the child process is dead, flush seems to be a no-op.
313
 
        client_medium._flush()
314
 
 
315
 
    def test_simple_pipes__read_bytes_pipe_closed(self):
316
 
        child_read, client_write = create_file_pipes()
317
 
        client_medium = medium.SmartSimplePipesClientMedium(
318
 
            child_read, client_write, 'base')
319
 
        client_medium._accept_bytes('abc\n')
320
 
        client_write.close()
321
 
        self.assertEqual('abc\n', client_medium._read_bytes(4))
322
 
        self.assertEqual('', client_medium._read_bytes(4))
323
 
 
324
 
    def test_simple_pipes__read_bytes_subprocess_closed(self):
325
 
        p = subprocess.Popen([sys.executable, '-c',
326
 
            'import sys\n'
327
 
            'if sys.platform == "win32":\n'
328
 
            '    import msvcrt, os\n'
329
 
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
330
 
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
331
 
            'sys.stdout.write(sys.stdin.read(4))\n'
332
 
            'sys.stdout.close()\n'],
333
 
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
334
 
        client_medium = medium.SmartSimplePipesClientMedium(
335
 
            p.stdout, p.stdin, 'base')
336
 
        client_medium._accept_bytes('abc\n')
337
 
        p.wait()
338
 
        self.assertEqual('abc\n', client_medium._read_bytes(4))
339
 
        self.assertEqual('', client_medium._read_bytes(4))
340
 
 
341
180
    def test_simple_pipes_client_disconnect_does_nothing(self):
342
181
        # calling disconnect does nothing.
343
182
        input = StringIO()
499
338
            ],
500
339
            vendor.calls)
501
340
 
502
 
    def test_ssh_client_repr(self):
503
 
        client_medium = medium.SmartSSHClientMedium(
504
 
            'base', medium.SSHParams("example.com", "4242", "username"))
505
 
        self.assertEquals(
506
 
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
507
 
            repr(client_medium))
508
 
 
509
 
    def test_ssh_client_repr_no_port(self):
510
 
        client_medium = medium.SmartSSHClientMedium(
511
 
            'base', medium.SSHParams("example.com", None, "username"))
512
 
        self.assertEquals(
513
 
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
514
 
            repr(client_medium))
515
 
 
516
 
    def test_ssh_client_repr_no_username(self):
517
 
        client_medium = medium.SmartSSHClientMedium(
518
 
            'base', medium.SSHParams("example.com", None, None))
519
 
        self.assertEquals(
520
 
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
521
 
            repr(client_medium))
522
 
 
523
341
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
524
342
        # Doing a disconnect on a new (and thus unconnected) SSH medium
525
343
        # does not fail.  It's ok to disconnect an unconnected medium.
746
564
        request.finished_reading()
747
565
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
748
566
 
749
 
    def test_reset(self):
750
 
        server_sock, client_sock = portable_socket_pair()
751
 
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
752
 
        #       bzr where it exists.
753
 
        client_medium = medium.SmartTCPClientMedium(None, None, None)
754
 
        client_medium._socket = client_sock
755
 
        client_medium._connected = True
756
 
        req = client_medium.get_request()
757
 
        self.assertRaises(errors.TooManyConcurrentRequests,
758
 
            client_medium.get_request)
759
 
        client_medium.reset()
760
 
        # The stream should be reset, marked as disconnected, though ready for
761
 
        # us to make a new request
762
 
        self.assertFalse(client_medium._connected)
763
 
        self.assertIs(None, client_medium._socket)
764
 
        try:
765
 
            self.assertEqual('', client_sock.recv(1))
766
 
        except socket.error, e:
767
 
            if e.errno not in (errno.EBADF,):
768
 
                raise
769
 
        req = client_medium.get_request()
770
 
 
771
567
 
772
568
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
773
569
 
821
617
        super(TestSmartServerStreamMedium, self).setUp()
822
618
        self.overrideEnv('BZR_NO_SMART_VFS', None)
823
619
 
824
 
    def create_pipe_medium(self, to_server, from_server, transport,
825
 
                           timeout=4.0):
826
 
        """Create a new SmartServerPipeStreamMedium."""
827
 
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
828
 
            transport, timeout=timeout)
829
 
 
830
 
    def create_pipe_context(self, to_server_bytes, transport):
831
 
        """Create a SmartServerSocketStreamMedium.
832
 
 
833
 
        This differes from create_pipe_medium, in that we initialize the
834
 
        request that is sent to the server, and return the StringIO class that
835
 
        will hold the response.
836
 
        """
837
 
        to_server = StringIO(to_server_bytes)
838
 
        from_server = StringIO()
839
 
        m = self.create_pipe_medium(to_server, from_server, transport)
840
 
        return m, from_server
841
 
 
842
 
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
843
 
        """Initialize a new medium.SmartServerSocketStreamMedium."""
844
 
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
845
 
            timeout=timeout)
846
 
 
847
 
    def create_socket_context(self, transport, timeout=4.0):
848
 
        """Create a new SmartServerSocketStreamMedium with default context.
849
 
 
850
 
        This will call portable_socket_pair and pass the server side to
851
 
        create_socket_medium along with transport.
852
 
        It then returns the client_sock and the server.
853
 
        """
854
 
        server_sock, client_sock = portable_socket_pair()
855
 
        server = self.create_socket_medium(server_sock, transport,
856
 
                                           timeout=timeout)
857
 
        return server, client_sock
 
620
    def portable_socket_pair(self):
 
621
        """Return a pair of TCP sockets connected to each other.
 
622
 
 
623
        Unlike socket.socketpair, this should work on Windows.
 
624
        """
 
625
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
626
        listen_sock.bind(('127.0.0.1', 0))
 
627
        listen_sock.listen(1)
 
628
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
629
        client_sock.connect(listen_sock.getsockname())
 
630
        server_sock, addr = listen_sock.accept()
 
631
        listen_sock.close()
 
632
        return server_sock, client_sock
858
633
 
859
634
    def test_smart_query_version(self):
860
635
        """Feed a canned query version to a server"""
861
636
        # wire-to-wire, using the whole stack
 
637
        to_server = StringIO('hello\n')
 
638
        from_server = StringIO()
862
639
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
863
 
        server, from_server = self.create_pipe_context('hello\n', transport)
 
640
        server = medium.SmartServerPipeStreamMedium(
 
641
            to_server, from_server, transport)
864
642
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
865
643
                from_server.write)
866
644
        server._serve_one_request(smart_protocol)
870
648
    def test_response_to_canned_get(self):
871
649
        transport = memory.MemoryTransport('memory:///')
872
650
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
873
 
        server, from_server = self.create_pipe_context('get\001./testfile\n',
874
 
            transport)
 
651
        to_server = StringIO('get\001./testfile\n')
 
652
        from_server = StringIO()
 
653
        server = medium.SmartServerPipeStreamMedium(
 
654
            to_server, from_server, transport)
875
655
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
876
656
                from_server.write)
877
657
        server._serve_one_request(smart_protocol)
888
668
        # VFS requests use filenames, not raw UTF-8.
889
669
        hpss_path = urlutils.escape(utf8_filename)
890
670
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
891
 
        server, from_server = self.create_pipe_context(
892
 
                'get\001' + hpss_path + '\n', transport)
 
671
        to_server = StringIO('get\001' + hpss_path + '\n')
 
672
        from_server = StringIO()
 
673
        server = medium.SmartServerPipeStreamMedium(
 
674
            to_server, from_server, transport)
893
675
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
894
676
                from_server.write)
895
677
        server._serve_one_request(smart_protocol)
901
683
 
902
684
    def test_pipe_like_stream_with_bulk_data(self):
903
685
        sample_request_bytes = 'command\n9\nbulk datadone\n'
904
 
        server, from_server = self.create_pipe_context(
905
 
            sample_request_bytes, None)
 
686
        to_server = StringIO(sample_request_bytes)
 
687
        from_server = StringIO()
 
688
        server = medium.SmartServerPipeStreamMedium(
 
689
            to_server, from_server, None)
906
690
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
907
691
        server._serve_one_request(sample_protocol)
908
692
        self.assertEqual('', from_server.getvalue())
911
695
 
912
696
    def test_socket_stream_with_bulk_data(self):
913
697
        sample_request_bytes = 'command\n9\nbulk datadone\n'
914
 
        server, client_sock = self.create_socket_context(None)
 
698
        server_sock, client_sock = self.portable_socket_pair()
 
699
        server = medium.SmartServerSocketStreamMedium(
 
700
            server_sock, None)
915
701
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
916
702
        client_sock.sendall(sample_request_bytes)
917
703
        server._serve_one_request(sample_protocol)
918
 
        server._disconnect_client()
 
704
        server_sock.close()
919
705
        self.assertEqual('', client_sock.recv(1))
920
706
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
921
707
        self.assertFalse(server.finished)
922
708
 
923
709
    def test_pipe_like_stream_shutdown_detection(self):
924
 
        server, _ = self.create_pipe_context('', None)
 
710
        to_server = StringIO('')
 
711
        from_server = StringIO()
 
712
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
925
713
        server._serve_one_request(SampleRequest('x'))
926
714
        self.assertTrue(server.finished)
927
715
 
928
716
    def test_socket_stream_shutdown_detection(self):
929
 
        server, client_sock = self.create_socket_context(None)
 
717
        server_sock, client_sock = self.portable_socket_pair()
930
718
        client_sock.close()
 
719
        server = medium.SmartServerSocketStreamMedium(
 
720
            server_sock, None)
931
721
        server._serve_one_request(SampleRequest('x'))
932
722
        self.assertTrue(server.finished)
933
723
 
944
734
        rest_of_request_bytes = 'lo\n'
945
735
        expected_response = (
946
736
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
947
 
        server, client_sock = self.create_socket_context(None)
 
737
        server_sock, client_sock = self.portable_socket_pair()
 
738
        server = medium.SmartServerSocketStreamMedium(
 
739
            server_sock, None)
948
740
        client_sock.sendall(incomplete_request_bytes)
949
741
        server_protocol = server._build_protocol()
950
742
        client_sock.sendall(rest_of_request_bytes)
951
743
        server._serve_one_request(server_protocol)
952
 
        server._disconnect_client()
 
744
        server_sock.close()
953
745
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
954
746
                         "Not a version 2 response to 'hello' request.")
955
747
        self.assertEqual('', client_sock.recv(1))
974
766
        to_server_w = os.fdopen(to_server_w, 'w', 0)
975
767
        from_server_r = os.fdopen(from_server_r, 'r', 0)
976
768
        from_server = os.fdopen(from_server, 'w', 0)
977
 
        server = self.create_pipe_medium(to_server, from_server, None)
 
769
        server = medium.SmartServerPipeStreamMedium(
 
770
            to_server, from_server, None)
978
771
        # Like test_socket_stream_incomplete_request, write an incomplete
979
772
        # request (that does not end in '\n') and build a protocol from it.
980
773
        to_server_w.write(incomplete_request_bytes)
995
788
        # _serve_one_request should still process both of them as if they had
996
789
        # been received separately.
997
790
        sample_request_bytes = 'command\n'
998
 
        server, from_server = self.create_pipe_context(
999
 
            sample_request_bytes * 2, None)
 
791
        to_server = StringIO(sample_request_bytes * 2)
 
792
        from_server = StringIO()
 
793
        server = medium.SmartServerPipeStreamMedium(
 
794
            to_server, from_server, None)
1000
795
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
1001
796
        server._serve_one_request(first_protocol)
1002
797
        self.assertEqual(0, first_protocol.next_read_size())
1015
810
        # _serve_one_request should still process both of them as if they had
1016
811
        # been received separately.
1017
812
        sample_request_bytes = 'command\n'
1018
 
        server, client_sock = self.create_socket_context(None)
 
813
        server_sock, client_sock = self.portable_socket_pair()
 
814
        server = medium.SmartServerSocketStreamMedium(
 
815
            server_sock, None)
1019
816
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
1020
817
        # Put two whole requests on the wire.
1021
818
        client_sock.sendall(sample_request_bytes * 2)
1028
825
        stream_still_open = server._serve_one_request(second_protocol)
1029
826
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
1030
827
        self.assertFalse(server.finished)
1031
 
        server._disconnect_client()
 
828
        server_sock.close()
1032
829
        self.assertEqual('', client_sock.recv(1))
1033
830
 
1034
831
    def test_pipe_like_stream_error_handling(self):
1041
838
        def close():
1042
839
            self.closed = True
1043
840
        from_server.close = close
1044
 
        server = self.create_pipe_medium(
 
841
        server = medium.SmartServerPipeStreamMedium(
1045
842
            to_server, from_server, None)
1046
843
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
1047
844
        server._serve_one_request(fake_protocol)
1050
847
        self.assertTrue(server.finished)
1051
848
 
1052
849
    def test_socket_stream_error_handling(self):
1053
 
        server, client_sock = self.create_socket_context(None)
 
850
        server_sock, client_sock = self.portable_socket_pair()
 
851
        server = medium.SmartServerSocketStreamMedium(
 
852
            server_sock, None)
1054
853
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
1055
854
        server._serve_one_request(fake_protocol)
1056
855
        # recv should not block, because the other end of the socket has been
1059
858
        self.assertTrue(server.finished)
1060
859
 
1061
860
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
1062
 
        server, from_server = self.create_pipe_context('', None)
 
861
        to_server = StringIO('')
 
862
        from_server = StringIO()
 
863
        server = medium.SmartServerPipeStreamMedium(
 
864
            to_server, from_server, None)
1063
865
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
1064
866
        self.assertRaises(
1065
867
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
1066
868
        self.assertEqual('', from_server.getvalue())
1067
869
 
1068
870
    def test_socket_stream_keyboard_interrupt_handling(self):
1069
 
        server, client_sock = self.create_socket_context(None)
 
871
        server_sock, client_sock = self.portable_socket_pair()
 
872
        server = medium.SmartServerSocketStreamMedium(
 
873
            server_sock, None)
1070
874
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
1071
875
        self.assertRaises(
1072
876
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
1073
 
        server._disconnect_client()
 
877
        server_sock.close()
1074
878
        self.assertEqual('', client_sock.recv(1))
1075
879
 
1076
880
    def build_protocol_pipe_like(self, bytes):
1077
 
        server, _ = self.create_pipe_context(bytes, None)
 
881
        to_server = StringIO(bytes)
 
882
        from_server = StringIO()
 
883
        server = medium.SmartServerPipeStreamMedium(
 
884
            to_server, from_server, None)
1078
885
        return server._build_protocol()
1079
886
 
1080
887
    def build_protocol_socket(self, bytes):
1081
 
        server, client_sock = self.create_socket_context(None)
 
888
        server_sock, client_sock = self.portable_socket_pair()
 
889
        server = medium.SmartServerSocketStreamMedium(
 
890
            server_sock, None)
1082
891
        client_sock.sendall(bytes)
1083
892
        client_sock.close()
1084
893
        return server._build_protocol()
1124
933
        server_protocol = self.build_protocol_socket('bzr request 2\n')
1125
934
        self.assertProtocolTwo(server_protocol)
1126
935
 
1127
 
    def test__build_protocol_returns_if_stopping(self):
1128
 
        # _build_protocol should notice that we are stopping, and return
1129
 
        # without waiting for bytes from the client.
1130
 
        server, client_sock = self.create_socket_context(None)
1131
 
        server._stop_gracefully()
1132
 
        self.assertIs(None, server._build_protocol())
1133
 
 
1134
 
    def test_socket_set_timeout(self):
1135
 
        server, _ = self.create_socket_context(None, timeout=1.23)
1136
 
        self.assertEqual(1.23, server._client_timeout)
1137
 
 
1138
 
    def test_pipe_set_timeout(self):
1139
 
        server = self.create_pipe_medium(None, None, None,
1140
 
            timeout=1.23)
1141
 
        self.assertEqual(1.23, server._client_timeout)
1142
 
 
1143
 
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
1144
 
        server, client_sock = self.create_socket_context(None)
1145
 
        client_sock.sendall('data\n')
1146
 
        # This should not block or consume any actual content
1147
 
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
1148
 
        data = server.read_bytes(5)
1149
 
        self.assertEqual('data\n', data)
1150
 
 
1151
 
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
1152
 
        server, client_sock = self.create_socket_context(None)
1153
 
        # This should timeout quickly, reporting that there wasn't any data
1154
 
        self.assertRaises(errors.ConnectionTimeout,
1155
 
                          server._wait_for_bytes_with_timeout, 0.01)
1156
 
        client_sock.close()
1157
 
        data = server.read_bytes(1)
1158
 
        self.assertEqual('', data)
1159
 
 
1160
 
    def test_socket_wait_for_bytes_with_timeout_closed(self):
1161
 
        server, client_sock = self.create_socket_context(None)
1162
 
        # With the socket closed, this should return right away.
1163
 
        # It seems select.select() returns that you *can* read on the socket,
1164
 
        # even though it closed. Presumably as a way to tell it is closed?
1165
 
        # Testing shows that without sock.close() this times-out failing the
1166
 
        # test, but with it, it returns False immediately.
1167
 
        client_sock.close()
1168
 
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
1169
 
        data = server.read_bytes(1)
1170
 
        self.assertEqual('', data)
1171
 
 
1172
 
    def test_socket_wait_for_bytes_with_shutdown(self):
1173
 
        server, client_sock = self.create_socket_context(None)
1174
 
        t = time.time()
1175
 
        # Override the _timer functionality, so that time never increments,
1176
 
        # this way, we can be sure we stopped because of the flag, and not
1177
 
        # because of a timeout, etc.
1178
 
        server._timer = lambda: t
1179
 
        server._client_poll_timeout = 0.1
1180
 
        server._stop_gracefully()
1181
 
        server._wait_for_bytes_with_timeout(1.0)
1182
 
 
1183
 
    def test_socket_serve_timeout_closes_socket(self):
1184
 
        server, client_sock = self.create_socket_context(None, timeout=0.1)
1185
 
        # This should timeout quickly, and then close the connection so that
1186
 
        # client_sock recv doesn't block.
1187
 
        server.serve()
1188
 
        self.assertEqual('', client_sock.recv(1))
1189
 
 
1190
 
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
1191
 
        # We intentionally use a real pipe here, so that we can 'select' on it.
1192
 
        # You can't select() on a StringIO
1193
 
        (r_server, w_client) = os.pipe()
1194
 
        self.addCleanup(os.close, w_client)
1195
 
        with os.fdopen(r_server, 'rb') as rf_server:
1196
 
            server = self.create_pipe_medium(
1197
 
                rf_server, None, None)
1198
 
            os.write(w_client, 'data\n')
1199
 
            # This should not block or consume any actual content
1200
 
            server._wait_for_bytes_with_timeout(0.1)
1201
 
            data = server.read_bytes(5)
1202
 
            self.assertEqual('data\n', data)
1203
 
 
1204
 
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
1205
 
        # We intentionally use a real pipe here, so that we can 'select' on it.
1206
 
        # You can't select() on a StringIO
1207
 
        (r_server, w_client) = os.pipe()
1208
 
        # We can't add an os.close cleanup here, because we need to control
1209
 
        # when the file handle gets closed ourselves.
1210
 
        with os.fdopen(r_server, 'rb') as rf_server:
1211
 
            server = self.create_pipe_medium(
1212
 
                rf_server, None, None)
1213
 
            if sys.platform == 'win32':
1214
 
                # Windows cannot select() on a pipe, so we just always return
1215
 
                server._wait_for_bytes_with_timeout(0.01)
1216
 
            else:
1217
 
                self.assertRaises(errors.ConnectionTimeout,
1218
 
                                  server._wait_for_bytes_with_timeout, 0.01)
1219
 
            os.close(w_client)
1220
 
            data = server.read_bytes(5)
1221
 
            self.assertEqual('', data)
1222
 
 
1223
 
    def test_pipe_wait_for_bytes_no_fileno(self):
1224
 
        server, _ = self.create_pipe_context('', None)
1225
 
        # Our file doesn't support polling, so we should always just return
1226
 
        # 'you have data to consume.
1227
 
        server._wait_for_bytes_with_timeout(0.01)
1228
 
 
1229
936
 
1230
937
class TestGetProtocolFactoryForBytes(tests.TestCase):
1231
938
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
1261
968
 
1262
969
class TestSmartTCPServer(tests.TestCase):
1263
970
 
1264
 
    def make_server(self):
1265
 
        """Create a SmartTCPServer that we can exercise.
1266
 
 
1267
 
        Note: we don't use SmartTCPServer_for_testing because the testing
1268
 
        version overrides lots of functionality like 'serve', and we want to
1269
 
        test the raw service.
1270
 
 
1271
 
        This will start the server in another thread, and wait for it to
1272
 
        indicate it has finished starting up.
1273
 
 
1274
 
        :return: (server, server_thread)
1275
 
        """
1276
 
        t = _mod_transport.get_transport_from_url('memory:///')
1277
 
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
1278
 
        server._ACCEPT_TIMEOUT = 0.1
1279
 
        # We don't use 'localhost' because that might be an IPv6 address.
1280
 
        server.start_server('127.0.0.1', 0)
1281
 
        server_thread = threading.Thread(target=server.serve,
1282
 
                                         args=(self.id(),))
1283
 
        server_thread.start()
1284
 
        # Ensure this gets called at some point
1285
 
        self.addCleanup(server._stop_gracefully)
1286
 
        server._started.wait()
1287
 
        return server, server_thread
1288
 
 
1289
 
    def ensure_client_disconnected(self, client_sock):
1290
 
        """Ensure that a socket is closed, discarding all errors."""
1291
 
        try:
1292
 
            client_sock.close()
1293
 
        except Exception:
1294
 
            pass
1295
 
 
1296
 
    def connect_to_server(self, server):
1297
 
        """Create a client socket that can talk to the server."""
1298
 
        client_sock = socket.socket()
1299
 
        server_info = server._server_socket.getsockname()
1300
 
        client_sock.connect(server_info)
1301
 
        self.addCleanup(self.ensure_client_disconnected, client_sock)
1302
 
        return client_sock
1303
 
 
1304
 
    def connect_to_server_and_hangup(self, server):
1305
 
        """Connect to the server, and then hang up.
1306
 
        That way it doesn't sit waiting for 'accept()' to timeout.
1307
 
        """
1308
 
        # If the server has already signaled that the socket is closed, we
1309
 
        # don't need to try to connect to it. Not being set, though, the server
1310
 
        # might still close the socket while we try to connect to it. So we
1311
 
        # still have to catch the exception.
1312
 
        if server._stopped.isSet():
1313
 
            return
1314
 
        try:
1315
 
            client_sock = self.connect_to_server(server)
1316
 
            client_sock.close()
1317
 
        except socket.error, e:
1318
 
            # If the server has hung up already, that is fine.
1319
 
            pass
1320
 
 
1321
 
    def say_hello(self, client_sock):
1322
 
        """Send the 'hello' smart RPC, and expect the response."""
1323
 
        client_sock.send('hello\n')
1324
 
        self.assertEqual('ok\x012\n', client_sock.recv(5))
1325
 
 
1326
 
    def shutdown_server_cleanly(self, server, server_thread):
1327
 
        server._stop_gracefully()
1328
 
        self.connect_to_server_and_hangup(server)
1329
 
        server._stopped.wait()
1330
 
        server._fully_stopped.wait()
1331
 
        server_thread.join()
1332
 
 
1333
971
    def test_get_error_unexpected(self):
1334
972
        """Error reported by server with no specific representation"""
1335
973
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1353
991
                                t.get, 'something')
1354
992
        self.assertContainsRe(str(err), 'some random exception')
1355
993
 
1356
 
    def test_propagates_timeout(self):
1357
 
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
1358
 
        server_sock, client_sock = portable_socket_pair()
1359
 
        handler = server._make_handler(server_sock)
1360
 
        self.assertEqual(1.23, handler._client_timeout)
1361
 
 
1362
 
    def test_serve_conn_tracks_connections(self):
1363
 
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
1364
 
        server_sock, client_sock = portable_socket_pair()
1365
 
        server.serve_conn(server_sock, '-%s' % (self.id(),))
1366
 
        self.assertEqual(1, len(server._active_connections))
1367
 
        # We still want to talk on the connection. Polling should indicate it
1368
 
        # is still active.
1369
 
        server._poll_active_connections()
1370
 
        self.assertEqual(1, len(server._active_connections))
1371
 
        # Closing the socket will end the active thread, and polling will
1372
 
        # notice and remove it from the active set.
1373
 
        client_sock.close()
1374
 
        server._poll_active_connections(0.1)
1375
 
        self.assertEqual(0, len(server._active_connections))
1376
 
 
1377
 
    def test_serve_closes_out_finished_connections(self):
1378
 
        server, server_thread = self.make_server()
1379
 
        # The server is started, connect to it.
1380
 
        client_sock = self.connect_to_server(server)
1381
 
        # We send and receive on the connection, so that we know the
1382
 
        # server-side has seen the connect, and started handling the
1383
 
        # results.
1384
 
        self.say_hello(client_sock)
1385
 
        self.assertEqual(1, len(server._active_connections))
1386
 
        # Grab a handle to the thread that is processing our request
1387
 
        _, server_side_thread = server._active_connections[0]
1388
 
        # Close the connection, ask the server to stop, and wait for the
1389
 
        # server to stop, as well as the thread that was servicing the
1390
 
        # client request.
1391
 
        client_sock.close()
1392
 
        # Wait for the server-side request thread to notice we are closed.
1393
 
        server_side_thread.join()
1394
 
        # Stop the server, it should notice the connection has finished.
1395
 
        self.shutdown_server_cleanly(server, server_thread)
1396
 
        # The server should have noticed that all clients are gone before
1397
 
        # exiting.
1398
 
        self.assertEqual(0, len(server._active_connections))
1399
 
 
1400
 
    def test_serve_reaps_finished_connections(self):
1401
 
        server, server_thread = self.make_server()
1402
 
        client_sock1 = self.connect_to_server(server)
1403
 
        # We send and receive on the connection, so that we know the
1404
 
        # server-side has seen the connect, and started handling the
1405
 
        # results.
1406
 
        self.say_hello(client_sock1)
1407
 
        server_handler1, server_side_thread1 = server._active_connections[0]
1408
 
        client_sock1.close()
1409
 
        server_side_thread1.join()
1410
 
        # By waiting until the first connection is fully done, the server
1411
 
        # should notice after another connection that the first has finished.
1412
 
        client_sock2 = self.connect_to_server(server)
1413
 
        self.say_hello(client_sock2)
1414
 
        server_handler2, server_side_thread2 = server._active_connections[-1]
1415
 
        # There is a race condition. We know that client_sock2 has been
1416
 
        # registered, but not that _poll_active_connections has been called. We
1417
 
        # know that it will be called before the server will accept a new
1418
 
        # connection, however. So connect one more time, and assert that we
1419
 
        # either have 1 or 2 active connections (never 3), and that the 'first'
1420
 
        # connection is not connection 1
1421
 
        client_sock3 = self.connect_to_server(server)
1422
 
        self.say_hello(client_sock3)
1423
 
        # Copy the list, so we don't have it mutating behind our back
1424
 
        conns = list(server._active_connections)
1425
 
        self.assertEqual(2, len(conns))
1426
 
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
1427
 
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
1428
 
        client_sock2.close()
1429
 
        client_sock3.close()
1430
 
        self.shutdown_server_cleanly(server, server_thread)
1431
 
 
1432
 
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
1433
 
        server, server_thread = self.make_server()
1434
 
        # We need something big enough that it won't fit in a single recv. So
1435
 
        # the server thread gets blocked writing content to the client until we
1436
 
        # finish reading on the client.
1437
 
        server.backing_transport.put_bytes('bigfile',
1438
 
            'a'*1024*1024)
1439
 
        client_sock = self.connect_to_server(server)
1440
 
        self.say_hello(client_sock)
1441
 
        _, server_side_thread = server._active_connections[0]
1442
 
        # Start the RPC, but don't finish reading the response
1443
 
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
1444
 
            'base', client_sock)
1445
 
        client_client = client._SmartClient(client_medium)
1446
 
        resp, response_handler = client_client.call_expecting_body('get',
1447
 
            'bigfile')
1448
 
        self.assertEqual(('ok',), resp)
1449
 
        # Ask the server to stop gracefully, and wait for it.
1450
 
        server._stop_gracefully()
1451
 
        self.connect_to_server_and_hangup(server)
1452
 
        server._stopped.wait()
1453
 
        # It should not be accepting another connection.
1454
 
        self.assertRaises(socket.error, self.connect_to_server, server)
1455
 
        # It should also not be fully stopped
1456
 
        server._fully_stopped.wait(0.01)
1457
 
        self.assertFalse(server._fully_stopped.isSet())
1458
 
        response_handler.read_body_bytes()
1459
 
        client_sock.close()
1460
 
        server_side_thread.join()
1461
 
        server_thread.join()
1462
 
        self.assertTrue(server._fully_stopped.isSet())
1463
 
        log = self.get_log()
1464
 
        self.assertThat(log, DocTestMatches("""\
1465
 
    INFO  Requested to stop gracefully
1466
 
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
1467
 
    INFO  Waiting for 1 client(s) to finish
1468
 
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
1469
 
 
1470
 
    def test_stop_gracefully_tells_handlers_to_stop(self):
1471
 
        server, server_thread = self.make_server()
1472
 
        client_sock = self.connect_to_server(server)
1473
 
        self.say_hello(client_sock)
1474
 
        server_handler, server_side_thread = server._active_connections[0]
1475
 
        self.assertFalse(server_handler.finished)
1476
 
        server._stop_gracefully()
1477
 
        self.assertTrue(server_handler.finished)
1478
 
        client_sock.close()
1479
 
        self.connect_to_server_and_hangup(server)
1480
 
        server_thread.join()
1481
 
 
1482
994
 
1483
995
class SmartTCPTests(tests.TestCase):
1484
996
    """Tests for connection/end to end behaviour using the TCP server.
1502
1014
            mem_server.start_server()
1503
1015
            self.addCleanup(mem_server.stop_server)
1504
1016
            self.permit_url(mem_server.get_url())
1505
 
            self.backing_transport = _mod_transport.get_transport_from_url(
 
1017
            self.backing_transport = transport.get_transport(
1506
1018
                mem_server.get_url())
1507
1019
        else:
1508
1020
            self.backing_transport = backing_transport
1509
1021
        if readonly:
1510
1022
            self.real_backing_transport = self.backing_transport
1511
 
            self.backing_transport = _mod_transport.get_transport_from_url(
 
1023
            self.backing_transport = transport.get_transport(
1512
1024
                "readonly+" + self.backing_transport.abspath('.'))
1513
 
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
1514
 
                                                 client_timeout=4.0)
 
1025
        self.server = server.SmartTCPServer(self.backing_transport)
1515
1026
        self.server.start_server('127.0.0.1', 0)
1516
1027
        self.server.start_background_thread('-' + self.id())
1517
1028
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1651
1162
    def test_server_started_hook_memory(self):
1652
1163
        """The server_started hook fires when the server is started."""
1653
1164
        self.hook_calls = []
1654
 
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1165
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1655
1166
            self.capture_server_call, None)
1656
1167
        self.start_server()
1657
1168
        # at this point, the server will be starting a thread up.
1665
1176
    def test_server_started_hook_file(self):
1666
1177
        """The server_started hook fires when the server is started."""
1667
1178
        self.hook_calls = []
1668
 
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1179
        server.SmartTCPServer.hooks.install_named_hook('server_started',
1669
1180
            self.capture_server_call, None)
1670
 
        self.start_server(
1671
 
            backing_transport=_mod_transport.get_transport_from_path("."))
 
1181
        self.start_server(backing_transport=transport.get_transport("."))
1672
1182
        # at this point, the server will be starting a thread up.
1673
1183
        # there is no indicator at the moment, so bodge it by doing a request.
1674
1184
        self.transport.has('.')
1682
1192
    def test_server_stopped_hook_simple_memory(self):
1683
1193
        """The server_stopped hook fires when the server is stopped."""
1684
1194
        self.hook_calls = []
1685
 
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1195
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1686
1196
            self.capture_server_call, None)
1687
1197
        self.start_server()
1688
1198
        result = [([self.backing_transport.base], self.transport.base)]
1699
1209
    def test_server_stopped_hook_simple_file(self):
1700
1210
        """The server_stopped hook fires when the server is stopped."""
1701
1211
        self.hook_calls = []
1702
 
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1212
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1703
1213
            self.capture_server_call, None)
1704
 
        self.start_server(
1705
 
            backing_transport=_mod_transport.get_transport_from_path("."))
 
1214
        self.start_server(backing_transport=transport.get_transport("."))
1706
1215
        result = [(
1707
1216
            [self.backing_transport.base, self.backing_transport.external_url()]
1708
1217
            , self.transport.base)]
1844
1353
class RemoteTransportRegistration(tests.TestCase):
1845
1354
 
1846
1355
    def test_registration(self):
1847
 
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
 
1356
        t = transport.get_transport('bzr+ssh://example.com/path')
1848
1357
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1849
 
        self.assertEqual('example.com', t._parsed_url.host)
 
1358
        self.assertEqual('example.com', t._host)
1850
1359
 
1851
1360
    def test_bzr_https(self):
1852
1361
        # https://bugs.launchpad.net/bzr/+bug/128456
1853
 
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
 
1362
        t = transport.get_transport('bzr+https://example.com/path')
1854
1363
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1855
1364
        self.assertStartsWith(
1856
1365
            t._http_transport.base,
3033
2542
        from_server = StringIO()
3034
2543
        transport = memory.MemoryTransport('memory:///')
3035
2544
        server = medium.SmartServerPipeStreamMedium(
3036
 
            to_server, from_server, transport, timeout=4.0)
 
2545
            to_server, from_server, transport)
3037
2546
        proto = server._build_protocol()
3038
2547
        message_handler = proto.message_handler
3039
2548
        server._serve_one_request(proto)
3284
2793
            'e', # end
3285
2794
            output.getvalue())
3286
2795
 
3287
 
    def test_records_start_of_body_stream(self):
3288
 
        requester, output = self.make_client_encoder_and_output()
3289
 
        requester.set_headers({})
3290
 
        in_stream = [False]
3291
 
        def stream_checker():
3292
 
            self.assertTrue(requester.body_stream_started)
3293
 
            in_stream[0] = True
3294
 
            yield 'content'
3295
 
        flush_called = []
3296
 
        orig_flush = requester.flush
3297
 
        def tracked_flush():
3298
 
            flush_called.append(in_stream[0])
3299
 
            if in_stream[0]:
3300
 
                self.assertTrue(requester.body_stream_started)
3301
 
            else:
3302
 
                self.assertFalse(requester.body_stream_started)
3303
 
            return orig_flush()
3304
 
        requester.flush = tracked_flush
3305
 
        requester.call_with_body_stream(('one arg',), stream_checker())
3306
 
        self.assertEqual(
3307
 
            'bzr message 3 (bzr 1.6)\n' # protocol version
3308
 
            '\x00\x00\x00\x02de' # headers
3309
 
            's\x00\x00\x00\x0bl7:one arge' # args
3310
 
            'b\x00\x00\x00\x07content' # body
3311
 
            'e', output.getvalue())
3312
 
        self.assertEqual([False, True, True], flush_called)
3313
 
 
3314
2796
 
3315
2797
class StubMediumRequest(object):
3316
2798
    """A stub medium request that tracks the number of times accept_bytes is
3736
3218
        # encoder.
3737
3219
 
3738
3220
 
3739
 
class Test_SmartClientRequest(tests.TestCase):
3740
 
 
3741
 
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
3742
 
        response_io = StringIO(response)
3743
 
        output = StringIO()
3744
 
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
3745
 
                    fail_at_write=fail_at_write)
3746
 
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
3747
 
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
3748
 
        smart_client = client._SmartClient(client_medium, headers={})
3749
 
        return output, vendor, smart_client
3750
 
 
3751
 
    def make_response(self, args, body=None, body_stream=None):
3752
 
        response_io = StringIO()
3753
 
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
3754
 
            body_stream=body_stream)
3755
 
        responder = protocol.ProtocolThreeResponder(response_io.write)
3756
 
        responder.send_response(response)
3757
 
        return response_io.getvalue()
3758
 
 
3759
 
    def test__call_doesnt_retry_append(self):
3760
 
        response = self.make_response(('appended', '8'))
3761
 
        output, vendor, smart_client = self.make_client_with_failing_medium(
3762
 
            fail_at_write=False, response=response)
3763
 
        smart_request = client._SmartClientRequest(smart_client, 'append',
3764
 
            ('foo', ''), body='content\n')
3765
 
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
3766
 
 
3767
 
    def test__call_retries_get_bytes(self):
3768
 
        response = self.make_response(('ok',), 'content\n')
3769
 
        output, vendor, smart_client = self.make_client_with_failing_medium(
3770
 
            fail_at_write=False, response=response)
3771
 
        smart_request = client._SmartClientRequest(smart_client, 'get',
3772
 
            ('foo',))
3773
 
        response, response_handler = smart_request._call(3)
3774
 
        self.assertEqual(('ok',), response)
3775
 
        self.assertEqual('content\n', response_handler.read_body_bytes())
3776
 
 
3777
 
    def test__call_noretry_get_bytes(self):
3778
 
        debug.debug_flags.add('noretry')
3779
 
        response = self.make_response(('ok',), 'content\n')
3780
 
        output, vendor, smart_client = self.make_client_with_failing_medium(
3781
 
            fail_at_write=False, response=response)
3782
 
        smart_request = client._SmartClientRequest(smart_client, 'get',
3783
 
            ('foo',))
3784
 
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
3785
 
 
3786
 
    def test__send_no_retry_pipes(self):
3787
 
        client_read, server_write = create_file_pipes()
3788
 
        server_read, client_write = create_file_pipes()
3789
 
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
3790
 
            client_write, base='/')
3791
 
        smart_client = client._SmartClient(client_medium)
3792
 
        smart_request = client._SmartClientRequest(smart_client,
3793
 
            'hello', ())
3794
 
        # Close the server side
3795
 
        server_read.close()
3796
 
        encoder, response_handler = smart_request._construct_protocol(3)
3797
 
        self.assertRaises(errors.ConnectionReset,
3798
 
            smart_request._send_no_retry, encoder)
3799
 
 
3800
 
    def test__send_read_response_sockets(self):
3801
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
3802
 
        listen_sock.bind(('127.0.0.1', 0))
3803
 
        listen_sock.listen(1)
3804
 
        host, port = listen_sock.getsockname()
3805
 
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
3806
 
        client_medium._ensure_connection()
3807
 
        smart_client = client._SmartClient(client_medium)
3808
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
3809
 
        # Accept the connection, but don't actually talk to the client.
3810
 
        server_sock, _ = listen_sock.accept()
3811
 
        server_sock.close()
3812
 
        # Sockets buffer and don't really notice that the server has closed the
3813
 
        # connection until we try to read again.
3814
 
        handler = smart_request._send(3)
3815
 
        self.assertRaises(errors.ConnectionReset,
3816
 
            handler.read_response_tuple, expect_body=False)
3817
 
 
3818
 
    def test__send_retries_on_write(self):
3819
 
        output, vendor, smart_client = self.make_client_with_failing_medium()
3820
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
3821
 
        handler = smart_request._send(3)
3822
 
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
3823
 
                         '\x00\x00\x00\x02de'   # empty headers
3824
 
                         's\x00\x00\x00\tl5:helloee',
3825
 
                         output.getvalue())
3826
 
        self.assertEqual(
3827
 
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3828
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3829
 
             ('close',),
3830
 
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3831
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3832
 
            ],
3833
 
            vendor.calls)
3834
 
 
3835
 
    def test__send_doesnt_retry_read_failure(self):
3836
 
        output, vendor, smart_client = self.make_client_with_failing_medium(
3837
 
            fail_at_write=False)
3838
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
3839
 
        handler = smart_request._send(3)
3840
 
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
3841
 
                         '\x00\x00\x00\x02de'   # empty headers
3842
 
                         's\x00\x00\x00\tl5:helloee',
3843
 
                         output.getvalue())
3844
 
        self.assertEqual(
3845
 
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3846
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3847
 
            ],
3848
 
            vendor.calls)
3849
 
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
3850
 
 
3851
 
    def test__send_request_retries_body_stream_if_not_started(self):
3852
 
        output, vendor, smart_client = self.make_client_with_failing_medium()
3853
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
3854
 
            body_stream=['a', 'b'])
3855
 
        response_handler = smart_request._send(3)
3856
 
        # We connect, get disconnected, and notice before consuming the stream,
3857
 
        # so we try again one time and succeed.
3858
 
        self.assertEqual(
3859
 
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3860
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3861
 
             ('close',),
3862
 
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3863
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3864
 
            ],
3865
 
            vendor.calls)
3866
 
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
3867
 
                         '\x00\x00\x00\x02de'   # empty headers
3868
 
                         's\x00\x00\x00\tl5:helloe'
3869
 
                         'b\x00\x00\x00\x01a'
3870
 
                         'b\x00\x00\x00\x01b'
3871
 
                         'e',
3872
 
                         output.getvalue())
3873
 
 
3874
 
    def test__send_request_stops_if_body_started(self):
3875
 
        # We intentionally use the python StringIO so that we can subclass it.
3876
 
        from StringIO import StringIO
3877
 
        response = StringIO()
3878
 
 
3879
 
        class FailAfterFirstWrite(StringIO):
3880
 
            """Allow one 'write' call to pass, fail the rest"""
3881
 
            def __init__(self):
3882
 
                StringIO.__init__(self)
3883
 
                self._first = True
3884
 
 
3885
 
            def write(self, s):
3886
 
                if self._first:
3887
 
                    self._first = False
3888
 
                    return StringIO.write(self, s)
3889
 
                raise IOError(errno.EINVAL, 'invalid file handle')
3890
 
        output = FailAfterFirstWrite()
3891
 
 
3892
 
        vendor = FirstRejectedStringIOSSHVendor(response, output,
3893
 
            fail_at_write=False)
3894
 
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
3895
 
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
3896
 
        smart_client = client._SmartClient(client_medium, headers={})
3897
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
3898
 
            body_stream=['a', 'b'])
3899
 
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
3900
 
        # We connect, and manage to get to the point that we start consuming
3901
 
        # the body stream. The next write fails, so we just stop.
3902
 
        self.assertEqual(
3903
 
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3904
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3905
 
             ('close',),
3906
 
            ],
3907
 
            vendor.calls)
3908
 
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
3909
 
                         '\x00\x00\x00\x02de'   # empty headers
3910
 
                         's\x00\x00\x00\tl5:helloe',
3911
 
                         output.getvalue())
3912
 
 
3913
 
    def test__send_disabled_retry(self):
3914
 
        debug.debug_flags.add('noretry')
3915
 
        output, vendor, smart_client = self.make_client_with_failing_medium()
3916
 
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
3917
 
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
3918
 
        self.assertEqual(
3919
 
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
3920
 
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
3921
 
             ('close',),
3922
 
            ],
3923
 
            vendor.calls)
3924
 
 
3925
 
 
3926
3221
class LengthPrefixedBodyDecoder(tests.TestCase):
3927
3222
 
3928
3223
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
4262
3557
        # still work correctly.
4263
3558
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
4264
3559
        new_transport = base_transport.clone('c')
4265
 
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
 
3560
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
4266
3561
        self.assertEqual(
4267
3562
            'c/',
4268
3563
            new_transport._client.remote_path_from_transport(new_transport))
4285
3580
        r = t._redirected_to('http://www.example.com/foo',
4286
3581
                             'http://www.example.com/bar')
4287
3582
        self.assertEquals(type(r), type(t))
4288
 
        self.assertEquals('joe', t._parsed_url.user)
4289
 
        self.assertEquals(t._parsed_url.user, r._parsed_url.user)
 
3583
        self.assertEquals('joe', t._user)
 
3584
        self.assertEquals(t._user, r._user)
4290
3585
 
4291
3586
    def test_redirected_to_same_host_different_protocol(self):
4292
3587
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')