~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Jelmer Vernooij
  • Date: 2012-02-20 12:19:29 UTC
  • mfrom: (6437.23.11 2.5)
  • mto: (6581.1.1 trunk)
  • mto: This revision was merged to the branch mainline in revision 6582.
  • Revision ID: jelmer@samba.org-20120220121929-7ni2psvjoatm1yp4
Merge bzr/2.5.

Show diffs side-by-side

added added

removed removed

Lines of Context:
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
 
22
import errno
21
23
import os
22
24
import socket
 
25
import subprocess
 
26
import sys
23
27
import threading
 
28
import time
 
29
 
 
30
from testtools.matchers import DocTestMatches
24
31
 
25
32
import bzrlib
26
33
from bzrlib import (
27
34
        bzrdir,
 
35
        debug,
28
36
        errors,
29
37
        osutils,
30
38
        tests,
31
 
        transport,
 
39
        transport as _mod_transport,
32
40
        urlutils,
33
41
        )
34
42
from bzrlib.smart import (
37
45
        message,
38
46
        protocol,
39
47
        request as _mod_request,
40
 
        server,
 
48
        server as _mod_server,
41
49
        vfs,
42
50
)
43
51
from bzrlib.tests import (
54
62
        )
55
63
 
56
64
 
 
65
def create_file_pipes():
 
66
    r, w = os.pipe()
 
67
    # These must be opened without buffering, or we get undefined results
 
68
    rf = os.fdopen(r, 'rb', 0)
 
69
    wf = os.fdopen(w, 'wb', 0)
 
70
    return rf, wf
 
71
 
 
72
 
 
73
def portable_socket_pair():
 
74
    """Return a pair of TCP sockets connected to each other.
 
75
 
 
76
    Unlike socket.socketpair, this should work on Windows.
 
77
    """
 
78
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
79
    listen_sock.bind(('127.0.0.1', 0))
 
80
    listen_sock.listen(1)
 
81
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
82
    client_sock.connect(listen_sock.getsockname())
 
83
    server_sock, addr = listen_sock.accept()
 
84
    listen_sock.close()
 
85
    return server_sock, client_sock
 
86
 
 
87
 
57
88
class StringIOSSHVendor(object):
58
89
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
59
90
 
68
99
        return StringIOSSHConnection(self)
69
100
 
70
101
 
 
102
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
 
103
    """The first connection will be considered closed.
 
104
 
 
105
    The second connection will succeed normally.
 
106
    """
 
107
 
 
108
    def __init__(self, read_from, write_to, fail_at_write=True):
 
109
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
 
110
            write_to)
 
111
        self.fail_at_write = fail_at_write
 
112
        self._first = True
 
113
 
 
114
    def connect_ssh(self, username, password, host, port, command):
 
115
        self.calls.append(('connect_ssh', username, password, host, port,
 
116
            command))
 
117
        if self._first:
 
118
            self._first = False
 
119
            return ClosedSSHConnection(self)
 
120
        return StringIOSSHConnection(self)
 
121
 
 
122
 
71
123
class StringIOSSHConnection(ssh.SSHConnection):
72
124
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
73
125
 
83
135
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
84
136
 
85
137
 
 
138
class ClosedSSHConnection(ssh.SSHConnection):
 
139
    """An SSH connection that just has closed channels."""
 
140
 
 
141
    def __init__(self, vendor):
 
142
        self.vendor = vendor
 
143
 
 
144
    def close(self):
 
145
        self.vendor.calls.append(('close', ))
 
146
 
 
147
    def get_sock_or_pipes(self):
 
148
        # We create matching pipes, and then close the ssh side
 
149
        bzr_read, ssh_write = create_file_pipes()
 
150
        # We always fail when bzr goes to read
 
151
        ssh_write.close()
 
152
        if self.vendor.fail_at_write:
 
153
            # If set, we'll also fail when bzr goes to write
 
154
            ssh_read, bzr_write = create_file_pipes()
 
155
            ssh_read.close()
 
156
        else:
 
157
            bzr_write = self.vendor.write_to
 
158
        return 'pipes', (bzr_read, bzr_write)
 
159
 
 
160
 
