~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_test_server.py

Merge bzr.dev into cleanup

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import threading
 
21
 
 
22
from bzrlib import (
 
23
    osutils,
 
24
    tests,
 
25
    )
 
26
from bzrlib.tests import test_server
 
27
 
 
28
 
 
29
def load_tests(basic_tests, module, loader):
 
30
    suite = loader.suiteClass()
 
31
    server_tests, remaining_tests = tests.split_suite_by_condition(
 
32
        basic_tests, tests.condition_isinstance(TestTCPServerInAThread))
 
33
    server_scenarios = [ (name, {'server_class': getattr(test_server, name)})
 
34
                         for name in
 
35
                         ('TestingTCPServer', 'TestingThreadingTCPServer')]
 
36
    tests.multiply_tests(server_tests, server_scenarios, suite)
 
37
    suite.addTest(remaining_tests)
 
38
    return suite
 
39
 
 
40
 
 
41
class TestThreadWithException(tests.TestCase):
 
42
 
 
43
    def test_start_and_join_smoke_test(self):
 
44
        def do_nothing():
 
45
            pass
 
46
 
 
47
        tt = test_server.ThreadWithException(target=do_nothing)
 
48
        tt.start()
 
49
        tt.join()
 
50
 
 
51
    def test_exception_is_re_raised(self):
 
52
        class MyException(Exception):
 
53
            pass
 
54
 
 
55
        def raise_my_exception():
 
56
            raise MyException()
 
57
 
 
58
        tt = test_server.ThreadWithException(target=raise_my_exception)
 
59
        tt.start()
 
60
        self.assertRaises(MyException, tt.join)
 
61
 
 
62
    def test_join_when_no_exception(self):
 
63
        resume = threading.Event()
 
64
        class MyException(Exception):
 
65
            pass
 
66
 
 
67
        def raise_my_exception():
 
68
            # Wait for the test to tell us to resume
 
69
            resume.wait()
 
70
            # Now we can raise
 
71
            raise MyException()
 
72
 
 
73
        tt = test_server.ThreadWithException(target=raise_my_exception)
 
74
        tt.start()
 
75
        tt.join(timeout=0)
 
76
        self.assertIs(None, tt.exception)
 
77
        resume.set()
 
78
        self.assertRaises(MyException, tt.join)
 
79
 
 
80
 
 
81
class TCPClient(object):
 
82
 
 
83
    def __init__(self):
 
84
        self.sock = None
 
85
 
 
86
    def connect(self, addr):
 
87
        if self.sock is not None:
 
88
            raise AssertionError('Already connected to %r'
 
89
                                 % (self.sock.getsockname(),))
 
90
        self.sock = osutils.connect_socket(addr)
 
91
 
 
92
    def disconnect(self):
 
93
        if self.sock is not None:
 
94
            try:
 
95
                self.sock.shutdown(socket.SHUT_RDWR)
 
96
                self.sock.close()
 
97
            except socket.error, e:
 
98
                if e[0] in (errno.EBADF, errno.ENOTCONN):
 
99
                    # Right, the socket is already down
 
100
                    pass
 
101
                else:
 
102
                    raise
 
103
            self.sock = None
 
104
 
 
105
    def write(self, s):
 
106
        return self.sock.sendall(s)
 
107
 
 
108
    def read(self, bufsize=4096):
 
109
        return self.sock.recv(bufsize)
 
110
 
 
111
 
 
112
class TCPConnectionHandler(SocketServer.StreamRequestHandler):
 
113
 
 
114
    def handle(self):
 
115
        self.done = False
 
116
        self.handle_connection()
 
117
        while not self.done:
 
118
            self.handle_connection()
 
119
 
 
120
    def handle_connection(self):
 
121
        req = self.rfile.readline()
 
122
        if not req:
 
123
            self.done = True
 
124
        elif req == 'ping\n':
 
125
            self.wfile.write('pong\n')
 
126
        else:
 
127
            raise ValueError('[%s] not understood' % req)
 
128
 
 
129
 
 
130
class TestTCPServerInAThread(tests.TestCase):
 
131
 
 
132
    # Set by load_tests()
 
133
    server_class = None
 
134
 
 
135
    def get_server(self, server_class=None, connection_handler_class=None):
 
136
        if server_class is not None:
 
137
            self.server_class = server_class
 
138
        if connection_handler_class is None:
 
139
            connection_handler_class = TCPConnectionHandler
 
140
        server =  test_server.TestingTCPServerInAThread(
 
141
            ('localhost', 0), self.server_class, connection_handler_class)
 
142
        server.start_server()
 
143
        self.addCleanup(server.stop_server)
 
144
        return server
 
145
 
 
146
    def get_client(self):
 
