~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_test_server.py

Merge bzr.dev, update to use new hooks.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010 Canonical Ltd
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
19
19
import SocketServer
20
20
import threading
21
21
 
 
22
 
22
23
from bzrlib import (
23
24
    osutils,
24
25
    tests,
25
26
    )
26
27
from bzrlib.tests import test_server
27
 
 
28
 
 
29
 
def load_tests(basic_tests, module, loader):
30
 
    suite = loader.suiteClass()
31
 
    server_tests, remaining_tests = tests.split_suite_by_condition(
32
 
        basic_tests, tests.condition_isinstance(TestTCPServerInAThread))
33
 
    server_scenarios = [ (name, {'server_class': getattr(test_server, name)})
34
 
                         for name in
35
 
                         ('TestingTCPServer', 'TestingThreadingTCPServer')]
36
 
    tests.multiply_tests(server_tests, server_scenarios, suite)
37
 
    suite.addTest(remaining_tests)
38
 
    return suite
39
 
 
40
 
 
41
 
class TestThreadWithException(tests.TestCase):
42
 
 
43
 
    def test_start_and_join_smoke_test(self):
44
 
        def do_nothing():
45
 
            pass
46
 
 
47
 
        tt = test_server.ThreadWithException(target=do_nothing)
48
 
        tt.start()
49
 
        tt.join()
50
 
 
51
 
    def test_exception_is_re_raised(self):
52
 
        class MyException(Exception):
53
 
            pass
54
 
 
55
 
        def raise_my_exception():
56
 
            raise MyException()
57
 
 
58
 
        tt = test_server.ThreadWithException(target=raise_my_exception)
59
 
        tt.start()
60
 
        self.assertRaises(MyException, tt.join)
61
 
 
62
 
    def test_join_when_no_exception(self):
63
 
        resume = threading.Event()
64
 
        class MyException(Exception):
65
 
            pass
66
 
 
67
 
        def raise_my_exception():
68
 
            # Wait for the test to tell us to resume
69
 
            resume.wait()
70
 
            # Now we can raise
71
 
            raise MyException()
72
 
 
73
 
        tt = test_server.ThreadWithException(target=raise_my_exception)
74
 
        tt.start()
75
 
        tt.join(timeout=0)
76
 
        self.assertIs(None, tt.exception)
77
 
        resume.set()
78
 
        self.assertRaises(MyException, tt.join)
 
28
from bzrlib.tests.scenarios import load_tests_apply_scenarios
 
29
 
 
30
 
 
31
load_tests = load_tests_apply_scenarios
 
32
 
 
33
 
 
34
def portable_socket_pair():
 
35
    """Return a pair of TCP sockets connected to each other.
 
36
 
 
37
    Unlike socket.socketpair, this should work on Windows.
 
38
    """
 
39
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
40
    listen_sock.bind(('127.0.0.1', 0))
 
41
    listen_sock.listen(1)
 
42
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
43
    client_sock.connect(listen_sock.getsockname())
 
44
    server_sock, addr = listen_sock.accept()
 
45
    listen_sock.close()
 
46
    return server_sock, client_sock
79
47
 
80
48
 
81
49
class TCPClient(object):
109
77
        return self.sock.recv(bufsize)
110
78
 
111
79
 
112
 
class TCPConnectionHandler(SocketServer.StreamRequestHandler):
 
80
class TCPConnectionHandler(SocketServer.BaseRequestHandler):
113
81
 
114
82
    def handle(self):
115
83
        self.done = False
117
85
        while not self.done:
118
86
            self.handle_connection()
119
87
 
 
88
    def readline(self):
 
89
        # TODO: We should be buffering any extra data sent, etc. However, in
 
90
        #       practice, we don't send extra content, so we haven't bothered
 
91
        #       to implement it yet.
 
92
        req = self.request.recv(4096)
 
93
        # An empty string is allowed, to indicate the end of the connection
 
94
        if not req or (req.endswith('\n') and req.count('\n') == 1):
 
95
            return req
 
96
        raise ValueError('[%r] not a simple line' % (req,))
 
97
 
120
98
    def handle_connection(self):
121
 
        req = self.rfile.readline()
 
99
        req = self.readline()
122
100
        if not req:
123
101
            self.done = True
124
102
        elif req == 'ping\n':
125
 
            self.wfile.write('pong\n')
 
103
            self.request.sendall('pong\n')
126
104
        else:
127
105
            raise ValueError('[%s] not understood' % req)
128
106
 
129
107
 
130
108
class TestTCPServerInAThread(tests.TestCase):
131
109
 
132
 
    # Set by load_tests()