86
161
class _InvalidHostnameFeature(features.Feature):
87
162
    """Does 'non_existent.invalid' fail to resolve?
88
163
 
178
253
        client_medium._accept_bytes('abc')
179
254
        self.assertEqual('abc', output.getvalue())
180
255
 
 
256
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
 
257
        # It is unfortunate that we have to use Popen for this. However,
 
258
        # os.pipe() does not behave the same as subprocess.Popen().
 
259
        # On Windows, if you use os.pipe() and close the write side,
 
260
        # read.read() hangs. On Linux, read.read() returns the empty string.
 
261
        p = subprocess.Popen([sys.executable, '-c',
 
262
            'import sys\n'
 
263
            'sys.stdout.write(sys.stdin.read(4))\n'
 
264
            'sys.stdout.close()\n'],
 
265
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
266
        client_medium = medium.SmartSimplePipesClientMedium(
 
267
            p.stdout, p.stdin, 'base')
 
268
        client_medium._accept_bytes('abc\n')
 
269
        self.assertEqual('abc', client_medium._read_bytes(3))
 
270
        p.wait()
 
271
        # While writing to the underlying pipe,
 
272
        #   Windows py2.6.6 we get IOError(EINVAL)
 
273
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
274
        # In both cases, it should be wrapped to ConnectionReset
 
275
        self.assertRaises(errors.ConnectionReset,
 
276
                          client_medium._accept_bytes, 'more')
 
277
 
 
278
    def test_simple_pipes__accept_bytes_pipe_closed(self):
 
279
        child_read, client_write = create_file_pipes()
 
280
        client_medium = medium.SmartSimplePipesClientMedium(
 
281
            None, client_write, 'base')
 
282
        client_medium._accept_bytes('abc\n')
 
283
        self.assertEqual('abc\n', child_read.read(4))
 
284
        # While writing to the underlying pipe,
 
285
        #   Windows py2.6.6 we get IOError(EINVAL)
 
286
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
287
        # In both cases, it should be wrapped to ConnectionReset
 
288
        child_read.close()
 
289
        self.assertRaises(errors.ConnectionReset,
 
290
                          client_medium._accept_bytes, 'more')
 
291
 
 
292
    def test_simple_pipes__flush_pipe_closed(self):
 
293
        child_read, client_write = create_file_pipes()
 
294
        client_medium = medium.SmartSimplePipesClientMedium(
 
295
            None, client_write, 'base')
 
296
        client_medium._accept_bytes('abc\n')
 
297
        child_read.close()
 
298
        # Even though the pipe is closed, flush on the write side seems to be a
 
299
        # no-op, rather than a failure.
 
300
        client_medium._flush()
 
301
 
 
302
    def test_simple_pipes__flush_subprocess_closed(self):
 
303
        p = subprocess.Popen([sys.executable, '-c',
 
304
            'import sys\n'
 
305
            'sys.stdout.write(sys.stdin.read(4))\n'
 
306
            'sys.stdout.close()\n'],
 
307
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
308
        client_medium = medium.SmartSimplePipesClientMedium(
 
309
            p.stdout, p.stdin, 'base')
 
310
        client_medium._accept_bytes('abc\n')
 
311
        p.wait()
 
312
        # Even though the child process is dead, flush seems to be a no-op.
 
313
        client_medium._flush()
 
314
 
 
315
    def test_simple_pipes__read_bytes_pipe_closed(self):
 
316
        child_read, client_write = create_file_pipes()
 
317
        client_medium = medium.SmartSimplePipesClientMedium(
 
318
            child_read, client_write, 'base')
 
319
        client_medium._accept_bytes('abc\n')
 
320
        client_write.close()
 
321
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
322
        self.assertEqual('', client_medium._read_bytes(4))
 
323
 
 
324
    def test_simple_pipes__read_bytes_subprocess_closed(self):
 
325
        p = subprocess.Popen([sys.executable, '-c',
 
326
            'import sys\n'
 
327
            'if sys.platform == "win32":\n'
 
328
            '    import msvcrt, os\n'
 
329
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
330
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
331
            'sys.stdout.write(sys.stdin.read(4))\n'
 
332
            'sys.stdout.close()\n'],
 
333
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
334
        client_medium = medium.SmartSimplePipesClientMedium(
 
335
            p.stdout, p.stdin, 'base')
 
336
        client_medium._accept_bytes('abc\n')
 
337
        p.wait()
 
338
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
339
        self.assertEqual('', client_medium._read_bytes(4))
 
340
 
181
341
    def test_simple_pipes_client_disconnect_does_nothing(self):
182
342
        # calling disconnect does nothing.
183
343
        input = StringIO()
339
499
            ],
340
500
            vendor.calls)
341
501
 
 
502
    def test_ssh_client_repr(self):
 
503
        client_medium = medium.SmartSSHClientMedium(
 
504
            'base', medium.SSHParams("example.com", "4242", "username"))
 
505
        self.assertEquals(
 
506
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
 
507
            repr(client_medium))
 
508
 
 
509
    def test_ssh_client_repr_no_port(self):
 
510
        client_medium = medium.SmartSSHClientMedium(
 
511
            'base', medium.SSHParams("example.com", None, "username"))
 
512
        self.assertEquals(
 
513
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
 
514
            repr(client_medium))
 
515
 
 
516
    def test_ssh_client_repr_no_username(self):
 
517
        client_medium = medium.SmartSSHClientMedium(
 
518
            'base', medium.SSHParams("example.com", None, None))
 
519
        self.assertEquals(
 
520
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
 
521
            repr(client_medium))
 
522
 
342
523
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
343
524
        # Doing a disconnect on a new (and thus unconnected) SSH medium
344
525
        # does not fail.  It's ok to disconnect an unconnected medium.
565
746
        request.finished_reading()
566
747
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
567
748
 
 
749
    def test_reset(self):
 
750
        server_sock, client_sock = portable_socket_pair()
 
751
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
 
752
        #       bzr where it exists.
 
753
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
754
        client_medium._socket = client_sock
 
755
        client_medium._connected = True
 
756
        req = client_medium.get_request()
 
757
        self.assertRaises(errors.TooManyConcurrentRequests,
 
758
            client_medium.get_request)
 
759
        client_medium.reset()
 
760
        # The stream should be reset, marked as disconnected, though ready for
 
761
        # us to make a new request
 
762
        self.assertFalse(client_medium._connected)
 
763
        self.assertIs(None, client_medium._socket)
 
764
        try:
 
765
            self.assertEqual('', client_sock.recv(1))
 
766
        except socket.error, e:
 
767
            if e.errno not in (errno.EBADF,):
 
768
                raise
 
769
        req = client_medium.get_request()
 
770
 
568
771
 
569
772
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
570
773
 
618
821
        super(TestSmartServerStreamMedium, self).setUp()
619
822
        self.overrideEnv('BZR_NO_SMART_VFS', None)
620
823
 
621
 
    def portable_socket_pair(self):
622
 
        """Return a pair of TCP sockets connected to each other.
623
 
 
624
 
        Unlike socket.socketpair, this should work on Windows.
625
 
        """
626
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
627
 
        listen_sock.bind(('127.0.0.1', 0))
628
 
        listen_sock.listen(1)
629
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
630
 
        client_sock.connect(listen_sock.getsockname())
631
 
        server_sock, addr = listen_sock.accept()
632
 
        listen_sock.close()
633
 
        return server_sock, client_sock
 
824
    def create_pipe_medium(self, to_server, from_server, transport,
 
825
                           timeout=4.0):
 
826
        """Create a new SmartServerPipeStreamMedium."""
 
827
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
828
            transport, timeout=timeout)
 
829
 
 
830
    def create_pipe_context(self, to_server_bytes, transport):
 
831
        """Create a SmartServerSocketStreamMedium.
 
832
 
 
833
        This differes from create_pipe_medium, in that we initialize the
 
834
        request that is sent to the server, and return the StringIO class that
 
835
        will hold the response.
 
836
        """
 
837
        to_server = StringIO(to_server_bytes)
 
838
        from_server = StringIO()
 
839
        m = self.create_pipe_medium(to_server, from_server, transport)
 
840
        return m, from_server
 
841
 
 
842
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
843
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
844
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
845
            timeout=timeout)
 
846
 
 
847
    def create_socket_context(self, transport, timeout=4.0):
 
848
        """Create a new SmartServerSocketStreamMedium with default context.
 
