~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

Merge bzr.dev.

Show diffs side-by-side

added added

removed removed

Lines of Context:
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
21
22
import os
22
23
import socket
 
24
import sys
23
25
import threading
 
26
import time
 
27
 
 
28
from testtools.matchers import DocTestMatches
24
29
 
25
30
import bzrlib
26
31
from bzrlib import (
54
59
        )
55
60
 
56
61
 
 
62
def portable_socket_pair():
 
63
    """Return a pair of TCP sockets connected to each other.
 
64
 
 
65
    Unlike socket.socketpair, this should work on Windows.
 
66
    """
 
67
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
68
    listen_sock.bind(('127.0.0.1', 0))
 
69
    listen_sock.listen(1)
 
70
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
71
    client_sock.connect(listen_sock.getsockname())
 
72
    server_sock, addr = listen_sock.accept()
 
73
    listen_sock.close()
 
74
    return server_sock, client_sock
 
75
 
 
76
 
57
77
class StringIOSSHVendor(object):
58
78
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
59
79
 
618
638
        super(TestSmartServerStreamMedium, self).setUp()
619
639
        self.overrideEnv('BZR_NO_SMART_VFS', None)
620
640
 
621
 
    def portable_socket_pair(self):
622
 
        """Return a pair of TCP sockets connected to each other.
623
 
 
624
 
        Unlike socket.socketpair, this should work on Windows.
625
 
        """
626
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
627
 
        listen_sock.bind(('127.0.0.1', 0))
628
 
        listen_sock.listen(1)
629
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
630
 
        client_sock.connect(listen_sock.getsockname())
631
 
        server_sock, addr = listen_sock.accept()
632
 
        listen_sock.close()
633
 
        return server_sock, client_sock
 
641
    def create_pipe_medium(self, to_server, from_server, transport,
 
642
                           timeout=4.0):
 
643
        """Create a new SmartServerPipeStreamMedium."""
 
644
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
645
            transport, timeout=timeout)
 
646
 
 
647
    def create_pipe_context(self, to_server_bytes, transport):
 
648
        """Create a SmartServerSocketStreamMedium.
 
649
 
 
650
        This differes from create_pipe_medium, in that we initialize the
 
651
        request that is sent to the server, and return the StringIO class that
 
652
        will hold the response.
 
653
        """
 
654
        to_server = StringIO(to_server_bytes)
 
655
        from_server = StringIO()
 
656
        m = self.create_pipe_medium(to_server, from_server, transport)
 
657
        return m, from_server
 
658
 
 
659
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
660
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
661
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
662
            timeout=timeout)
 
663
 
 
664
    def create_socket_context(self, transport, timeout=4.0):
 
665
        """Create a new SmartServerSocketStreamMedium with default context.
 
666
 
 
667
        This will call portable_socket_pair and pass the server side to
 
668
        create_socket_medium along with transport.
 
669
        It then returns the client_sock and the server.
 
670
        """
 
671
        server_sock, client_sock = portable_socket_pair()
 
672
        server = self.create_socket_medium(server_sock, transport,
 
673
                                           timeout=timeout)
 
674
        return server, client_sock
634
675
 
635
676
    def test_smart_query_version(self):
636
677
        """Feed a canned query version to a server"""
637
678
        # wire-to-wire, using the whole stack
638
 
        to_server = StringIO('hello\n')
639
 
        from_server = StringIO()
640
679
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
641
 
        server = medium.SmartServerPipeStreamMedium(
642
 
            to_server, from_server, transport)
 
680
        server, from_server = self.create_pipe_context('hello\n', transport)
643
681
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
644
682
                from_server.write)
645
683
        server._serve_one_request(smart_protocol)
649
687
    def test_response_to_canned_get(self):
650
688
        transport = memory.MemoryTransport('memory:///')
651
689
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
652
 
        to_server = StringIO('get\001./testfile\n')
653
 
        from_server = StringIO()
