~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_test_server.py

  • Committer: Vincent Ladeuil
  • Date: 2012-01-18 14:09:19 UTC
  • mto: This revision was merged to the branch mainline in revision 6468.
  • Revision ID: v.ladeuil+lp@free.fr-20120118140919-rlvdrhpc0nq1lbwi
Change set/remove to require a lock for the branch config files.

This means that tests (or any plugin for that matter) do not requires an
explicit lock on the branch anymore to change a single option. This also
means the optimisation becomes "opt-in" and as such won't be as
spectacular as it may be and/or harder to get right (nothing fails
anymore).

This reduces the diff by ~300 lines.

Code/tests that were updating more than one config option is still taking
a lock to at least avoid some IOs and demonstrate the benefits through
the decreased number of hpss calls.

The duplication between BranchStack and BranchOnlyStack will be removed
once the same sharing is in place for local config files, at which point
the Stack class itself may be able to host the changes.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import threading
 
21
 
 
22
 
 
23
from bzrlib import (
 
24
    osutils,
 
25
    tests,
 
26
    )
 
27
from bzrlib.tests import test_server
 
28
from bzrlib.tests.scenarios import load_tests_apply_scenarios
 
29
 
 
30
 
 
31
load_tests = load_tests_apply_scenarios
 
32
 
 
33
 
 
34
def portable_socket_pair():
 
35
    """Return a pair of TCP sockets connected to each other.
 
36
 
 
37
    Unlike socket.socketpair, this should work on Windows.
 
38
    """
 
39
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
40
    listen_sock.bind(('127.0.0.1', 0))
 
41
    listen_sock.listen(1)
 
42
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
43
    client_sock.connect(listen_sock.getsockname())
 
44
    server_sock, addr = listen_sock.accept()
 
45
    listen_sock.close()
 
46
    return server_sock, client_sock
 
47
 
 
48
 
 
49
class TCPClient(object):
 
50
 
 
51
    def __init__(self):
 
52
        self.sock = None
 
53
 
 
54
    def connect(self, addr):
 
55
        if self.sock is not None:
 
56
            raise AssertionError('Already connected to %r'
 
57
                                 % (self.sock.getsockname(),))
 
58
        self.sock = osutils.connect_socket(addr)
 
59
 
 
60
    def disconnect(self):
 
61
        if self.sock is not None:
 
62
            try:
 
63
                self.sock.shutdown(socket.SHUT_RDWR)
 
64
                self.sock.close()
 
65
            except socket.error, e:
 
66
                if e[0] in (errno.EBADF, errno.ENOTCONN):
 
67
                    # Right, the socket is already down
 
68
                    pass
 
69
                else:
 
70
                    raise
 
71
            self.sock = None
 
72
 
 
73
    def write(self, s):
 
74
        return self.sock.sendall(s)
 
75
 
 
76
    def read(self, bufsize=4096):
 
77
        return self.sock.recv(bufsize)
 
78
 
 
79
 
 
80
class TCPConnectionHandler(SocketServer.BaseRequestHandler):
 
81
 
 
82
    def handle(self):
 
83
        self.done = False
 
84
        self.handle_connection()
 
85
        while not self.done:
 
86
            self.handle_connection()
 
87
 
 
88
    def readline(self):
 
89
        # TODO: We should be buffering any extra data sent, etc. However, in
 
90
        #       practice, we don't send extra content, so we haven't bothered
 
91
        #       to implement it yet.
 
92
        req = self.request.recv(4096)
 
93
        # An empty string is allowed, to indicate the end of the connection
 
94
        if not req or (req.endswith('\n') and req.count('\n') == 1):
 
95
            return req
 
96
        raise ValueError('[%r] not a simple line' % (req,))
 
97
 
 
98
    def handle_connection(self):
 
99
        req = self.readline()
 
100
        if not req:
 
101
            self.done = True
 
102
        elif req == 'ping\n':
 
103
            self.request.sendall('pong\n')
 
104
        else:
 
105
            raise ValueError('[%s] not understood' % req)
 
106
 
 
107
 
 
108
class TestTCPServerInAThread(tests.TestCase):
 
109
 
 
110
    scenarios = [
 
111
        (name, {'server_class': getattr(test_server, name)})
 
112
        for name in
 
113
        ('TestingTCPServer', 'TestingThreadingTCPServer')]
 
114
 
 
115
    def get_server(self, server_class=None, connection_handler_class=None):
 
116
        if server_class is not None:
 
117
            self.server_class = server_class
 
118
        if connection_handler_class is None:
 
119
            connection_handler_class = TCPConnectionHandler
 
120
        server = test_server.TestingTCPServerInAThread(
 
121
            ('localhost', 0), self.server_class, connection_handler_class)
 
122
        server.start_server()
 
123
        self.addCleanup(server.stop_server)
 