849
 
 
850
        This will call portable_socket_pair and pass the server side to
 
851
        create_socket_medium along with transport.
 
852
        It then returns the client_sock and the server.
 
853
        """
 
854
        server_sock, client_sock = portable_socket_pair()
 
855
        server = self.create_socket_medium(server_sock, transport,
 
856
                                           timeout=timeout)
 
857
        return server, client_sock
634
858
 
635
859
    def test_smart_query_version(self):
636
860
        """Feed a canned query version to a server"""
637
861
        # wire-to-wire, using the whole stack
638
 
        to_server = StringIO('hello\n')
639
 
        from_server = StringIO()
640
862
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
641
 
        server = medium.SmartServerPipeStreamMedium(
642
 
            to_server, from_server, transport)
 
863
        server, from_server = self.create_pipe_context('hello\n', transport)
643
864
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
644
865
                from_server.write)
645
866
        server._serve_one_request(smart_protocol)
649
870
    def test_response_to_canned_get(self):
650
871
        transport = memory.MemoryTransport('memory:///')
651
872
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
652
 
        to_server = StringIO('get\001./testfile\n')
653
 
        from_server = StringIO()
654
 
        server = medium.SmartServerPipeStreamMedium(
655
 
            to_server, from_server, transport)
 
873
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
874
            transport)
656
875
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
657
876
                from_server.write)
658
877
        server._serve_one_request(smart_protocol)
669
888
        # VFS requests use filenames, not raw UTF-8.
670
889
        hpss_path = urlutils.escape(utf8_filename)
671
890
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
672
 
        to_server = StringIO('get\001' + hpss_path + '\n')
673
 
        from_server = StringIO()
674
 
        server = medium.SmartServerPipeStreamMedium(
675
 
            to_server, from_server, transport)
 
891
        server, from_server = self.create_pipe_context(
 
892
                'get\001' + hpss_path + '\n', transport)
676
893
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
677
894
                from_server.write)
678
895
        server._serve_one_request(smart_protocol)
684
901
 
685
902
    def test_pipe_like_stream_with_bulk_data(self):
686
903
        sample_request_bytes = 'command\n9\nbulk datadone\n'
687
 
        to_server = StringIO(sample_request_bytes)
688
 
        from_server = StringIO()
689
 
        server = medium.SmartServerPipeStreamMedium(
690
 
            to_server, from_server, None)
 
904
        server, from_server = self.create_pipe_context(
 
905
            sample_request_bytes, None)
691
906
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
692
907
        server._serve_one_request(sample_protocol)
693
908
        self.assertEqual('', from_server.getvalue())
696
911
 
697
912
    def test_socket_stream_with_bulk_data(self):
698
913
        sample_request_bytes = 'command\n9\nbulk datadone\n'
699
 
        server_sock, client_sock = self.portable_socket_pair()
700
 
        server = medium.SmartServerSocketStreamMedium(
701
 
            server_sock, None)
 
914
        server, client_sock = self.create_socket_context(None)
702
915
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
703
916
        client_sock.sendall(sample_request_bytes)
704
917
        server._serve_one_request(sample_protocol)
705
 
        server_sock.close()
 
918
        server._disconnect_client()
706
919
        self.assertEqual('', client_sock.recv(1))
707
920
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
708
921
        self.assertFalse(server.finished)
709
922
 
710
923
    def test_pipe_like_stream_shutdown_detection(self):
711
 
        to_server = StringIO('')
712
 
        from_server = StringIO()
713
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
924
        server, _ = self.create_pipe_context('', None)
714
925
        server._serve_one_request(SampleRequest('x'))
715
926
        self.assertTrue(server.finished)
716
927
 
717
928
    def test_socket_stream_shutdown_detection(self):
718
 
        server_sock, client_sock = self.portable_socket_pair()
 
929
        server, client_sock = self.create_socket_context(None)
719
930
        client_sock.close()
720
 
        server = medium.SmartServerSocketStreamMedium(
721
 
            server_sock, None)
722
931
        server._serve_one_request(SampleRequest('x'))
723
932
        self.assertTrue(server.finished)
724
933
 
735
944
        rest_of_request_bytes = 'lo\n'
736
945
        expected_response = (
737
946
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
738
 
        server_sock, client_sock = self.portable_socket_pair()
739
 
        server = medium.SmartServerSocketStreamMedium(
740
 
            server_sock, None)
 
947
        server, client_sock = self.create_socket_context(None)
741
948
        client_sock.sendall(incomplete_request_bytes)
742
949
        server_protocol = server._build_protocol()
743
950
        client_sock.sendall(rest_of_request_bytes)
744
951
        server._serve_one_request(server_protocol)
745
 
        server_sock.close()
 
952
        server._disconnect_client()
746
953
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
747
954
                         "Not a version 2 response to 'hello' request.")
748
955
        self.assertEqual('', client_sock.recv(1))
767
974
        to_server_w = os.fdopen(to_server_w, 'w', 0)
768
975
        from_server_r = os.fdopen(from_server_r, 'r', 0)
769
976
        from_server = os.fdopen(from_server, 'w', 0)
770
 
        server = medium.SmartServerPipeStreamMedium(
771
 
            to_server, from_server, None)
 
977
        server = self.create_pipe_medium(to_server, from_server, None)
772
978
        # Like test_socket_stream_incomplete_request, write an incomplete
773
979
        # request (that does not end in '\n') and build a protocol from it.
774
980
        to_server_w.write(incomplete_request_bytes)
789
995
        # _serve_one_request should still process both of them as if they had
790
996
        # been received separately.
791
997
        sample_request_bytes = 'command\n'
792
 
        to_server = StringIO(sample_request_bytes * 2)
793
 
        from_server = StringIO()
794
 
        server = medium.SmartServerPipeStreamMedium(
795
 
            to_server, from_server, None)
 
998
        server, from_server = self.create_pipe_context(
 
999
            sample_request_bytes * 2, None)
796
1000
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
797
1001
        server._serve_one_request(first_protocol)
798
1002
        self.assertEqual(0, first_protocol.next_read_size())
811
1015
        # _serve_one_request should still process both of them as if they had
812
1016
        # been received separately.
813
1017
        sample_request_bytes = 'command\n'
814
 
        server_sock, client_sock = self.portable_socket_pair()
815
 
        server = medium.SmartServerSocketStreamMedium(
816
 
            server_sock, None)
 
1018
        server, client_sock = self.create_socket_context(None)
817
1019
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
818
1020
        # Put two whole requests on the wire.
819
1021
        client_sock.sendall(sample_request_bytes * 2)
826
1028
        stream_still_open = server._serve_one_request(second_protocol)
827
1029
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
828
1030
        self.assertFalse(server.finished)
829
 
        server_sock.close()
 
1031
        server._disconnect_client()
830
1032
        self.assertEqual('', client_sock.recv(1))
831
1033
 
832
1034
    def test_pipe_like_stream_error_handling(self):
839
1041
        def close():
840
1042
            self.closed = True
841
1043
        from_server.close = close
842
 
        server = medium.SmartServerPipeStreamMedium(
 
1044
        server = self.create_pipe_medium(
843
1045
            to_server, from_server, None)
844
1046
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
845
1047
        server._serve_one_request(fake_protocol)
848
1050
        self.assertTrue(server.finished)
849
1051
 
850
1052
    def test_socket_stream_error_handling(self):
851
 
        server_sock, client_sock = self.portable_socket_pair()
852
 
        server = medium.SmartServerSocketStreamMedium(
853
 
            server_sock, None)
 
1053
        server, client_sock = self.create_socket_context(None)
854
1054
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
855
1055
        server._serve_one_request(fake_protocol)
856
1056
        # recv should not block, because the other end of the socket has been
859
1059
        self.assertTrue(server.finished)
860
1060
 
861
1061
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
862
 
        to_server = StringIO('')
863
 
        from_server = StringIO()
864
 
        server = medium.SmartServerPipeStreamMedium(
865
 
            to_server, from_server, None)
 
1062
        server, from_server = self.create_pipe_context('', None)
866
1063
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
867
1064
        self.assertRaises(
868
1065
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
869
1066
        self.assertEqual('', from_server.getvalue())
870
1067
 
871
1068
    def test_socket_stream_keyboard_interrupt_handling(self):
872
 
        server_sock, client_sock = self.portable_socket_pair()
873
 
        server = medium.SmartServerSocketStreamMedium(
874
 
            server_sock, None)
 
1069
        server, client_sock = self.create_socket_context(None)
875
1070
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
876
1071
        self.assertRaises(
877
1072
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
878
 
        server_sock.close()
 
1073
        server._disconnect_client()
879
1074
        self.assertEqual('', client_sock.recv(1))
880
1075
 
881
1076
    def build_protocol_pipe_like(self, bytes):
882
 
        to_server = StringIO(bytes)
883
 
        from_server = StringIO()
884
 
        server = medium.SmartServerPipeStreamMedium(
885
 
            to_server, from_server, None)
 
1077
        server, _ = self.create_pipe_context(bytes, None)
886
1078
        return server._build_protocol()
887
1079
 
888
1080
    def build_protocol_socket(self, bytes):
889
 
        server_sock, client_sock = self.portable_socket_pair()
890
 
        server = medium.SmartServerSocketStreamMedium(
891
 
            server_sock, None)
 
1081
        server, client_sock = self.create_socket_context(None)
892
1082
        client_sock.sendall(bytes)
893
1083
        client_sock.close()
894
1084
        return server._build_protocol()
934
1124
        server_protocol = self.build_protocol_socket('bzr request 2\n')
935
1125
        self.assertProtocolTwo(server_protocol)
936
1126
 
 
1127
    def test__build_protocol_returns_if_stopping(self):
 
1128
        # _build_protocol should notice that we are stopping, and return
 
1129
        # without waiting for bytes from the client.
 
1130
        server, client_sock = self.create_socket_context(None)
 
1131
        server._stop_gracefully()
 
1132
        self.assertIs(None, server._build_protocol())
 
1133
 
 
1134
    def test_socket_set_timeout(self):
 
1135
        server, _ = self.create_socket_context(None, timeout=1.23)
 
1136
        self.assertEqual(1.23, server._client_timeout)
 
1137
 
 
1138
    def test_pipe_set_timeout(self):
 
1139
        server = self.create_pipe_medium(None, None, None,
 
1140
            timeout=1.23)
 
1141
        self.assertEqual(1.23, server._client_timeout)
 
1142
 
 
1143
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
1144
        server, client_sock = self.create_socket_context(None)
 
1145
        client_sock.sendall('data\n')
 
1146
        # This should not block or consume any actual content
 
1147
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
1148
        data = server.read_bytes(5)
 
1149
        self.assertEqual('data\n', data)
 
1150
 
 
1151
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
1152
        server, client_sock = self.create_socket_context(None)
 
1153
        # This should timeout quickly, reporting that there wasn't any data
 
1154
        self.assertRaises(errors.ConnectionTimeout,
 
1155
                          server._wait_for_bytes_with_timeout, 0.01)
 
1156
        client_sock.close()
 
1157
        data = server.read_bytes(1)
 
1158
        self.assertEqual('', data)
 
1159
 
 
1160
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
1161
        server, client_sock = self.create_socket_context(None)
 
1162
        # With the socket closed, this should return right away.
 
1163
        # It seems select.select() returns that you *can* read on the socket,
 
1164
        # even though it closed. Presumably as a way to tell it is closed?
 
1165
        # Testing shows that without sock.close() this times-out failing the
 
1166
        # test, but with it, it returns False immediately.
 
1167
        client_sock.close()
 
1168
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
1169
        data = server.read_bytes(1)
 
1170
        self.assertEqual('', data)
 
1171
 
 
1172
    def test_socket_wait_for_bytes_with_shutdown(self):
 
1173
        server, client_sock = self.create_socket_context(None)
 
1174
        t = time.time()
 
1175
        # Override the _timer functionality, so that time never increments,
 
1176
        # this way, we can be sure we stopped because of the flag, and not
 
1177
        # because of a timeout, etc.
 
1178
        server._timer = lambda: t
 
1179
        server._client_poll_timeout = 0.1
 
1180
        server._stop_gracefully()
 
1181
        server._wait_for_bytes_with_timeout(1.0)
 
1182
 
 
1183
    def test_socket_serve_timeout_closes_socket(self):
 
1184
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1185
        # This should timeout quickly, and then close the connection so that
 
1186
        # client_sock recv doesn't block.
 
1187
        server.serve()
 
1188
        self.assertEqual('', client_sock.recv(1))
 
1189
 
 
1190
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1191
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1192
        # You can't select() on a StringIO
 
1193
        (r_server, w_client) = os.pipe()
 
1194
        self.addCleanup(os.close, w_client)
 
1195
        with os.fdopen(r_server, 'rb') as rf_server:
 
1196
            server = self.create_pipe_medium(
 
1197
                rf_server, None, None)
 
1198
            os.write(w_client, 'data\n')
 
1199
            # This should not block or consume any actual content
 
1200
            server._wait_for_bytes_with_timeout(0.1)
 
1201
            data = server.read_bytes(5)
 
1202
            self.assertEqual('data\n', data)
 
1203
 
 
1204
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1205
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1206
        # You can't select() on a StringIO
 
1207
        (r_server, w_client) = os.pipe()
 
1208
        # We can't add an os.close cleanup here, because we need to control
 
1209
        # when the file handle gets closed ourselves.
 
1210
        with os.fdopen(r_server, 'rb') as rf_server:
 
1211
            server = self.create_pipe_medium(
 
1212
                rf_server, None, None)
 
1213
            if sys.platform == 'win32':
 
1214
                # Windows cannot select() on a pipe, so we just always return
 
1215
                server._wait_for_bytes_with_timeout(0.01)
 
1216
            else:
 
1217
                self.assertRaises(errors.ConnectionTimeout,
 
1218
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1219
            os.close(w_client)
 
1220
            data = server.read_bytes(5)
 
1221
            self.assertEqual('', data)
 
1222
 
 
1223
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1224
        server, _ = self.create_pipe_context('', None)
 
1225
        # Our file doesn't support polling, so we should always just return
 
1226
        # 'you have data to consume.
 
1227
        server._wait_for_bytes_with_timeout(0.01)
 
1228
 
937
1229
 
938
1230
class TestGetProtocolFactoryForBytes(tests.TestCase):
939
1231
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
969
1261
 
970
1262
class TestSmartTCPServer(tests.TestCase):
971
1263
 
 
1264
    def make_server(self):
 
1265
        """Create a SmartTCPServer that we can exercise.
 