654
 
        server = medium.SmartServerPipeStreamMedium(
655
 
            to_server, from_server, transport)
 
690
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
691
            transport)
656
692
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
657
693
                from_server.write)
658
694
        server._serve_one_request(smart_protocol)
669
705
        # VFS requests use filenames, not raw UTF-8.
670
706
        hpss_path = urlutils.escape(utf8_filename)
671
707
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
672
 
        to_server = StringIO('get\001' + hpss_path + '\n')
673
 
        from_server = StringIO()
674
 
        server = medium.SmartServerPipeStreamMedium(
675
 
            to_server, from_server, transport)
 
708
        server, from_server = self.create_pipe_context(
 
709
                'get\001' + hpss_path + '\n', transport)
676
710
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
677
711
                from_server.write)
678
712
        server._serve_one_request(smart_protocol)
684
718
 
685
719
    def test_pipe_like_stream_with_bulk_data(self):
686
720
        sample_request_bytes = 'command\n9\nbulk datadone\n'
687
 
        to_server = StringIO(sample_request_bytes)
688
 
        from_server = StringIO()
689
 
        server = medium.SmartServerPipeStreamMedium(
690
 
            to_server, from_server, None)
 
721
        server, from_server = self.create_pipe_context(
 
722
            sample_request_bytes, None)
691
723
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
692
724
        server._serve_one_request(sample_protocol)
693
725
        self.assertEqual('', from_server.getvalue())
696
728
 
697
729
    def test_socket_stream_with_bulk_data(self):
698
730
        sample_request_bytes = 'command\n9\nbulk datadone\n'
699
 
        server_sock, client_sock = self.portable_socket_pair()
700
 
        server = medium.SmartServerSocketStreamMedium(
701
 
            server_sock, None)
 
731
        server, client_sock = self.create_socket_context(None)
702
732
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
703
733
        client_sock.sendall(sample_request_bytes)
704
734
        server._serve_one_request(sample_protocol)
705
 
        server_sock.close()
 
735
        server._disconnect_client()
706
736
        self.assertEqual('', client_sock.recv(1))
707
737
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
708
738
        self.assertFalse(server.finished)
709
739
 
710
740
    def test_pipe_like_stream_shutdown_detection(self):
711
 
        to_server = StringIO('')
712
 
        from_server = StringIO()
713
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
741
        server, _ = self.create_pipe_context('', None)
714
742
        server._serve_one_request(SampleRequest('x'))
715
743
        self.assertTrue(server.finished)
716
744
 
717
745
    def test_socket_stream_shutdown_detection(self):
718
 
        server_sock, client_sock = self.portable_socket_pair()
 
746
        server, client_sock = self.create_socket_context(None)
719
747
        client_sock.close()
720
 
        server = medium.SmartServerSocketStreamMedium(
721
 
            server_sock, None)
722
748
        server._serve_one_request(SampleRequest('x'))
723
749
        self.assertTrue(server.finished)
724
750
 
735
761
        rest_of_request_bytes = 'lo\n'
736
762
        expected_response = (
737
763
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
738
 
        server_sock, client_sock = self.portable_socket_pair()
739
 
        server = medium.SmartServerSocketStreamMedium(
740
 
            server_sock, None)
 
764
        server, client_sock = self.create_socket_context(None)
741
765
        client_sock.sendall(incomplete_request_bytes)
742
766
        server_protocol = server._build_protocol()
743
767
        client_sock.sendall(rest_of_request_bytes)
744
768
        server._serve_one_request(server_protocol)
745
 
        server_sock.close()
 
769
        server._disconnect_client()
746
770
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
747
771
                         "Not a version 2 response to 'hello' request.")
748
772
        self.assertEqual('', client_sock.recv(1))
767
791
        to_server_w = os.fdopen(to_server_w, 'w', 0)
768
792
        from_server_r = os.fdopen(from_server_r, 'r', 0)
769
793
        from_server = os.fdopen(from_server, 'w', 0)