124
        return server
 
125
 
 
126
    def get_client(self):
 
127
        client = TCPClient()
 
128
        self.addCleanup(client.disconnect)
 
129
        return client
 
130
 
 
131
    def get_server_connection(self, server, conn_rank):
 
132
        return server.server.clients[conn_rank]
 
133
 
 
134
    def assertClientAddr(self, client, server, conn_rank):
 
135
        conn = self.get_server_connection(server, conn_rank)
 
136
        self.assertEquals(client.sock.getsockname(), conn[1])
 
137
 
 
138
    def test_start_stop(self):
 
139
        server = self.get_server()
 
140
        client = self.get_client()
 
141
        server.stop_server()
 
142
        # since the server doesn't accept connections anymore attempting to
 
143
        # connect should fail
 
144
        client = self.get_client()
 
145
        self.assertRaises(socket.error,
 
146
                          client.connect, (server.host, server.port))
 
147
 
 
148
    def test_client_talks_server_respond(self):
 
149
        server = self.get_server()
 
150
        client = self.get_client()
 
151
        client.connect((server.host, server.port))
 
152
        self.assertIs(None, client.write('ping\n'))
 
153
        resp = client.read()
 
154
        self.assertClientAddr(client, server, 0)
 
155
        self.assertEquals('pong\n', resp)
 
156
 
 
157
    def test_server_fails_to_start(self):
 
158
        class CantStart(Exception):
 
159
            pass
 
160
 
 
161
        class CantStartServer(test_server.TestingTCPServer):
 
162
 
 
163
            def server_bind(self):
 
164
                raise CantStart()
 
165
 
 
166
        # The exception is raised in the main thread
 
167
        self.assertRaises(CantStart,
 
168
                          self.get_server, server_class=CantStartServer)
 
169
 
 
170
    def test_server_fails_while_serving_or_stopping(self):
 
171
        class CantConnect(Exception):
 
172
            pass
 
173
 
 
174
        class FailingConnectionHandler(TCPConnectionHandler):
 
175
 
 
176
            def handle(self):
 
177
                raise CantConnect()
 
178
 
 
179
        server = self.get_server(
 
180
            connection_handler_class=FailingConnectionHandler)
 
181
        # The server won't fail until a client connect
 
182
        client = self.get_client()
 
183
        client.connect((server.host, server.port))
 
184
        # We make sure the server wants to handle a request, but the request is
 
185
        # guaranteed to fail. However, the server should make sure that the
 
186
        # connection gets closed, and stop_server should then raise the
 
187
        # original exception.
 
188
        client.write('ping\n')
 
189
        try:
 
190
            self.assertEqual('', client.read())
 
191
        except socket.error, e:
 
192
            # On Windows, failing during 'handle' means we get
 
193
            # 'forced-close-of-connection'. Possibly because we haven't
 
194
            # processed the write request before we close the socket.
 
195
            WSAECONNRESET = 10054
 
196
            if e.errno in (WSAECONNRESET,):
 
197
                pass
 
198
        # Now the server has raised the exception in its own thread
 
199
        self.assertRaises(CantConnect, server.stop_server)
 
200
 
 
201
    def test_server_crash_while_responding(self):
 
202
        # We want to ensure the exception has been caught
 
203
        caught = threading.Event()
 
204
        caught.clear()
 
205
        # The thread that will serve the client, this needs to be an attribute
 
206
        # so the handler below can modify it when it's executed (it's
 
207
        # instantiated when the request is processed)
 
208
        self.connection_thread = None
 
209
 
 
210
        class FailToRespond(Exception):
 
211
            pass
 
212
 
 
213
        class FailingDuringResponseHandler(TCPConnectionHandler):
 
214
 
 
215
            # We use 'request' instead of 'self' below because the test matters
 
216
            # more and we need a container to properly set connection_thread.
 
217
            def handle_connection(request):
 
218
                req = request.readline()
 
219
                # Capture the thread and make it use 'caught' so we can wait on
 
220
                # the event that will be set when the exception is caught. We
 
221
                # also capture the thread to know where to look.
 
222
                self.connection_thread = threading.currentThread()
 
223
                self.connection_thread.set_sync_event(caught)
 
224
                raise FailToRespond()
 
225
 
 
226
        server = self.get_server(
 
227
            connection_handler_class=FailingDuringResponseHandler)
 
228
        client = self.get_client()
 
229
        client.connect((server.host, server.port))
 
230
        client.write('ping\n')
 
231
        # Wait for the exception to be caught
 
232
        caught.wait()
 
233
        self.assertEqual('', client.read()) # connection closed
 
234
        # Check that the connection thread did catch the exception,
 
235
        # http://pad.lv/869366 was wrongly checking the server thread which
 
236
        # works for TestingTCPServer where the connection is handled in the
 
237
        # same thread than the server one but was racy for
 