1266
 
 
1267
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1268
        version overrides lots of functionality like 'serve', and we want to
 
1269
        test the raw service.
 
1270
 
 
1271
        This will start the server in another thread, and wait for it to
 
1272
        indicate it has finished starting up.
 
1273
 
 
1274
        :return: (server, server_thread)
 
1275
        """
 
1276
        t = _mod_transport.get_transport_from_url('memory:///')
 
1277
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1278
        server._ACCEPT_TIMEOUT = 0.1
 
1279
        # We don't use 'localhost' because that might be an IPv6 address.
 
1280
        server.start_server('127.0.0.1', 0)
 
1281
        server_thread = threading.Thread(target=server.serve,
 
1282
                                         args=(self.id(),))
 
1283
        server_thread.start()
 
1284
        # Ensure this gets called at some point
 
1285
        self.addCleanup(server._stop_gracefully)
 
1286
        server._started.wait()
 
1287
        return server, server_thread
 
1288
 
 
1289
    def ensure_client_disconnected(self, client_sock):
 
1290
        """Ensure that a socket is closed, discarding all errors."""
 
1291
        try:
 
1292
            client_sock.close()
 
1293
        except Exception:
 
1294
            pass
 
1295
 
 
1296
    def connect_to_server(self, server):
 
1297
        """Create a client socket that can talk to the server."""
 
1298
        client_sock = socket.socket()
 
1299
        server_info = server._server_socket.getsockname()
 
1300
        client_sock.connect(server_info)
 
1301
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1302
        return client_sock
 
1303
 
 
1304
    def connect_to_server_and_hangup(self, server):
 
1305
        """Connect to the server, and then hang up.
 
