618
638
super(TestSmartServerStreamMedium, self).setUp()
619
639
self.overrideEnv('BZR_NO_SMART_VFS', None)
621
def portable_socket_pair(self):
622
"""Return a pair of TCP sockets connected to each other.
624
Unlike socket.socketpair, this should work on Windows.
626
listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
627
listen_sock.bind(('127.0.0.1', 0))
628
listen_sock.listen(1)
629
client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
630
client_sock.connect(listen_sock.getsockname())
631
server_sock, addr = listen_sock.accept()
633
return server_sock, client_sock
641
def create_pipe_medium(self, to_server, from_server, transport,
643
"""Create a new SmartServerPipeStreamMedium."""
644
return medium.SmartServerPipeStreamMedium(to_server, from_server,
645
transport, timeout=timeout)
647
def create_pipe_context(self, to_server_bytes, transport):
648
"""Create a SmartServerSocketStreamMedium.
650
This differes from create_pipe_medium, in that we initialize the
651
request that is sent to the server, and return the StringIO class that
652
will hold the response.
654
to_server = StringIO(to_server_bytes)
655
from_server = StringIO()
656
m = self.create_pipe_medium(to_server, from_server, transport)
657
return m, from_server
659
def create_socket_medium(self, server_sock, transport, timeout=4.0):
660
"""Initialize a new medium.SmartServerSocketStreamMedium."""
661
return medium.SmartServerSocketStreamMedium(server_sock, transport,
664
def create_socket_context(self, transport, timeout=4.0):
665
"""Create a new SmartServerSocketStreamMedium with default context.
667
This will call portable_socket_pair and pass the server side to
668
create_socket_medium along with transport.
669
It then returns the client_sock and the server.
671
server_sock, client_sock = portable_socket_pair()
672
server = self.create_socket_medium(server_sock, transport,
674
return server, client_sock
635
676
def test_smart_query_version(self):
636
677
"""Feed a canned query version to a server"""
637
678
# wire-to-wire, using the whole stack
638
to_server = StringIO('hello\n')
639
from_server = StringIO()
640
679
transport = local.LocalTransport(urlutils.local_path_to_url('/'))
641
server = medium.SmartServerPipeStreamMedium(
642
to_server, from_server, transport)
680
server, from_server = self.create_pipe_context('hello\n', transport)
643
681
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
644
682
from_server.write)
645
683
server._serve_one_request(smart_protocol)
697
729
def test_socket_stream_with_bulk_data(self):
698
730
sample_request_bytes = 'command\n9\nbulk datadone\n'
699
server_sock, client_sock = self.portable_socket_pair()
700
server = medium.SmartServerSocketStreamMedium(
731
server, client_sock = self.create_socket_context(None)
702
732
sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
703
733
client_sock.sendall(sample_request_bytes)
704
734
server._serve_one_request(sample_protocol)
735
server._disconnect_client()
706
736
self.assertEqual('', client_sock.recv(1))
707
737
self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
708
738
self.assertFalse(server.finished)
710
740
def test_pipe_like_stream_shutdown_detection(self):
711
to_server = StringIO('')
712
from_server = StringIO()
713
server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
741
server, _ = self.create_pipe_context('', None)
714
742
server._serve_one_request(SampleRequest('x'))
715
743
self.assertTrue(server.finished)
717
745
def test_socket_stream_shutdown_detection(self):
718
server_sock, client_sock = self.portable_socket_pair()
746
server, client_sock = self.create_socket_context(None)
719
747
client_sock.close()
720
server = medium.SmartServerSocketStreamMedium(
722
748
server._serve_one_request(SampleRequest('x'))
723
749
self.assertTrue(server.finished)
859
876
self.assertTrue(server.finished)
861
878
def test_pipe_like_stream_keyboard_interrupt_handling(self):
862
to_server = StringIO('')
863
from_server = StringIO()
864
server = medium.SmartServerPipeStreamMedium(
865
to_server, from_server, None)
879
server, from_server = self.create_pipe_context('', None)
866
880
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
867
881
self.assertRaises(
868
882
KeyboardInterrupt, server._serve_one_request, fake_protocol)
869
883
self.assertEqual('', from_server.getvalue())
871
885
def test_socket_stream_keyboard_interrupt_handling(self):
872
server_sock, client_sock = self.portable_socket_pair()
873
server = medium.SmartServerSocketStreamMedium(
886
server, client_sock = self.create_socket_context(None)
875
887
fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
876
888
self.assertRaises(
877
889
KeyboardInterrupt, server._serve_one_request, fake_protocol)
890
server._disconnect_client()
879
891
self.assertEqual('', client_sock.recv(1))
881
893
def build_protocol_pipe_like(self, bytes):
882
to_server = StringIO(bytes)
883
from_server = StringIO()
884
server = medium.SmartServerPipeStreamMedium(
885
to_server, from_server, None)
894
server, _ = self.create_pipe_context(bytes, None)
886
895
return server._build_protocol()
888
897
def build_protocol_socket(self, bytes):
889
server_sock, client_sock = self.portable_socket_pair()
890
server = medium.SmartServerSocketStreamMedium(
898
server, client_sock = self.create_socket_context(None)
892
899
client_sock.sendall(bytes)
893
900
client_sock.close()
894
901
return server._build_protocol()
934
941
server_protocol = self.build_protocol_socket('bzr request 2\n')
935
942
self.assertProtocolTwo(server_protocol)
944
def test__build_protocol_returns_if_stopping(self):
945
# _build_protocol should notice that we are stopping, and return
946
# without waiting for bytes from the client.
947
server, client_sock = self.create_socket_context(None)
948
server._stop_gracefully()
949
self.assertIs(None, server._build_protocol())
951
def test_socket_set_timeout(self):
952
server, _ = self.create_socket_context(None, timeout=1.23)
953
self.assertEqual(1.23, server._client_timeout)
955
def test_pipe_set_timeout(self):
956
server = self.create_pipe_medium(None, None, None,
958
self.assertEqual(1.23, server._client_timeout)
960
def test_socket_wait_for_bytes_with_timeout_with_data(self):
961
server, client_sock = self.create_socket_context(None)
962
client_sock.sendall('data\n')
963
# This should not block or consume any actual content
964
self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
965
data = server.read_bytes(5)
966
self.assertEqual('data\n', data)
968
def test_socket_wait_for_bytes_with_timeout_no_data(self):
969
server, client_sock = self.create_socket_context(None)
970
# This should timeout quickly, reporting that there wasn't any data
971
self.assertRaises(errors.ConnectionTimeout,
972
server._wait_for_bytes_with_timeout, 0.01)
974
data = server.read_bytes(1)
975
self.assertEqual('', data)
977
def test_socket_wait_for_bytes_with_timeout_closed(self):
978
server, client_sock = self.create_socket_context(None)
979
# With the socket closed, this should return right away.
980
# It seems select.select() returns that you *can* read on the socket,
981
# even though it closed. Presumably as a way to tell it is closed?
982
# Testing shows that without sock.close() this times-out failing the
983
# test, but with it, it returns False immediately.
985
self.assertFalse(server._wait_for_bytes_with_timeout(10))
986
data = server.read_bytes(1)
987
self.assertEqual('', data)
989
def test_socket_wait_for_bytes_with_shutdown(self):
990
server, client_sock = self.create_socket_context(None)
992
# Override the _timer functionality, so that time never increments,
993
# this way, we can be sure we stopped because of the flag, and not
994
# because of a timeout, etc.
995
server._timer = lambda: t
996
server._client_poll_timeout = 0.1
997
server._stop_gracefully()
998
server._wait_for_bytes_with_timeout(1.0)
1000
def test_socket_serve_timeout_closes_socket(self):
1001
server, client_sock = self.create_socket_context(None, timeout=0.1)
1002
# This should timeout quickly, and then close the connection so that
1003
# client_sock recv doesn't block.
1005
self.assertEqual('', client_sock.recv(1))
1007
def test_pipe_wait_for_bytes_with_timeout_with_data(self):
1008
# We intentionally use a real pipe here, so that we can 'select' on it.
1009
# You can't select() on a StringIO
1010
(r_server, w_client) = os.pipe()
1011
self.addCleanup(os.close, w_client)
1012
with os.fdopen(r_server, 'rb') as rf_server:
1013
server = self.create_pipe_medium(
1014
rf_server, None, None)
1015
os.write(w_client, 'data\n')
1016
# This should not block or consume any actual content
1017
server._wait_for_bytes_with_timeout(0.1)
1018
data = server.read_bytes(5)
1019
self.assertEqual('data\n', data)
1021
def test_pipe_wait_for_bytes_with_timeout_no_data(self):
1022
# We intentionally use a real pipe here, so that we can 'select' on it.
1023
# You can't select() on a StringIO
1024
(r_server, w_client) = os.pipe()
1025
# We can't add an os.close cleanup here, because we need to control
1026
# when the file handle gets closed ourselves.
1027
with os.fdopen(r_server, 'rb') as rf_server:
1028
server = self.create_pipe_medium(
1029
rf_server, None, None)
1030
if sys.platform == 'win32':
1031
# Windows cannot select() on a pipe, so we just always return
1032
server._wait_for_bytes_with_timeout(0.01)
1034
self.assertRaises(errors.ConnectionTimeout,
1035
server._wait_for_bytes_with_timeout, 0.01)
1037
data = server.read_bytes(5)
1038
self.assertEqual('', data)
1040
def test_pipe_wait_for_bytes_no_fileno(self):
1041
server, _ = self.create_pipe_context('', None)
1042
# Our file doesn't support polling, so we should always just return
1043
# 'you have data to consume.
1044
server._wait_for_bytes_with_timeout(0.01)
938
1047
class TestGetProtocolFactoryForBytes(tests.TestCase):
939
1048
"""_get_protocol_factory_for_bytes identifies the protocol factory a server
970
1079
class TestSmartTCPServer(tests.TestCase):
1081
def make_server(self):
1082
"""Create a SmartTCPServer that we can exercise.
1084
Note: we don't use SmartTCPServer_for_testing because the testing
1085
version overrides lots of functionality like 'serve', and we want to
1086
test the raw service.
1088
This will start the server in another thread, and wait for it to
1089
indicate it has finished starting up.
1091
:return: (server, server_thread)
1093
t = _mod_transport.get_transport_from_url('memory:///')
1094
server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
1095
server._ACCEPT_TIMEOUT = 0.1
1096
# We don't use 'localhost' because that might be an IPv6 address.
1097
server.start_server('127.0.0.1', 0)
1098
server_thread = threading.Thread(target=server.serve,
1100
server_thread.start()
1101
# Ensure this gets called at some point
1102
self.addCleanup(server._stop_gracefully)
1103
server._started.wait()
1104
return server, server_thread
1106
def ensure_client_disconnected(self, client_sock):
1107
"""Ensure that a socket is closed, discarding all errors."""
1113
def connect_to_server(self, server):
1114
"""Create a client socket that can talk to the server."""
1115
client_sock = socket.socket()
1116
server_info = server._server_socket.getsockname()
1117
client_sock.connect(server_info)
1118
self.addCleanup(self.ensure_client_disconnected, client_sock)
1121
def connect_to_server_and_hangup(self, server):
1122
"""Connect to the server, and then hang up.
1123
That way it doesn't sit waiting for 'accept()' to timeout.
1125
# If the server has already signaled that the socket is closed, we
1126
# don't need to try to connect to it. Not being set, though, the server
1127
# might still close the socket while we try to connect to it. So we
1128
# still have to catch the exception.
1129
if server._stopped.isSet():
1132
client_sock = self.connect_to_server(server)
1134
except socket.error, e:
1135
# If the server has hung up already, that is fine.
1138
def say_hello(self, client_sock):
1139
"""Send the 'hello' smart RPC, and expect the response."""
1140
client_sock.send('hello\n')
1141
self.assertEqual('ok\x012\n', client_sock.recv(5))
1143
def shutdown_server_cleanly(self, server, server_thread):
1144
server._stop_gracefully()
1145
self.connect_to_server_and_hangup(server)
1146
server._stopped.wait()
1147
server._fully_stopped.wait()
1148
server_thread.join()
972
1150
def test_get_error_unexpected(self):
973
1151
"""Error reported by server with no specific representation"""
974
1152
self.overrideEnv('BZR_NO_SMART_VFS', None)
992
1170
t.get, 'something')
993
1171
self.assertContainsRe(str(err), 'some random exception')
1173
def test_propagates_timeout(self):
1174
server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
1175
server_sock, client_sock = portable_socket_pair()
1176
handler = server._make_handler(server_sock)
1177
self.assertEqual(1.23, handler._client_timeout)
1179
def test_serve_conn_tracks_connections(self):
1180
server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
1181
server_sock, client_sock = portable_socket_pair()
1182
server.serve_conn(server_sock, '-%s' % (self.id(),))
1183
self.assertEqual(1, len(server._active_connections))
1184
# We still want to talk on the connection. Polling should indicate it
1186
server._poll_active_connections()
1187
self.assertEqual(1, len(server._active_connections))
1188
# Closing the socket will end the active thread, and polling will
1189
# notice and remove it from the active set.
1191
server._poll_active_connections(0.1)
1192
self.assertEqual(0, len(server._active_connections))
1194
def test_serve_closes_out_finished_connections(self):
1195
server, server_thread = self.make_server()
1196
# The server is started, connect to it.
1197
client_sock = self.connect_to_server(server)
1198
# We send and receive on the connection, so that we know the
1199
# server-side has seen the connect, and started handling the
1201
self.say_hello(client_sock)
1202
self.assertEqual(1, len(server._active_connections))
1203
# Grab a handle to the thread that is processing our request
1204
_, server_side_thread = server._active_connections[0]
1205
# Close the connection, ask the server to stop, and wait for the
1206
# server to stop, as well as the thread that was servicing the
1209
# Wait for the server-side request thread to notice we are closed.
1210
server_side_thread.join()
1211
# Stop the server, it should notice the connection has finished.
1212
self.shutdown_server_cleanly(server, server_thread)
1213
# The server should have noticed that all clients are gone before
1215
self.assertEqual(0, len(server._active_connections))
1217
def test_serve_reaps_finished_connections(self):
1218
server, server_thread = self.make_server()
1219
client_sock1 = self.connect_to_server(server)
1220
# We send and receive on the connection, so that we know the
1221
# server-side has seen the connect, and started handling the
1223
self.say_hello(client_sock1)
1224
server_handler1, server_side_thread1 = server._active_connections[0]
1225
client_sock1.close()
1226
server_side_thread1.join()
1227
# By waiting until the first connection is fully done, the server
1228
# should notice after another connection that the first has finished.
1229
client_sock2 = self.connect_to_server(server)
1230
self.say_hello(client_sock2)
1231
server_handler2, server_side_thread2 = server._active_connections[-1]
1232
# There is a race condition. We know that client_sock2 has been
1233
# registered, but not that _poll_active_connections has been called. We
1234
# know that it will be called before the server will accept a new
1235
# connection, however. So connect one more time, and assert that we
1236
# either have 1 or 2 active connections (never 3), and that the 'first'
1237
# connection is not connection 1
1238
client_sock3 = self.connect_to_server(server)
1239
self.say_hello(client_sock3)
1240
# Copy the list, so we don't have it mutating behind our back
1241
conns = list(server._active_connections)
1242
self.assertEqual(2, len(conns))
1243
self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
1244
self.assertEqual((server_handler2, server_side_thread2), conns[0])
1245
client_sock2.close()
1246
client_sock3.close()
1247
self.shutdown_server_cleanly(server, server_thread)
1249
def test_graceful_shutdown_waits_for_clients_to_stop(self):
1250
server, server_thread = self.make_server()
1251
# We need something big enough that it won't fit in a single recv. So
1252
# the server thread gets blocked writing content to the client until we
1253
# finish reading on the client.
1254
server.backing_transport.put_bytes('bigfile',
1256
client_sock = self.connect_to_server(server)
1257
self.say_hello(client_sock)
1258
_, server_side_thread = server._active_connections[0]
1259
# Start the RPC, but don't finish reading the response
1260
client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
1261
'base', client_sock)
1262
client_client = client._SmartClient(client_medium)
1263
resp, response_handler = client_client.call_expecting_body('get',
1265
self.assertEqual(('ok',), resp)
1266
# Ask the server to stop gracefully, and wait for it.
1267
server._stop_gracefully()
1268
self.connect_to_server_and_hangup(server)
1269
server._stopped.wait()
1270
# It should not be accepting another connection.
1271
self.assertRaises(socket.error, self.connect_to_server, server)
1272
# It should also not be fully stopped
1273
server._fully_stopped.wait(0.01)
1274
self.assertFalse(server._fully_stopped.isSet())
1275
response_handler.read_body_bytes()
1277
server_side_thread.join()
1278
server_thread.join()
1279
self.assertTrue(server._fully_stopped.isSet())
1280
log = self.get_log()
1281
self.assertThat(log, DocTestMatches("""\
1282
INFO Requested to stop gracefully
1283
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
1284
INFO Waiting for 1 client(s) to finish
1285
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
1287
def test_stop_gracefully_tells_handlers_to_stop(self):
1288
server, server_thread = self.make_server()
1289
client_sock = self.connect_to_server(server)
1290
self.say_hello(client_sock)
1291
server_handler, server_side_thread = server._active_connections[0]
1292
self.assertFalse(server_handler.finished)
1293
server._stop_gracefully()
1294
self.assertTrue(server_handler.finished)
1296
self.connect_to_server_and_hangup(server)
1297
server_thread.join()
996
1300
class SmartTCPTests(tests.TestCase):
997
1301
"""Tests for connection/end to end behaviour using the TCP server.