238
        # TestingThreadingTCPServer. Since the connection thread detaches
 
239
        # itself before handling the request, we are guaranteed that the
 
240
        # exception won't leak into the server thread anymore.
 
241
        self.assertRaises(FailToRespond,
 
242
                          self.connection_thread.pending_exception)
 
243
 
 
244
    def test_exception_swallowed_while_serving(self):
 
245
        # We need to ensure the exception has been caught
 
246
        caught = threading.Event()
 
247
        caught.clear()
 
248
        # The thread that will serve the client, this needs to be an attribute
 
249
        # so the handler below can access it when it's executed (it's
 
250
        # instantiated when the request is processed)
 
251
        self.connection_thread = None
 
252
        class CantServe(Exception):
 
253
            pass
 
254
 
 
255
        class FailingWhileServingConnectionHandler(TCPConnectionHandler):
 
256
 
 
257
            # We use 'request' instead of 'self' below because the test matters
 
258
            # more and we need a container to properly set connection_thread.
 
259
            def handle(request):
 
260
                # Capture the thread and make it use 'caught' so we can wait on
 
261
                # the event that will be set when the exception is caught. We
 
262
                # also capture the thread to know where to look.
 
263
                self.connection_thread = threading.currentThread()
 
264
                self.connection_thread.set_sync_event(caught)
 
265
                raise CantServe()
 
266
 
 
267
        server = self.get_server(
 
268
            connection_handler_class=FailingWhileServingConnectionHandler)
 
269
        self.assertEquals(True, server.server.serving)
 
270
        # Install the exception swallower
 
271
        server.set_ignored_exceptions(CantServe)
 
272
        client = self.get_client()
 
273
        # Connect to the server so the exception is raised there
 
274
        client.connect((server.host, server.port))
 
275
        # Wait for the exception to be caught
 
276
        caught.wait()
 
277
        self.assertEqual('', client.read()) # connection closed
 
278
        # The connection wasn't served properly but the exception should have
 
279
        # been swallowed (see test_server_crash_while_responding remark about
 
280
        # http://pad.lv/869366 explaining why we can't check the server thread
 
281
        # here). More precisely, the exception *has* been caught and captured
 
282
        # but it is cleared when joining the thread (or trying to acquire the
 
283
        # exception) and as such won't propagate to the server thread.
 
284
        self.assertIs(None, self.connection_thread.pending_exception())
 
285
        self.assertIs(None, server.pending_exception())
 
286
 
 
287
    def test_handle_request_closes_if_it_doesnt_process(self):
 
288
        server = self.get_server()
 
289
        client = self.get_client()
 
290
        server.server.serving = False
 
291
        client.connect((server.host, server.port))
 
292
        self.assertEqual('', client.read())
 
293
 
 
294
 
 
295
class TestTestingSmartServer(tests.TestCase):
 
296
 
 
297
    def test_sets_client_timeout(self):
 
298
        server = test_server.TestingSmartServer(('localhost', 0), None, None,
 
299
            root_client_path='/no-such-client/path')
 
300
        self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
 
301
                         server._client_timeout)
 
302
        sock = socket.socket()
 
303
        h = server._make_handler(sock)
 
304
        self.assertEqual(test_server._DEFAULT_TESTING_CLIENT_TIMEOUT,
 
305
                         h._client_timeout)
 
306
 
 
307
 
 
308
class FakeServer(object):
 
309
    """Minimal implementation to pass to TestingSmartConnectionHandler"""
 
310
    backing_transport = None
 
311
    root_client_path = '/'
 
312
 
 
313
 
 
314
class TestTestingSmartConnectionHandler(tests.TestCase):
 
315
 
 
316
    def test_connection_timeout_suppressed(self):
 
317
        self.overrideAttr(test_server, '_DEFAULT_TESTING_CLIENT_TIMEOUT', 0.01)
 
318
        s = FakeServer()
 
319
        server_sock, client_sock = portable_socket_pair()
 
320
        # This should timeout quickly, but not generate an exception.
 
321
        handler = test_server.TestingSmartConnectionHandler(server_sock,
 
322
            server_sock.getpeername(), s)
 
323
 
 
324
    def test_connection_shutdown_while_serving_no_error(self):
 
325
        s = FakeServer()
 
326
        server_sock, client_sock = portable_socket_pair()
 
327
        class ShutdownConnectionHandler(
 
328
            test_server.TestingSmartConnectionHandler):
 
329
 
 
330
            def _build_protocol(self):
 
331
                self.finished = True
 
332
                return super(ShutdownConnectionHandler, self)._build_protocol()
 
333
        # This should trigger shutdown after the entering _build_protocol, and
 
334
        # we should exit cleanly, without raising an exception.
 
335
        handler = ShutdownConnectionHandler(server_sock,
 
336
            server_sock.getpeername(), s)