1306
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1307
        """
 
1308
        # If the server has already signaled that the socket is closed, we
 
1309
        # don't need to try to connect to it. Not being set, though, the server
 
1310
        # might still close the socket while we try to connect to it. So we
 
1311
        # still have to catch the exception.
 
1312
        if server._stopped.isSet():
 
1313
            return
 
1314
        try:
 
1315
            client_sock = self.connect_to_server(server)
 
1316
            client_sock.close()
 
1317
        except socket.error, e:
 
1318
            # If the server has hung up already, that is fine.
 
1319
            pass
 
1320
 
 
1321
    def say_hello(self, client_sock):
 
1322
        """Send the 'hello' smart RPC, and expect the response."""
 
1323
        client_sock.send('hello\n')
 
1324
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1325
 
 
1326
    def shutdown_server_cleanly(self, server, server_thread):
 
1327
        server._stop_gracefully()
 
1328
        self.connect_to_server_and_hangup(server)
 
1329
        server._stopped.wait()
 
1330
        server._fully_stopped.wait()
 
1331
        server_thread.join()
 
1332
 
972
1333
    def test_get_error_unexpected(self):
973
1334
        """Error reported by server with no specific representation"""
974
1335
        self.overrideEnv('BZR_NO_SMART_VFS', None)
992
1353
                                t.get, 'something')
993
1354
        self.assertContainsRe(str(err), 'some random exception')
994
1355
 
 
1356
    def test_propagates_timeout(self):
 
1357
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1358
        server_sock, client_sock = portable_socket_pair()
 
1359
        handler = server._make_handler(server_sock)
 
1360
        self.assertEqual(1.23, handler._client_timeout)
 
1361
 
 
1362
    def test_serve_conn_tracks_connections(self):
 
1363
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1364
        server_sock, client_sock = portable_socket_pair()
 
1365
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1366
        self.assertEqual(1, len(server._active_connections))
 
1367
        # We still want to talk on the connection. Polling should indicate it
 
1368
        # is still active.
 
1369
        server._poll_active_connections()
 
1370
        self.assertEqual(1, len(server._active_connections))
 
1371
        # Closing the socket will end the active thread, and polling will
 
1372
        # notice and remove it from the active set.
 
1373
        client_sock.close()
 
1374
        server._poll_active_connections(0.1)
 
1375
        self.assertEqual(0, len(server._active_connections))
 
1376
 
 
1377
    def test_serve_closes_out_finished_connections(self):
 
1378
        server, server_thread = self.make_server()
 
1379
        # The server is started, connect to it.
 
1380
        client_sock = self.connect_to_server(server)
 
1381
        # We send and receive on the connection, so that we know the
 
1382
        # server-side has seen the connect, and started handling the
 
1383
        # results.
 
1384
        self.say_hello(client_sock)
 
1385
        self.assertEqual(1, len(server._active_connections))
 
1386
        # Grab a handle to the thread that is processing our request
 
1387
        _, server_side_thread = server._active_connections[0]
 
1388
        # Close the connection, ask the server to stop, and wait for the
 
1389
        # server to stop, as well as the thread that was servicing the
 
1390
        # client request.
 
1391
        client_sock.close()
 
1392
        # Wait for the server-side request thread to notice we are closed.
 
1393
        server_side_thread.join()
 
1394
        # Stop the server, it should notice the connection has finished.
 
1395
        self.shutdown_server_cleanly(server, server_thread)
 
1396
        # The server should have noticed that all clients are gone before
 
1397
        # exiting.
 
1398
        self.assertEqual(0, len(server._active_connections))
 
1399
 
 
1400
    def test_serve_reaps_finished_connections(self):
 
1401
        server, server_thread = self.make_server()
 
1402
        client_sock1 = self.connect_to_server(server)
 
1403
        # We send and receive on the connection, so that we know the
 
1404
        # server-side has seen the connect, and started handling the
 
1405
        # results.
 
1406
        self.say_hello(client_sock1)
 
1407
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1408
        client_sock1.close()
 
1409
        server_side_thread1.join()
 
1410
        # By waiting until the first connection is fully done, the server
 
1411
        # should notice after another connection that the first has finished.
 
1412
        client_sock2 = self.connect_to_server(server)
 
1413
        self.say_hello(client_sock2)
 
1414
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1415
        # There is a race condition. We know that client_sock2 has been
 
1416
        # registered, but not that _poll_active_connections has been called. We
 
1417
        # know that it will be called before the server will accept a new
 
1418
        # connection, however. So connect one more time, and assert that we
 
1419
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1420
        # connection is not connection 1
 
1421
        client_sock3 = self.connect_to_server(server)
 
1422
        self.say_hello(client_sock3)
 
1423
        # Copy the list, so we don't have it mutating behind our back
 
1424
        conns = list(server._active_connections)
 
1425
        self.assertEqual(2, len(conns))
 
1426
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1427
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1428
        client_sock2.close()
 
1429
        client_sock3.close()
 
1430
        self.shutdown_server_cleanly(server, server_thread)
 
1431
 
 
1432
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1433
        server, server_thread = self.make_server()
 
1434
        # We need something big enough that it won't fit in a single recv. So
 
1435
        # the server thread gets blocked writing content to the client until we
 
1436
        # finish reading on the client.
 
1437
        server.backing_transport.put_bytes('bigfile',
 
1438
            'a'*1024*1024)
 
1439
        client_sock = self.connect_to_server(server)
 
1440
        self.say_hello(client_sock)
 
1441
        _, server_side_thread = server._active_connections[0]
 
1442
        # Start the RPC, but don't finish reading the response
 
1443
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1444
            'base', client_sock)
 
1445
        client_client = client._SmartClient(client_medium)
 
1446
        resp, response_handler = client_client.call_expecting_body('get',
 
1447
            'bigfile')
 
1448
        self.assertEqual(('ok',), resp)
 
1449
        # Ask the server to stop gracefully, and wait for it.
 
1450
        server._stop_gracefully()
 
1451
        self.connect_to_server_and_hangup(server)
 
1452
        server._stopped.wait()
 
1453
        # It should not be accepting another connection.
 
1454
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1455
        # It should also not be fully stopped
 
1456
        server._fully_stopped.wait(0.01)
 
1457
        self.assertFalse(server._fully_stopped.isSet())
 
1458
        response_handler.read_body_bytes()
 
1459
        client_sock.close()
 
1460
        server_side_thread.join()
 
1461
        server_thread.join()
 
1462
        self.assertTrue(server._fully_stopped.isSet())
 
1463
        log = self.get_log()
 
1464
        self.assertThat(log, DocTestMatches("""\
 