133
 
    server_class = None
 
110
    scenarios = [
 
111
        (name, {'server_class': getattr(test_server, name)})
 
112
        for name in
 
113
        ('TestingTCPServer', 'TestingThreadingTCPServer')]
134
114
 
135
115
    def get_server(self, server_class=None, connection_handler_class=None):
136
116
        if server_class is not None:
137
117
            self.server_class = server_class
138
118
        if connection_handler_class is None:
139
119
            connection_handler_class = TCPConnectionHandler
140
 
        server =  test_server.TestingTCPServerInAThread(
 
120
        server = test_server.TestingTCPServerInAThread(
141
121
            ('localhost', 0), self.server_class, connection_handler_class)
142
122
        server.start_server()
143
123
        self.addCleanup(server.stop_server)
201
181
        # The server won't fail until a client connect
202
182
        client = self.get_client()
203
183
        client.connect((server.host, server.port))
 
184
        # We make sure the server wants to handle a request, but the request is
 
185
        # guaranteed to fail. However, the server should make sure that the
 
186
        # connection gets closed, and stop_server should then raise the
 
187
        # original exception.
 
188
        client.write('ping\n')
204
189
        try:
205
 
            # Now we must force the server to answer by sending the request and
206
 
            # waiting for some answer. But since we don't control when the
207
 
            # server thread will be given cycles, we don't control either
208
 
            # whether our reads or writes may hang.
209
 
            client.sock.settimeout(0.1)
210
 
            client.write('ping\n')
211
 
            client.read()
212
 
        except socket.error:
213
 
            pass
 
190
            self.assertEqual('', client.read())
 
191
        except socket.error, e:
 
192
            # On Windows, failing during 'handle' means we get
 
193
            # 'forced-close-of-connection'. Possibly because we haven't
 
194
            # processed the write request before we close the socket.
 
195
            WSAECONNRESET = 10054
 
196
            if e.errno in (WSAECONNRESET,):
 
197
                pass
214
198
        # Now the server has raised the exception in its own thread
215
199
        self.assertRaises(CantConnect, server.stop_server)
216
200
 
217
201
    def test_server_crash_while_responding(self):
218
 
        sync = threading.Event()
219
 
        sync.clear()
 
202
        # We want to ensure the exception has been caught
 
203
        caught = threading.Event()
 
204
        caught.clear()
 
205
        # The thread that will serve the client, this needs to be an attribute
 
206
        # so the handler below can modify it when it's executed (it's
 
207
        # instantiated when the request is processed)
 
208
        self.connection_thread = None
 
209
 
220
210
        class FailToRespond(Exception):
221
211
            pass
222
212
 
223
213
        class FailingDuringResponseHandler(TCPConnectionHandler):
224
214
 
225
 
            def handle_connection(self):
226
 
                req = self.rfile.readline()
227
 
                threading.currentThread().set_ready_event(sync)
 
215
            # We use 'request' instead of 'self' below because the test matters
 
216
            # more and we need a container to properly set connection_thread.
 
217
            def handle_connection(request):
 
218
                req = request.readline()
 
219
                # Capture the thread and make it use 'caught' so we can wait on
 
220
                # the event that will be set when the exception is caught. We
 
221
                # also capture the thread to know where to look.
 
222
                self.connection_thread = threading.currentThread()
 
223
                self.connection_thread.set_sync_event(caught)
228
224
                raise FailToRespond()
229
225
 
230
226
        server = self.get_server(
232
228
        client = self.get_client()
233
229
        client.connect((server.host, server.port))
234
230
        client.write('ping\n')
235
 
        sync.wait()
236
 
        self.assertRaises(FailToRespond, server.pending_exception)
 
231
        # Wait for the exception to be caught
 
232
        caught.wait()
 
233
        self.assertEqual('', client.read()) # connection closed
 
234
        # Check that the connection thread did catch the exception,
 
235
        # http://pad.lv/869366 was wrongly checking the server thread which
 
236
        # works for TestingTCPServer where the connection is handled in the
 
237
        # same thread than the server one but was racy for
 
238
        # TestingThreadingTCPServer. Since the connection thread detaches
 
239
        # itself before handling the request, we are guaranteed that the
 
240
        # exception won't leak into the server thread anymore.
 
241
        self.assertRaises(FailToRespond,
 
242
                          self.connection_thread.pending_exception)
237
243
 
238
244
    def test_exception_swallowed_while_serving(self):
239
 
        sync = threading.Event()
240
 
        sync.clear()
 
245
        # We need to ensure the exception has been caught
 
246
        caught = threading.Event()
 
247
        caught.clear()
 
248
        # The thread that will serve the client, this needs to be an attribute
 
249
        # so the handler below can access it when it's executed (it's
 
250
        # instantiated when the request is processed)
 
251
        self.connection_thread = None
241
252
        class CantServe(Exception):
242
253
            pass
243
254
 
244
255
        class FailingWhileServingConnectionHandler(TCPConnectionHandler):
245
256
 
246
 
            def handle(self):
247
 
                # We want to sync with the thread that is serving the
248
 
                # connection.
249
 
                threading.currentThread().set_ready_event(sync)
 
257
            # We use 'request' instead of 'self' below because the test matters
 
258
            # more and we need a container to properly set connection_thread.
 
259
            def handle(request):
 
260
                # Capture the thread and make it use 'caught' so we can wait on
 
261
                # the event that will be set when the exception is caught. We
 
262
                # also capture the thread to know where to look.
 
263
                self.connection_thread = threading.currentThread()
 
264
                self.connection_thread.set_sync_event(caught)
250
265
                raise CantServe()
251
266
 
252
267
        server = self.get_server(
253
268
            connection_handler_class=FailingWhileServingConnectionHandler)
 
269
        self.assertEquals(True, server.server.serving)
254
270
        # Install the exception swallower
255
271
        server.set_ignored_exceptions(CantServe)
256
272
        client = self.get_client()
257
273
        # Connect to the server so the exception is raised there
258
274
        client.connect((server.host, server.port))
259
 
        # Wait for the exception to propagate.
260
 
        sync.wait()
 
275
        # Wait for the exception to be caught
 
276
        caught.wait()
 
277
        self.assertEqual('', client.read()) # connection closed
261
278
        # The connection wasn't served properly but the exception should have
262
 
        # been swallowed.
263
 
        server.pending_exception()
 
279
        # been swallowed (see test_server_crash_while_responding remark about
 
280
        # http://pad.lv/869366 explaining why we can't check the server thread
 
281
        # here). More precisely, the exception *has* been caught and captured
 
282
        # but it is cleared when joining the thread (or trying to acquire the
 
283
        # exception) and as such won't propagate to the server thread.
 
284
        self.assertIs(None, self.connection_thread.pending_exception())
 
285
        self.assertIs(None, server.pending_exception())
 
286
 
 
287
    def test_handle_request_closes_if_it_doesnt_process(self):
 
288
        server = self.get_server()
 
289
        client = self.get_client()
 
290
        server.server.serving = False
 
291
        client.connect((server.host, server.port))
 
292
        self.assertEqual('', client.read())
 
293
 
 
294
 
 
295
class TestTestingSmartServer(tests.TestCase):
 
296
 
 
297
    def test_sets_client_timeout(self):
 
298
        server = test_server.TestingSmartServer(('localhost', 0), None, None,
 
299
            root_client_path='/no-such-client/path')
 
300
        self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
 
301
                         server._client_timeout)
 
302
        sock = socket.socket()
 
303
        h = server._make_handler(sock)
 
304
        self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
 
305
                         h._client_timeout)
 