770
 
        server = medium.SmartServerPipeStreamMedium(
771
 
            to_server, from_server, None)
 
794
        server = self.create_pipe_medium(to_server, from_server, None)
772
795
        # Like test_socket_stream_incomplete_request, write an incomplete
773
796
        # request (that does not end in '\n') and build a protocol from it.
774
797
        to_server_w.write(incomplete_request_bytes)
789
812
        # _serve_one_request should still process both of them as if they had
790
813
        # been received separately.
791
814
        sample_request_bytes = 'command\n'
792
 
        to_server = StringIO(sample_request_bytes * 2)
793
 
        from_server = StringIO()
794
 
        server = medium.SmartServerPipeStreamMedium(
795
 
            to_server, from_server, None)
 
815
        server, from_server = self.create_pipe_context(
 
816
            sample_request_bytes * 2, None)
796
817
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
797
818
        server._serve_one_request(first_protocol)
798
819
        self.assertEqual(0, first_protocol.next_read_size())
811
832
        # _serve_one_request should still process both of them as if they had
812
833
        # been received separately.
813
834
        sample_request_bytes = 'command\n'
814
 
        server_sock, client_sock = self.portable_socket_pair()
815
 
        server = medium.SmartServerSocketStreamMedium(
816
 
            server_sock, None)
 
835
        server, client_sock = self.create_socket_context(None)
817
836
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
818
837
        # Put two whole requests on the wire.
819
838
        client_sock.sendall(sample_request_bytes * 2)
826
845
        stream_still_open = server._serve_one_request(second_protocol)
827
846
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
828
847
        self.assertFalse(server.finished)
829
 
        server_sock.close()
 
848
        server._disconnect_client()
830
849
        self.assertEqual('', client_sock.recv(1))
831
850
 
832
851
    def test_pipe_like_stream_error_handling(self):
839
858
        def close():
840
859
            self.closed = True
841
860
        from_server.close = close