1465
    INFO  Requested to stop gracefully
 
1466
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1467
    INFO  Waiting for 1 client(s) to finish
 
1468
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1469
 
 
1470
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1471
        server, server_thread = self.make_server()
 
1472
        client_sock = self.connect_to_server(server)
 
1473
        self.say_hello(client_sock)
 
1474
        server_handler, server_side_thread = server._active_connections[0]
 
1475
        self.assertFalse(server_handler.finished)
 
1476
        server._stop_gracefully()
 
1477
        self.assertTrue(server_handler.finished)
 
1478
        client_sock.close()
 
1479
        self.connect_to_server_and_hangup(server)
 
1480
        server_thread.join()
 
1481
 
995
1482
 
996
1483
class SmartTCPTests(tests.TestCase):
997
1484
    """Tests for connection/end to end behaviour using the TCP server.
1015
1502
            mem_server.start_server()
1016
1503
            self.addCleanup(mem_server.stop_server)
1017
1504
            self.permit_url(mem_server.get_url())
1018
 
            self.backing_transport = transport.get_transport_from_url(
 
1505
            self.backing_transport = _mod_transport.get_transport_from_url(
1019
1506
                mem_server.get_url())
1020
1507
        else:
1021
1508
            self.backing_transport = backing_transport
1022
1509
        if readonly:
1023
1510
            self.real_backing_transport = self.backing_transport
1024
 
            self.backing_transport = transport.get_transport_from_url(
 
1511
            self.backing_transport = _mod_transport.get_transport_from_url(
1025
1512
                "readonly+" + self.backing_transport.abspath('.'))
1026
 
        self.server = server.SmartTCPServer(self.backing_transport)
 
1513
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1514
                                                 client_timeout=4.0)
1027
1515
        self.server.start_server('127.0.0.1', 0)
1028
1516
        self.server.start_background_thread('-' + self.id())
1029
1517
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1163
1651
    def test_server_started_hook_memory(self):
1164
1652
        """The server_started hook fires when the server is started."""
1165
1653
        self.hook_calls = []
1166
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1654
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1167
1655
            self.capture_server_call, None)
1168
1656
        self.start_server()
1169
1657
        # at this point, the server will be starting a thread up.
1177
1665
    def test_server_started_hook_file(self):
1178
1666
        """The server_started hook fires when the server is started."""
1179
1667
        self.hook_calls = []
1180
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1668
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1181
1669
            self.capture_server_call, None)
1182
1670
        self.start_server(
1183
 
            backing_transport=transport.get_transport_from_path("."))
 
1671
            backing_transport=_mod_transport.get_transport_from_path("."))
1184
1672
        # at this point, the server will be starting a thread up.
1185
1673
        # there is no indicator at the moment, so bodge it by doing a request.
1186
1674
        self.transport.has('.')
1194
1682
    def test_server_stopped_hook_simple_memory(self):
1195
1683
        """The server_stopped hook fires when the server is stopped."""
1196
1684
        self.hook_calls = []
1197
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1685
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1198
1686
            self.capture_server_call, None)
1199
1687
        self.start_server()
1200
1688
        result = [([self.backing_transport.base], self.transport.base)]
1211
1699
    def test_server_stopped_hook_simple_file(self):
1212
1700
        """The server_stopped hook fires when the server is stopped."""
1213
1701
        self.hook_calls = []
1214
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1702
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1215
1703
            self.capture_server_call, None)
1216
1704
        self.start_server(
1217
 
            backing_transport=transport.get_transport_from_path("."))
 
1705
            backing_transport=_mod_transport.get_transport_from_path("."))
1218
1706
        result = [(
1219
1707
            [self.backing_transport.base, self.backing_transport.external_url()]
1220
1708
            , self.transport.base)]
1356
1844
class RemoteTransportRegistration(tests.TestCase):
1357
1845
 
1358
1846
    def test_registration(self):
1359
 
        t = transport.get_transport_from_url('bzr+ssh://example.com/path')
 
1847
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
1360
1848
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1361
1849
        self.assertEqual('example.com', t._parsed_url.host)
1362
1850
 
1363
1851
    def test_bzr_https(self):
1364
1852
        # https://bugs.launchpad.net/bzr/+bug/128456
1365
 
        t = transport.get_transport_from_url('bzr+https://example.com/path')
 
1853
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
1366
1854
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1367
1855
        self.assertStartsWith(
1368
1856
            t._http_transport.base,
2545
3033
        from_server = StringIO()
2546
3034
        transport = memory.MemoryTransport('memory:///')
2547
3035
        server = medium.SmartServerPipeStreamMedium(
2548
 
            to_server, from_server, transport)
 
3036
            to_server, from_server, transport, timeout=4.0)
2549
3037
        proto = server._build_protocol()
2550
3038
        message_handler = proto.message_handler
2551
3039
        server._serve_one_request(proto)
2796
3284
            'e', # end
2797
3285
            output.getvalue())
2798
3286
 
 
3287
    def test_records_start_of_body_stream(self):
 
3288
        requester, output = self.make_client_encoder_and_output()
 
3289
        requester.set_headers({})
 
3290
        in_stream = [False]
 
3291
        def stream_checker():
 
3292
            self.assertTrue(requester.body_stream_started)
 
3293
            in_stream[0] = True
 
3294
            yield 'content'
 
3295
        flush_called = []
 
3296
        orig_flush = requester.flush
 
3297
        def tracked_flush():
 
3298
            flush_called.append(in_stream[0])
 
3299
            if in_stream[0]:
 
3300
                self.assertTrue(requester.body_stream_started)
 
3301
            else:
 
3302
                self.assertFalse(requester.body_stream_started)
 
3303
            return orig_flush()
 
3304
        requester.flush = tracked_flush
 
3305
        requester.call_with_body_stream(('one arg',), stream_checker())
 
3306
        self.assertEqual(
 
3307
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
3308
            '\x00\x00\x00\x02de' # headers
 
3309
            's\x00\x00\x00\x0bl7:one arge' # args
 
3310
            'b\x00\x00\x00\x07content' # body
 
3311
            'e', output.getvalue())
 
3312
        self.assertEqual([False, True, True], flush_called)
 
3313
 
2799
3314
 
2800
3315
class StubMediumRequest(object):
2801
3316
    """A stub medium request that tracks the number of times accept_bytes is