306
 
 
307
 
 
308
class FakeServer(object):
 
309
    """Minimal implementation to pass to TestingSmartConnectionHandler"""
 
310
    backing_transport = None
 
311
    root_client_path = '/'
 
312
 
 
313
 
 
314
class TestTestingSmartConnectionHandler(tests.TestCase):
 
315
 
 
316
    def test_connection_timeout_suppressed(self):
 
317
        self.overrideAttr(test_server, '_DEFAULT_TESTING_CLIENT_TIMEOUT', 0.01)
 
318
        s = FakeServer()
 
319
        server_sock, client_sock = portable_socket_pair()
 
320
        # This should timeout quickly, but not generate an exception.
 
321
        handler = test_server.TestingSmartConnectionHandler(server_sock,
 
322
            server_sock.getpeername(), s)
 
323
 
 
324
    def test_connection_shutdown_while_serving_no_error(self):
 
325
        s = FakeServer()
 
326
        server_sock, client_sock = portable_socket_pair()
 
327
        class ShutdownConnectionHandler(
 
328
            test_server.TestingSmartConnectionHandler):
 
329
 
 
330
            def _build_protocol(self):
 
331
                self.finished = True
 
332
                return super(ShutdownConnectionHandler, self)._build_protocol()
 
333
        # This should trigger shutdown after the entering _build_protocol, and
 
334
        # we should exit cleanly, without raising an exception.
 
335
        handler = ShutdownConnectionHandler(server_sock,
 
336
            server_sock.getpeername(), s)