~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_test_server.py

  • Committer: Martin Pool
  • Date: 2010-02-23 07:43:11 UTC
  • mfrom: (4797.2.20 2.1)
  • mto: This revision was merged to the branch mainline in revision 5055.
  • Revision ID: mbp@sourcefrog.net-20100223074311-gnj55xdhrgz9l94e
Merge 2.1 back to trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
import errno
18
 
import socket
19
 
import SocketServer
20
 
import threading
21
 
 
22
 
from bzrlib import (
23
 
    osutils,
24
 
    tests,
25
 
    )
26
 
from bzrlib.tests import test_server
27
 
 
28
 
 
29
 
def load_tests(basic_tests, module, loader):
30
 
    suite = loader.suiteClass()
31
 
    server_tests, remaining_tests = tests.split_suite_by_condition(
32
 
        basic_tests, tests.condition_isinstance(TestTCPServerInAThread))
33
 
    server_scenarios = [ (name, {'server_class': getattr(test_server, name)})
34
 
                         for name in
35
 
                         ('TestingTCPServer', 'TestingThreadingTCPServer')]
36
 
    tests.multiply_tests(server_tests, server_scenarios, suite)
37
 
    suite.addTest(remaining_tests)
38
 
    return suite
39
 
 
40
 
 
41
 
class TestThreadWithException(tests.TestCase):
42
 
 
43
 
    def test_start_and_join_smoke_test(self):
44
 
        def do_nothing():
45
 
            pass
46
 
 
47
 
        tt = test_server.ThreadWithException(target=do_nothing)
48
 
        tt.start()
49
 
        tt.join()
50
 
 
51
 
    def test_exception_is_re_raised(self):
52
 
        class MyException(Exception):
53
 
            pass
54
 
 
55
 
        def raise_my_exception():
56
 
            raise MyException()
57
 
 
58
 
        tt = test_server.ThreadWithException(target=raise_my_exception)
59
 
        tt.start()
60
 
        self.assertRaises(MyException, tt.join)
61
 
 
62
 
    def test_join_when_no_exception(self):
63
 
        resume = threading.Event()
64
 
        class MyException(Exception):
65
 
            pass
66
 
 
67
 
        def raise_my_exception():
68
 
            # Wait for the test to tell us to resume
69
 
            resume.wait()
70
 
            # Now we can raise
71
 
            raise MyException()
72
 
 
73
 
        tt = test_server.ThreadWithException(target=raise_my_exception)
74
 
        tt.start()
75
 
        tt.join(timeout=0)
76
 
        self.assertIs(None, tt.exception)
77
 
        resume.set()
78
 
        self.assertRaises(MyException, tt.join)
79
 
 
80
 
 
81
 
class TCPClient(object):
82
 
 
83
 
    def __init__(self):
84
 
        self.sock = None
85
 
 
86
 
    def connect(self, addr):
87
 
        if self.sock is not None:
88
 
            raise AssertionError('Already connected to %r'
89
 
                                 % (self.sock.getsockname(),))
90
 
        self.sock = osutils.connect_socket(addr)
91
 
 
92
 
    def disconnect(self):
93
 
        if self.sock is not None:
94
 
            try:
95
 
                self.sock.shutdown(socket.SHUT_RDWR)
96
 
                self.sock.close()
97
 
            except socket.error, e:
98
 
                if e[0] in (errno.EBADF, errno.ENOTCONN):
99
 
                    # Right, the socket is already down
100
 
                    pass
101
 
                else:
102
 
                    raise
103
 
            self.sock = None
104
 
 
105
 
    def write(self, s):
106
 
        return self.sock.sendall(s)
107
 
 
108
 
    def read(self, bufsize=4096):
109
 
        return self.sock.recv(bufsize)
110
 
 
111
 
 
112
 
class TCPConnectionHandler(SocketServer.StreamRequestHandler):
113
 
 
114
 
    def handle(self):
115
 
        self.done = False
116
 
        self.handle_connection()
117
 
        while not self.done:
118
 
            self.handle_connection()
119
 
 
120
 
    def handle_connection(self):
121
 
        req = self.rfile.readline()
122
 
        if not req:
123
 
            self.done = True
124
 
        elif req == 'ping\n':
125
 
            self.wfile.write('pong\n')
126
 
        else:
127
 
            raise ValueError('[%s] not understood' % req)
128
 
 
129
 
 
130
 
class TestTCPServerInAThread(tests.TestCase):
131
 
 
132
 
    # Set by load_tests()
133
 
    server_class = None
134
 
 
135
 
    def get_server(self, server_class=None, connection_handler_class=None):
136
 
        if server_class is not None:
137
 
            self.server_class = server_class
138
 
        if connection_handler_class is None:
139
 
            connection_handler_class = TCPConnectionHandler
140
 
        server =  test_server.TestingTCPServerInAThread(
141
 
            ('localhost', 0), self.server_class, connection_handler_class)