3221
3736
        # encoder.
3222
3737
 
3223
3738
 
 
3739
class Test_SmartClientRequest(tests.TestCase):
 
3740
 
 
3741
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
 
3742
        response_io = StringIO(response)
 
3743
        output = StringIO()
 
3744
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
 
3745
                    fail_at_write=fail_at_write)
 
3746
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3747
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3748
        smart_client = client._SmartClient(client_medium, headers={})
 
3749
        return output, vendor, smart_client
 
3750
 
 
3751
    def make_response(self, args, body=None, body_stream=None):
 
3752
        response_io = StringIO()
 
3753
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
 
3754
            body_stream=body_stream)
 
3755
        responder = protocol.ProtocolThreeResponder(response_io.write)
 
3756
        responder.send_response(response)
 
3757
        return response_io.getvalue()
 
3758
 
 
3759
    def test__call_doesnt_retry_append(self):
 
3760
        response = self.make_response(('appended', '8'))
 
3761
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3762
            fail_at_write=False, response=response)
 
3763
        smart_request = client._SmartClientRequest(smart_client, 'append',
 
3764
            ('foo', ''), body='content\n')
 
3765
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3766
 
 
3767
    def test__call_retries_get_bytes(self):
 
3768
        response = self.make_response(('ok',), 'content\n')
 
3769
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3770
            fail_at_write=False, response=response)
 