842
 
        server = medium.SmartServerPipeStreamMedium(
 
861
        server = self.create_pipe_medium(
843
862
            to_server, from_server, None)
844
863
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
845
864
        server._serve_one_request(fake_protocol)
848
867
        self.assertTrue(server.finished)
849
868
 
850
869
    def test_socket_stream_error_handling(self):
851
 
        server_sock, client_sock = self.portable_socket_pair()
852
 
        server = medium.SmartServerSocketStreamMedium(
853
 
            server_sock, None)
 
870
        server, client_sock = self.create_socket_context(None)
854
871
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
855
872
        server._serve_one_request(fake_protocol)
856
873
        # recv should not block, because the other end of the socket has been
859
876
        self.assertTrue(server.finished)
860
877
 
861
878
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
862
 
        to_server = StringIO('')
863
 
        from_server = StringIO()
864
 
        server = medium.SmartServerPipeStreamMedium(
865
 
            to_server, from_server, None)
 
879
        server, from_server = self.create_pipe_context('', None)
866
880
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
867
881
        self.assertRaises(
868
882
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
869
883
        self.assertEqual('', from_server.getvalue())
870
884
 
871
885
    def test_socket_stream_keyboard_interrupt_handling(self):
872
 
        server_sock, client_sock = self.portable_socket_pair()
873
 
        server = medium.SmartServerSocketStreamMedium(
874
 
            server_sock, None)
 
886
        server, client_sock = self.create_socket_context(None)
875
887
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
876
888
        self.assertRaises(
877
889
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
878
 
        server_sock.close()
 
890
        server._disconnect_client()
879
891
        self.assertEqual('', client_sock.recv(1))
880
892
 
881
893
    def build_protocol_pipe_like(self, bytes):
882
 
        to_server = StringIO(bytes)
883
 
        from_server = StringIO()
884
 
        server = medium.SmartServerPipeStreamMedium(
885
 
            to_server, from_server, None)
 
894
        server, _ = self.create_pipe_context(bytes, None)
886
895
        return server._build_protocol()
887
896
 
888
897
    def build_protocol_socket(self, bytes):
889
 
        server_sock, client_sock = self.portable_socket_pair()
890
 
        server = medium.SmartServerSocketStreamMedium(
891
 
            server_sock, None)
 
898
        server, client_sock = self.create_socket_context(None)
892
899
        client_sock.sendall(bytes)
893
900
        client_sock.close()
894
901
        return server._build_protocol()
934
941
        server_protocol = self.build_protocol_socket('bzr request 2\n')
935
942
        self.assertProtocolTwo(server_protocol)
936
943
 
 
944
    def test__build_protocol_returns_if_stopping(self):
 
945
        # _build_protocol should notice that we are stopping, and return
 
946
        # without waiting for bytes from the client.
 
947
        server, client_sock = self.create_socket_context(None)
 
948
        server._stop_gracefully()
 
949
        self.assertIs(None, server._build_protocol())
 
950
 
 
951
    def test_socket_set_timeout(self):
 
952
        server, _ = self.create_socket_context(None, timeout=1.23)
 
953
        self.assertEqual(1.23, server._client_timeout)
 
954
 
 
955
    def test_pipe_set_timeout(self):
 
956
        server = self.create_pipe_medium(None, None, None,
 
957
            timeout=1.23)
 
958
        self.assertEqual(1.23, server._client_timeout)
 
959
 
 
960
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
961
        server, client_sock = self.create_socket_context(None)
 
962
        client_sock.sendall('data\n')
 
963
        # This should not block or consume any actual content
 
964
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
965
        data = server.read_bytes(5)
 
966
        self.assertEqual('data\n', data)
 
967
 
 
968
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
969
        server, client_sock = self.create_socket_context(None)
 
970
        # This should timeout quickly, reporting that there wasn't any data
 
971
        self.assertRaises(errors.ConnectionTimeout,
 
972
                          server._wait_for_bytes_with_timeout, 0.01)
 
973
        client_sock.close()
 
974
        data = server.read_bytes(1)
 
975
        self.assertEqual('', data)
 
976
 
 
977
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
978
        server, client_sock = self.create_socket_context(None)
 
979
        # With the socket closed, this should return right away.
 
980
        # It seems select.select() returns that you *can* read on the socket,
 
981
        # even though it closed. Presumably as a way to tell it is closed?
 
982
        # Testing shows that without sock.close() this times-out failing the
 
983
        # test, but with it, it returns False immediately.
 
984
        client_sock.close()
 
985
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
986
        data = server.read_bytes(1)
 
987
        self.assertEqual('', data)
 
988
 
 
989
    def test_socket_wait_for_bytes_with_shutdown(self):
 
990
        server, client_sock = self.create_socket_context(None)
 
991
        t = time.time()
 
992
        # Override the _timer functionality, so that time never increments,
 
993
        # this way, we can be sure we stopped because of the flag, and not
 
994
        # because of a timeout, etc.
 
995
        server._timer = lambda: t
 
996
        server._client_poll_timeout = 0.1
 
997
        server._stop_gracefully()
 
998
        server._wait_for_bytes_with_timeout(1.0)
 
999
 
 
1000
    def test_socket_serve_timeout_closes_socket(self):
 
1001
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1002
        # This should timeout quickly, and then close the connection so that
 
1003
        # client_sock recv doesn't block.
 
1004
        server.serve()
 
1005
        self.assertEqual('', client_sock.recv(1))
 
1006
 
 
1007
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1008
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1009
        # You can't select() on a StringIO
 
1010
        (r_server, w_client) = os.pipe()
 
1011
        self.addCleanup(os.close, w_client)
 
1012
        with os.fdopen(r_server, 'rb') as rf_server:
 
1013
            server = self.create_pipe_medium(
 
1014
                rf_server, None, None)
 
1015
            os.write(w_client, 'data\n')
 
1016
            # This should not block or consume any actual content
 
1017
            server._wait_for_bytes_with_timeout(0.1)
 
1018
            data = server.read_bytes(5)
 
1019
            self.assertEqual('data\n', data)
 
1020
 
 
1021
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1022
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1023
        # You can't select() on a StringIO
 
1024
        (r_server, w_client) = os.pipe()
 
1025
        # We can't add an os.close cleanup here, because we need to control
 
1026
        # when the file handle gets closed ourselves.
 
1027
        with os.fdopen(r_server, 'rb') as rf_server:
 
1028
            server = self.create_pipe_medium(
 
1029
                rf_server, None, None)
 
1030
            if sys.platform == 'win32':
 
1031
                # Windows cannot select() on a pipe, so we just always return
 
1032
                server._wait_for_bytes_with_timeout(0.01)
 
1033
            else:
 
1034
                self.assertRaises(errors.ConnectionTimeout,
 
1035
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1036
            os.close(w_client)
 
1037
            data = server.read_bytes(5)
 
1038
            self.assertEqual('', data)
 
1039
 
 
1040
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1041
        server, _ = self.create_pipe_context('', None)
 
1042
        # Our file doesn't support polling, so we should always just return
 
1043
        # 'you have data to consume.
 
1044
        server._wait_for_bytes_with_timeout(0.01)
 
1045
 
937
1046
 
938
1047
class TestGetProtocolFactoryForBytes(tests.TestCase):
939
1048
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
969
1078
 
970
1079
class TestSmartTCPServer(tests.TestCase):
971
1080
 
 
1081
    def make_server(self):
 
1082
        """Create a SmartTCPServer that we can exercise.
 
1083
 
 
1084
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1085
        version overrides lots of functionality like 'serve', and we want to
 
1086
        test the raw service.
 
1087
 
 
1088
        This will start the server in another thread, and wait for it to
 
1089
        indicate it has finished starting up.
 
1090
 
 
1091
        :return: (server, server_thread)
 
1092
        """
 
1093
        t = _mod_transport.get_transport_from_url('memory:///')
 
1094
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1095
        server._ACCEPT_TIMEOUT = 0.1
 
1096
        # We don't use 'localhost' because that might be an IPv6 address.
 
1097
        server.start_server('127.0.0.1', 0)
 
1098
        server_thread = threading.Thread(target=server.serve,
 
1099
                                         args=(self.id(),))
 
1100
        server_thread.start()
 
1101
        # Ensure this gets called at some point
 
1102
        self.addCleanup(server._stop_gracefully)
 
1103
        server._started.wait()
 
1104
        return server, server_thread
 
1105
 
 
1106
    def ensure_client_disconnected(self, client_sock):
 
1107
        """Ensure that a socket is closed, discarding all errors."""
 
1108
        try:
 
1109
            client_sock.close()
 
1110
        except Exception:
 
1111
            pass
 
1112
 
 
1113
    def connect_to_server(self, server):
 
1114
        """Create a client socket that can talk to the server."""
 
1115
        client_sock = socket.socket()
 
1116
        server_info = server._server_socket.getsockname()
 
1117
        client_sock.connect(server_info)
 
1118
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1119
        return client_sock
 
1120
 
 
1121
    def connect_to_server_and_hangup(self, server):
 
1122
        """Connect to the server, and then hang up.
 
1123
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1124
        """
 
1125
        # If the server has already signaled that the socket is closed, we
 
1126
        # don't need to try to connect to it. Not being set, though, the server
 
1127
        # might still close the socket while we try to connect to it. So we
 
1128
        # still have to catch the exception.
 
1129
        if server._stopped.isSet():
 
1130
            return
 
1131
        try:
 
1132
            client_sock = self.connect_to_server(server)
 
1133
            client_sock.close()
 
1134
        except socket.error, e:
 
1135
            # If the server has hung up already, that is fine.
 
1136
            pass
 
1137
 
 
1138
    def say_hello(self, client_sock):
 
1139
        """Send the 'hello' smart RPC, and expect the response."""
 
1140
        client_sock.send('hello\n')
 
1141
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1142
 
 
1143
    def shutdown_server_cleanly(self, server, server_thread):
 
1144
        server._stop_gracefully()
 
1145
        self.connect_to_server_and_hangup(server)
 
1146
        server._stopped.wait()
 
1147
        server._fully_stopped.wait()
 
1148
        server_thread.join()
 
1149
 
972
1150
    def test_get_error_unexpected(self):
973
1151
        """Error reported by server with no specific representation"""
974
1152
        self.overrideEnv('BZR_NO_SMART_VFS', None)
992
1170
                                t.get, 'something')
993
1171
        self.assertContainsRe(str(err), 'some random exception')
994
1172
 
 
1173
    def test_propagates_timeout(self):
 
1174
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1175
        server_sock, client_sock = portable_socket_pair()
 
1176
        handler = server._make_handler(server_sock)
 
1177
        self.assertEqual(1.23, handler._client_timeout)
 
1178
 
 
1179
    def test_serve_conn_tracks_connections(self):
 
1180
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1181
        server_sock, client_sock = portable_socket_pair()
 
1182
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1183
        self.assertEqual(1, len(server._active_connections))
 
1184
        # We still want to talk on the connection. Polling should indicate it
 
1185
        # is still active.
 
1186
        server._poll_active_connections()
 
1187
        self.assertEqual(1, len(server._active_connections))
 
1188
        # Closing the socket will end the active thread, and polling will
 
1189
        # notice and remove it from the active set.
 
1190
        client_sock.close()
 
1191
        server._poll_active_connections(0.1)
 
1192
        self.assertEqual(0, len(server._active_connections))
 
1193
 
 
1194
    def test_serve_closes_out_finished_connections(self):
 
1195
        server, server_thread = self.make_server()
 
1196
        # The server is started, connect to it.
 
1197
        client_sock = self.connect_to_server(server)
 
1198
        # We send and receive on the connection, so that we know the
 
1199
        # server-side has seen the connect, and started handling the
 
1200
        # results.
 
1201
        self.say_hello(client_sock)
 
1202
        self.assertEqual(1, len(server._active_connections))
 
1203
        # Grab a handle to the thread that is processing our request
 
1204
        _, server_side_thread = server._active_connections[0]
 
1205
        # Close the connection, ask the server to stop, and wait for the
 
1206
        # server to stop, as well as the thread that was servicing the
 
1207
        # client request.
 
1208
        client_sock.close()
 
1209
        # Wait for the server-side request thread to notice we are closed.
 
1210
        server_side_thread.join()
 
1211
        # Stop the server, it should notice the connection has finished.
 
1212
        self.shutdown_server_cleanly(server, server_thread)
 
1213
        # The server should have noticed that all clients are gone before
 
1214
        # exiting.
 
1215
        self.assertEqual(0, len(server._active_connections))
 
1216
 
 
1217
    def test_serve_reaps_finished_connections(self):
 
1218
        server, server_thread = self.make_server()
 
1219
        client_sock1 = self.connect_to_server(server)
 
1220
        # We send and receive on the connection, so that we know the
 
1221
        # server-side has seen the connect, and started handling the
 
1222
        # results.
 
1223
        self.say_hello(client_sock1)
 
1224
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1225
        client_sock1.close()
 
1226
        server_side_thread1.join()
 
1227
        # By waiting until the first connection is fully done, the server
 
1228
        # should notice after another connection that the first has finished.
 
1229
        client_sock2 = self.connect_to_server(server)
 
1230
        self.say_hello(client_sock2)
 
1231
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1232
        # There is a race condition. We know that client_sock2 has been
 
1233
        # registered, but not that _poll_active_connections has been called. We
 
1234
        # know that it will be called before the server will accept a new
 
1235
        # connection, however. So connect one more time, and assert that we
 
1236
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1237
        # connection is not connection 1
 
1238
        client_sock3 = self.connect_to_server(server)
 
1239
        self.say_hello(client_sock3)
 
1240
        # Copy the list, so we don't have it mutating behind our back
 
1241
        conns = list(server._active_connections)
 
1242
        self.assertEqual(2, len(conns))
 
1243
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1244
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1245
        client_sock2.close()
 
1246
        client_sock3.close()
 
1247
        self.shutdown_server_cleanly(server, server_thread)
 
1248
 
 
1249
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1250
        server, server_thread = self.make_server()
 
1251
        # We need something big enough that it won't fit in a single recv. So
 
1252
        # the server thread gets blocked writing content to the client until we
 
1253
        # finish reading on the client.
 
1254
        server.backing_transport.put_bytes('bigfile',
 
1255
            'a'*1024*1024)
 
1256
        client_sock = self.connect_to_server(server)
 
1257
        self.say_hello(client_sock)
 
1258
        _, server_side_thread = server._active_connections[0]
 
1259
        # Start the RPC, but don't finish reading the response
 
1260
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1261
            'base', client_sock)
 
1262
        client_client = client._SmartClient(client_medium)
 
1263
        resp, response_handler = client_client.call_expecting_body('get',
 
1264
            'bigfile')
 
1265
        self.assertEqual(('ok',), resp)
 
1266
        # Ask the server to stop gracefully, and wait for it.
 
1267
        server._stop_gracefully()
 
1268
        self.connect_to_server_and_hangup(server)
 
1269
        server._stopped.wait()
 
1270
        # It should not be accepting another connection.
 
1271
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1272
        # It should also not be fully stopped
 
1273
        server._fully_stopped.wait(0.01)
 
1274
        self.assertFalse(server._fully_stopped.isSet())
 
1275
        response_handler.read_body_bytes()
 
1276
        client_sock.close()
 
1277
        server_side_thread.join()
 
1278
        server_thread.join()
 
1279
        self.assertTrue(server._fully_stopped.isSet())
 
1280
        log = self.get_log()
 
1281
        self.assertThat(log, DocTestMatches("""\
 
1282
    INFO  Requested to stop gracefully
 
1283
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1284
    INFO  Waiting for 1 client(s) to finish
 
1285
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1286
 
 
1287
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1288
        server, server_thread = self.make_server()
 
1289
        client_sock = self.connect_to_server(server)
 
1290
        self.say_hello(client_sock)
 
1291
        server_handler, server_side_thread = server._active_connections[0]
 
1292
        self.assertFalse(server_handler.finished)
 
1293
        server._stop_gracefully()
 
1294
        self.assertTrue(server_handler.finished)
 
1295
        client_sock.close()
 
1296
        self.connect_to_server_and_hangup(server)
 
1297
        server_thread.join()
 
1298
 
995
1299
 
996
1300
class SmartTCPTests(tests.TestCase):
997
1301
    """Tests for connection/end to end behaviour using the TCP server.
1023
1327
            self.real_backing_transport = self.backing_transport
1024
1328
            self.backing_transport = _mod_transport.get_transport_from_url(
1025
1329
                "readonly+" + self.backing_transport.abspath('.'))
1026
 
        self.server = _mod_server.SmartTCPServer(self.backing_transport)
 
1330
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1331
                                                 client_timeout=4.0)
1027
1332
        self.server.start_server('127.0.0.1', 0)
1028
1333
        self.server.start_background_thread('-' + self.id())
1029
1334
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
2545
2850
        from_server = StringIO()
2546
2851
        transport = memory.MemoryTransport('memory:///')
2547
2852
        server = medium.SmartServerPipeStreamMedium(
2548
 
            to_server, from_server, transport)
 
2853
            to_server, from_server, transport, timeout=4.0)
2549
2854
        proto = server._build_protocol()
2550
2855
        message_handler = proto.message_handler
2551
2856
        server._serve_one_request(proto)