147
        client = TCPClient()
 
148
        self.addCleanup(client.disconnect)
 
149
        return client
 
150
 
 
151
    def get_server_connection(self, server, conn_rank):
 
152
        return server.server.clients[conn_rank]
 
153
 
 
154
    def assertClientAddr(self, client, server, conn_rank):
 
155
        conn = self.get_server_connection(server, conn_rank)
 
156
        self.assertEquals(client.sock.getsockname(), conn[1])
 
157
 
 
158
    def test_start_stop(self):
 
159
        server = self.get_server()
 
160
        client = self.get_client()
 
161
        server.stop_server()
 
162
        # since the server doesn't accept connections anymore attempting to
 
163
        # connect should fail
 
164
        client = self.get_client()
 
165
        self.assertRaises(socket.error,
 
166
                          client.connect, (server.host, server.port))
 
167
 
 
168
    def test_client_talks_server_respond(self):
 
169
        server = self.get_server()
 
170
        client = self.get_client()
 
171
        client.connect((server.host, server.port))
 
172
        self.assertIs(None, client.write('ping\n'))
 
173
        resp = client.read()
 
174
        self.assertClientAddr(client, server, 0)
 
175
        self.assertEquals('pong\n', resp)
 
176
 
 
177
    def test_server_fails_to_start(self):
 
178
        class CantStart(Exception):
 
179
            pass
 
180
 
 
181
        class CantStartServer(test_server.TestingTCPServer):
 
182
 
 
183
            def server_bind(self):
 
184
                raise CantStart()
 
185
 
 
186
        # The exception is raised in the main thread
 
187
        self.assertRaises(CantStart,
 
188
                          self.get_server, server_class=CantStartServer)
 
189
 
 
190
    def test_server_fails_while_serving_or_stopping(self):
 
191
        class CantConnect(Exception):
 
192
            pass
 
193
 
 
194
        class FailingConnectionHandler(TCPConnectionHandler):
 
195
 
 
196
            def handle(self):
 
197
                raise CantConnect()
 
198
 
 
199
        server = self.get_server(
 
200
            connection_handler_class=FailingConnectionHandler)
 
201
        # The server won't fail until a client connect
 
202
        client = self.get_client()
 
203
        client.connect((server.host, server.port))
 
204
        try:
 
205
            # Now we must force the server to answer by sending the request and
 
206
            # waiting for some answer. But since we don't control when the
 
207
            # server thread will be given cycles, we don't control either
 
208
            # whether our reads or writes may hang.
 
209
            client.sock.settimeout(0.1)
 
210
            client.write('ping\n')
 
211
            client.read()
 
212
        except socket.error:
 
213
            pass
 
214
        # Now the server has raised the exception in its own thread
 
215
        self.assertRaises(CantConnect, server.stop_server)
 
216
 
 
217
    def test_server_crash_while_responding(self):
 
218
        sync = threading.Event()
 
219
        sync.clear()
 
220
        class FailToRespond(Exception):
 
221
            pass
 
222
 
 
223
        class FailingDuringResponseHandler(TCPConnectionHandler):
 
224
 
 
225
            def handle_connection(self):
 
226
                req = self.rfile.readline()
 
227
                threading.currentThread().set_ready_event(sync)
 
228
                raise FailToRespond()
 
229
 
 
230
        server = self.get_server(
 
231
            connection_handler_class=FailingDuringResponseHandler)
 
232
        client = self.get_client()
 
233
        client.connect((server.host, server.port))
 
234
        client.write('ping\n')
 
235
        sync.wait()
 
236
        self.assertRaises(FailToRespond, server.pending_exception)
 
237
 
 
238
    def test_exception_swallowed_while_serving(self):
 
239
        sync = threading.Event()
 
240
        sync.clear()
 
241
        class CantServe(Exception):
 
242
            pass
 
243
 
 
244
        class FailingWhileServingConnectionHandler(TCPConnectionHandler):
 
245
 
 
246
            def handle(self):
 
247
                # We want to sync with the thread that is serving the
 
248
                # connection.
 
249
                threading.currentThread().set_ready_event(sync)
 
250
                raise CantServe()
 
251
 
 
252
        server = self.get_server(
 
253
            connection_handler_class=FailingWhileServingConnectionHandler)
 
254
        # Install the exception swallower
 
255
        server.set_ignored_exceptions(CantServe)
 
256
        client = self.get_client()
 
257
        # Connect to the server so the exception is raised there
 
258
        client.connect((server.host, server.port))
 
259
        # Wait for the exception to propagate.
 
260
        sync.wait()
 
261
        # The connection wasn't served properly but the exception should have
 
262
        # been swallowed.
 
263
        server.pending_exception()