3771
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3772
            ('foo',))
 
3773
        response, response_handler = smart_request._call(3)
 
3774
        self.assertEqual(('ok',), response)
 
3775
        self.assertEqual('content\n', response_handler.read_body_bytes())
 
3776
 
 
3777
    def test__call_noretry_get_bytes(self):
 
3778
        debug.debug_flags.add('noretry')
 
3779
        response = self.make_response(('ok',), 'content\n')
 
3780
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3781
            fail_at_write=False, response=response)
 
3782
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3783
            ('foo',))
 
3784
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3785
 
 
3786
    def test__send_no_retry_pipes(self):
 
3787
        client_read, server_write = create_file_pipes()
 
3788
        server_read, client_write = create_file_pipes()
 
3789
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
 
3790
            client_write, base='/')
 
3791
        smart_client = client._SmartClient(client_medium)
 
3792
        smart_request = client._SmartClientRequest(smart_client,
 
3793
            'hello', ())
 
3794
        # Close the server side
 
3795
        server_read.close()
 
3796
        encoder, response_handler = smart_request._construct_protocol(3)
 
3797
        self.assertRaises(errors.ConnectionReset,
 
3798
            smart_request._send_no_retry, encoder)
 
3799
 
 
3800
    def test__send_read_response_sockets(self):
 
3801
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
3802
        listen_sock.bind(('127.0.0.1', 0))
 
3803
        listen_sock.listen(1)
 
3804
        host, port = listen_sock.getsockname()
 
3805
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
 
3806
        client_medium._ensure_connection()
 
3807
        smart_client = client._SmartClient(client_medium)
 
3808
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3809
        # Accept the connection, but don't actually talk to the client.
 
3810
        server_sock, _ = listen_sock.accept()
 
3811
        server_sock.close()
 
3812
        # Sockets buffer and don't really notice that the server has closed the
 
3813
        # connection until we try to read again.
 
3814
        handler = smart_request._send(3)
 
3815
        self.assertRaises(errors.ConnectionReset,
 
3816
            handler.read_response_tuple, expect_body=False)
 
3817
 
 
3818
    def test__send_retries_on_write(self):
 
3819
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3820
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3821
        handler = smart_request._send(3)
 
3822
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3823
                         '\x00\x00\x00\x02de'   # empty headers
 
3824
                         's\x00\x00\x00\tl5:helloee',
 
3825
                         output.getvalue())
 
3826
        self.assertEqual(
 
3827
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3828
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3829
             ('close',),
 
3830
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3831
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3832
            ],
 
3833
            vendor.calls)
 
3834
 
 
3835
    def test__send_doesnt_retry_read_failure(self):
 
3836
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3837
            fail_at_write=False)
 
3838
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3839
        handler = smart_request._send(3)
 
3840
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3841
                         '\x00\x00\x00\x02de'   # empty headers
 
3842
                         's\x00\x00\x00\tl5:helloee',
 
3843
                         output.getvalue())
 
3844
        self.assertEqual(
 
3845
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3846
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3847
            ],
 
3848
            vendor.calls)
 
3849
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
3850
 
 
3851
    def test__send_request_retries_body_stream_if_not_started(self):
 
3852
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3853
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3854
            body_stream=['a', 'b'])
 
3855
        response_handler = smart_request._send(3)
 
3856
        # We connect, get disconnected, and notice before consuming the stream,
 
3857
        # so we try again one time and succeed.
 
3858
        self.assertEqual(
 
3859
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3860
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3861
             ('close',),
 
3862
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3863
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3864
            ],
 
3865
            vendor.calls)
 
3866
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3867
                         '\x00\x00\x00\x02de'   # empty headers
 
3868
                         's\x00\x00\x00\tl5:helloe'
 
3869
                         'b\x00\x00\x00\x01a'
 
3870
                         'b\x00\x00\x00\x01b'
 
3871
                         'e',
 
3872
                         output.getvalue())
 
3873
 
 
3874
    def test__send_request_stops_if_body_started(self):
 
3875
        # We intentionally use the python StringIO so that we can subclass it.
 
3876
        from StringIO import StringIO
 
3877
        response = StringIO()
 
3878
 
 
3879
        class FailAfterFirstWrite(StringIO):
 
3880
            """Allow one 'write' call to pass, fail the rest"""
 
3881
            def __init__(self):
 
3882
                StringIO.__init__(self)
 
3883
                self._first = True
 
3884
 
 
3885
            def write(self, s):
 
3886
                if self._first:
 
3887
                    self._first = False
 
3888
                    return StringIO.write(self, s)
 
3889
                raise IOError(errno.EINVAL, 'invalid file handle')
 
3890
        output = FailAfterFirstWrite()
 
3891
 
 
3892
        vendor = FirstRejectedStringIOSSHVendor(response, output,
 
3893
            fail_at_write=False)
 
3894
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3895
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3896
        smart_client = client._SmartClient(client_medium, headers={})
 
3897
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3898
            body_stream=['a', 'b'])
 
3899
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3900
        # We connect, and manage to get to the point that we start consuming
 
3901
        # the body stream. The next write fails, so we just stop.
 
3902
        self.assertEqual(
 
3903
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3904
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3905
             ('close',),
 
3906
            ],
 
3907
            vendor.calls)
 
3908
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3909
                         '\x00\x00\x00\x02de'   # empty headers
 
3910
                         's\x00\x00\x00\tl5:helloe',
 
3911
                         output.getvalue())
 
3912
 
 
3913
    def test__send_disabled_retry(self):
 
3914
        debug.debug_flags.add('noretry')
 
3915
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3916
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3917
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3918
        self.assertEqual(
 
3919
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3920
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3921
             ('close',),
 
3922
            ],
 
3923
            vendor.calls)
 
3924
 
 
3925
 
3224
3926
class LengthPrefixedBodyDecoder(tests.TestCase):
3225
3927
 
3226
3928
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3560
4262
        # still work correctly.
3561
4263
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3562
4264
        new_transport = base_transport.clone('c')
3563
 
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
4265
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
3564
4266
        self.assertEqual(
3565
4267
            'c/',
3566
4268
            new_transport._client.remote_path_from_transport(new_transport))