1
# Copyright (C) 2010 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
26
from bzrlib.tests import test_server
29
def load_tests(basic_tests, module, loader):
30
suite = loader.suiteClass()
31
server_tests, remaining_tests = tests.split_suite_by_condition(
32
basic_tests, tests.condition_isinstance(TestTCPServerInAThread))
33
server_scenarios = [ (name, {'server_class': getattr(test_server, name)})
35
('TestingTCPServer', 'TestingThreadingTCPServer')]
36
tests.multiply_tests(server_tests, server_scenarios, suite)
37
suite.addTest(remaining_tests)
41
class TestThreadWithException(tests.TestCase):
43
def test_start_and_join_smoke_test(self):
47
tt = test_server.ThreadWithException(target=do_nothing)
51
def test_exception_is_re_raised(self):
52
class MyException(Exception):
55
def raise_my_exception():
58
tt = test_server.ThreadWithException(target=raise_my_exception)
60
self.assertRaises(MyException, tt.join)
62
def test_join_when_no_exception(self):
63
resume = threading.Event()
64
class MyException(Exception):
67
def raise_my_exception():
68
# Wait for the test to tell us to resume
73
tt = test_server.ThreadWithException(target=raise_my_exception)
76
self.assertIs(None, tt.exception)
78
self.assertRaises(MyException, tt.join)
81
class TCPClient(object):
86
def connect(self, addr):
87
if self.sock is not None:
88
raise AssertionError('Already connected to %r'
89
% (self.sock.getsockname(),))
90
self.sock = osutils.connect_socket(addr)
93
if self.sock is not None:
95
self.sock.shutdown(socket.SHUT_RDWR)
97
except socket.error, e:
98
if e[0] in (errno.EBADF, errno.ENOTCONN):
99
# Right, the socket is already down
106
return self.sock.sendall(s)
108
def read(self, bufsize=4096):
109
return self.sock.recv(bufsize)
112
class TCPConnectionHandler(SocketServer.StreamRequestHandler):
116
self.handle_connection()
118
self.handle_connection()
120
def handle_connection(self):
121
req = self.rfile.readline()
124
elif req == 'ping\n':
125
self.wfile.write('pong\n')
127
raise ValueError('[%s] not understood' % req)
130
class TestTCPServerInAThread(tests.TestCase):
132
# Set by load_tests()
135
def get_server(self, server_class=None, connection_handler_class=None):
136
if server_class is not None:
137
self.server_class = server_class
138
if connection_handler_class is None:
139
connection_handler_class = TCPConnectionHandler
140
server = test_server.TestingTCPServerInAThread(
141
('localhost', 0), self.server_class, connection_handler_class)
142
server.start_server()
143
self.addCleanup(server.stop_server)
146
def get_client(self):
148
self.addCleanup(client.disconnect)
151
def get_server_connection(self, server, conn_rank):
152
return server.server.clients[conn_rank]
154
def assertClientAddr(self, client, server, conn_rank):
155
conn = self.get_server_connection(server, conn_rank)
156
self.assertEquals(client.sock.getsockname(), conn[1])
158
def test_start_stop(self):
159
server = self.get_server()
160
client = self.get_client()
162
# since the server doesn't accept connections anymore attempting to
163
# connect should fail
164
client = self.get_client()
165
self.assertRaises(socket.error,
166
client.connect, (server.host, server.port))
168
def test_client_talks_server_respond(self):
169
server = self.get_server()
170
client = self.get_client()
171
client.connect((server.host, server.port))
172
self.assertIs(None, client.write('ping\n'))
174
self.assertClientAddr(client, server, 0)
175
self.assertEquals('pong\n', resp)
177
def test_server_fails_to_start(self):
178
class CantStart(Exception):
181
class CantStartServer(test_server.TestingTCPServer):
183
def server_bind(self):
186
# The exception is raised in the main thread
187
self.assertRaises(CantStart,
188
self.get_server, server_class=CantStartServer)
190
def test_server_fails_while_serving_or_stopping(self):
191
class CantConnect(Exception):
194
class FailingConnectionHandler(TCPConnectionHandler):
199
server = self.get_server(
200
connection_handler_class=FailingConnectionHandler)
201
# The server won't fail until a client connect
202
client = self.get_client()
203
client.connect((server.host, server.port))
205
# Now we must force the server to answer by sending the request and
206
# waiting for some answer. But since we don't control when the
207
# server thread will be given cycles, we don't control either
208
# whether our reads or writes may hang.
209
client.sock.settimeout(0.1)
210
client.write('ping\n')
214
# Now the server has raised the exception in its own thread
215
self.assertRaises(CantConnect, server.stop_server)
217
def test_server_crash_while_responding(self):
218
sync = threading.Event()
220
class FailToRespond(Exception):
223
class FailingDuringResponseHandler(TCPConnectionHandler):
225
def handle_connection(self):
226
req = self.rfile.readline()
227
threading.currentThread().set_ready_event(sync)
228
raise FailToRespond()
230
server = self.get_server(
231
connection_handler_class=FailingDuringResponseHandler)
232
client = self.get_client()
233
client.connect((server.host, server.port))
234
client.write('ping\n')
236
self.assertRaises(FailToRespond, server.pending_exception)
238
def test_exception_swallowed_while_serving(self):
239
sync = threading.Event()
241
class CantServe(Exception):
244
class FailingWhileServingConnectionHandler(TCPConnectionHandler):
247
# We want to sync with the thread that is serving the
249
threading.currentThread().set_ready_event(sync)
252
server = self.get_server(
253
connection_handler_class=FailingWhileServingConnectionHandler)
254
# Install the exception swallower
255
server.set_ignored_exceptions(CantServe)
256
client = self.get_client()
257
# Connect to the server so the exception is raised there
258
client.connect((server.host, server.port))
259
# Wait for the exception to propagate.
261
# The connection wasn't served properly but the exception should have
263
server.pending_exception()