142
 
        server.start_server()
143
 
        self.addCleanup(server.stop_server)
144
 
        return server
145
 
 
146
 
    def get_client(self):
147
 
        client = TCPClient()
148
 
        self.addCleanup(client.disconnect)
149
 
        return client
150
 
 
151
 
    def get_server_connection(self, server, conn_rank):
152
 
        return server.server.clients[conn_rank]
153
 
 
154
 
    def assertClientAddr(self, client, server, conn_rank):
155
 
        conn = self.get_server_connection(server, conn_rank)
156
 
        self.assertEquals(client.sock.getsockname(), conn[1])
157
 
 
158
 
    def test_start_stop(self):
159
 
        server = self.get_server()
160
 
        client = self.get_client()
161
 
        server.stop_server()
162
 
        # since the server doesn't accept connections anymore attempting to
163
 
        # connect should fail
164
 
        client = self.get_client()
165
 
        self.assertRaises(socket.error,
166
 
                          client.connect, (server.host, server.port))
167
 
 
168
 
    def test_client_talks_server_respond(self):
169
 
        server = self.get_server()
170
 
        client = self.get_client()
171
 
        client.connect((server.host, server.port))
172
 
        self.assertIs(None, client.write('ping\n'))
173
 
        resp = client.read()
174
 
        self.assertClientAddr(client, server, 0)
175
 
        self.assertEquals('pong\n', resp)
176
 
 
177
 
    def test_server_fails_to_start(self):
178
 
        class CantStart(Exception):
179
 
            pass
180
 
 
181
 
        class CantStartServer(test_server.TestingTCPServer):
182
 
 
183
 
            def server_bind(self):
184
 
                raise CantStart()
185
 
 
186
 
        # The exception is raised in the main thread
187
 
        self.assertRaises(CantStart,
188
 
                          self.get_server, server_class=CantStartServer)
189
 
 
190
 
    def test_server_fails_while_serving_or_stopping(self):
191
 
        class CantConnect(Exception):
192
 
            pass
193
 
 
194
 
        class FailingConnectionHandler(TCPConnectionHandler):
195
 
 
196
 
            def handle(self):
197
 
                raise CantConnect()
198
 
 
199
 
        server = self.get_server(
200
 
            connection_handler_class=FailingConnectionHandler)
201
 
        # The server won't fail until a client connect
202
 
        client = self.get_client()
203
 
        client.connect((server.host, server.port))
204
 
        try:
205
 
            # Now we must force the server to answer by sending the request and
206
 
            # waiting for some answer. But since we don't control when the
207
 
            # server thread will be given cycles, we don't control either
208
 
            # whether our reads or writes may hang.
209
 
            client.sock.settimeout(0.1)
210
 
            client.write('ping\n')
211
 
            client.read()
212
 
        except socket.error:
213
 
            pass
214
 
        # Now the server has raised the exception in its own thread
215
 
        self.assertRaises(CantConnect, server.stop_server)
216
 
 
217
 
    def test_server_crash_while_responding(self):
218
 
        sync = threading.Event()
219
 
        sync.clear()
220
 
        class FailToRespond(Exception):
221
 
            pass
222
 
 
223
 
        class FailingDuringResponseHandler(TCPConnectionHandler):
224
 
 
225
 
            def handle_connection(self):
226
 
                req = self.rfile.readline()
227
 
                threading.currentThread().set_ready_event(sync)
228
 
                raise FailToRespond()
229
 
 
230
 
        server = self.get_server(
231
 
            connection_handler_class=FailingDuringResponseHandler)
232
 
        client = self.get_client()
233
 
        client.connect((server.host, server.port))
234
 
        client.write('ping\n')
235
 
        sync.wait()
236
 
        self.assertRaises(FailToRespond, server.pending_exception)
237
 
 
238
 
    def test_exception_swallowed_while_serving(self):
239
 
        sync = threading.Event()
240
 
        sync.clear()
241
 
        class CantServe(Exception):
242
 
            pass
243
 
 
244
 
        class FailingWhileServingConnectionHandler(TCPConnectionHandler):
245
 
 
246
 
            def handle(self):
247
 
                # We want to sync with the thread that is serving the
248
 
                # connection.
249
 
                threading.currentThread().set_ready_event(sync)
250
 
                raise CantServe()
251
 
 
252
 
        server = self.get_server(
253
 
            connection_handler_class=FailingWhileServingConnectionHandler)
254
 
        # Install the exception swallower
255
 
        server.set_ignored_exceptions(CantServe)
256
 
        client = self.get_client()
257
 
        # Connect to the server so the exception is raised there
258
 
        client.connect((server.host, server.port))
259
 
        # Wait for the exception to propagate.
260
 
        sync.wait()
261
 
        # The connection wasn't served properly but the exception should have
262
 
        # been swallowed.
263
 
        server.pending_exception()