77
73
self.vendor.calls.append(('close', ))
78
self.vendor.read_from.close()
79
self.vendor.write_to.close()
81
def get_sock_or_pipes(self):
82
return 'pipes', (self.vendor.read_from, self.vendor.write_to)
75
def get_filelike_channels(self):
76
return self.vendor.read_from, self.vendor.write_to
85
79
class _InvalidHostnameFeature(tests.Feature):
249
243
unopened_port = sock.getsockname()[1]
250
244
# having vendor be invalid means that if it tries to connect via the
251
245
# vendor it will blow up.
252
ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
253
client_medium = medium.SmartSSHClientMedium(
254
'base', ssh_params, "not a vendor")
246
client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
247
username=None, password=None, base='base', vendor="not a vendor",
248
bzr_remote_path='bzr')
257
251
def test_ssh_client_connects_on_first_use(self):
260
254
output = StringIO()
261
255
vendor = StringIOSSHVendor(StringIO(), output)
262
ssh_params = medium.SSHParams(
263
'a hostname', 'a port', 'a username', 'a password', 'bzr')
264
client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
256
client_medium = medium.SmartSSHClientMedium(
257
'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
265
259
client_medium._accept_bytes('abc')
266
260
self.assertEqual('abc', output.getvalue())
267
261
self.assertEqual([('connect_ssh', 'a username', 'a password',
275
269
output = StringIO()
276
270
vendor = StringIOSSHVendor(StringIO(), output)
277
ssh_params = medium.SSHParams(
278
'a hostname', 'a port', 'a username', 'a password',
279
bzr_remote_path='fugly')
280
client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
271
client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
272
'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
281
273
client_medium._accept_bytes('abc')
282
274
self.assertEqual('abc', output.getvalue())
283
275
self.assertEqual([('connect_ssh', 'a username', 'a password',
292
284
output = StringIO()
293
285
vendor = StringIOSSHVendor(input, output)
294
286
client_medium = medium.SmartSSHClientMedium(
295
'base', medium.SSHParams('a hostname'), vendor)
287
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
296
288
client_medium._accept_bytes('abc')
297
289
client_medium.disconnect()
298
290
self.assertTrue(input.closed)
313
305
output = StringIO()
314
306
vendor = StringIOSSHVendor(input, output)
315
307
client_medium = medium.SmartSSHClientMedium(
316
'base', medium.SSHParams('a hostname'), vendor)
308
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
317
309
client_medium._accept_bytes('abc')
318
310
client_medium.disconnect()
319
311
# the disconnect has closed output, so we need a new output for the
342
334
# Doing a disconnect on a new (and thus unconnected) SSH medium
343
335
# does not fail. It's ok to disconnect an unconnected medium.
344
336
client_medium = medium.SmartSSHClientMedium(
345
'base', medium.SSHParams(None))
337
None, base='base', bzr_remote_path='bzr')
346
338
client_medium.disconnect()
348
340
def test_ssh_client_raises_on_read_when_not_connected(self):
349
341
# Doing a read on a new (and thus unconnected) SSH medium raises
350
342
# MediumNotConnected.
351
343
client_medium = medium.SmartSSHClientMedium(
352
'base', medium.SSHParams(None))
344
None, base='base', bzr_remote_path='bzr')
353
345
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
355
347
self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
367
359
output.flush = logging_flush
368
360
vendor = StringIOSSHVendor(input, output)
369
361
client_medium = medium.SmartSSHClientMedium(
370
'base', medium.SSHParams('a hostname'), vendor=vendor)
362
'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
371
363
# this call is here to ensure we only flush once, not on every
372
364
# _accept_bytes call.
373
365
client_medium._accept_bytes('abc')
976
968
def external_url(self):
970
def get_bytes(self, path):
979
971
raise Exception("some random exception from inside server")
981
class FlakyServer(test_server.SmartTCPServer_for_testing):
982
def get_backing_transport(self, backing_transport_server):
983
return FlakyTransport()
985
smart_server = FlakyServer()
986
smart_server.start_server()
987
self.addCleanup(smart_server.stop_server)
988
t = remote.RemoteTCPTransport(smart_server.get_url())
989
self.addCleanup(t.disconnect)
990
err = self.assertRaises(errors.UnknownErrorFromSmartServer,
992
self.assertContainsRe(str(err), 'some random exception')
972
smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
973
smart_server.start_background_thread('-' + self.id())
975
transport = remote.RemoteTCPTransport(smart_server.get_url())
976
err = self.assertRaises(errors.UnknownErrorFromSmartServer,
977
transport.get, 'something')
978
self.assertContainsRe(str(err), 'some random exception')
979
transport.disconnect()
981
smart_server.stop_background_thread()
995
984
class SmartTCPTests(tests.TestCase):
996
985
"""Tests for connection/end to end behaviour using the TCP server.
998
All of these tests are run with a server running in another thread serving
987
All of these tests are run with a server running on another thread serving
999
988
a MemoryTransport, and a connection to it already open.
1001
990
the server is obtained by calling self.start_server(readonly=False).
1009
998
# NB: Tests using this fall into two categories: tests of the server,
1010
999
# tests wanting a server. The latter should be updated to use
1011
1000
# self.vfs_transport_factory etc.
1012
if backing_transport is None:
1001
if not backing_transport:
1013
1002
mem_server = memory.MemoryServer()
1014
1003
mem_server.start_server()
1015
1004
self.addCleanup(mem_server.stop_server)
1023
1012
self.backing_transport = transport.get_transport(
1024
1013
"readonly+" + self.backing_transport.abspath('.'))
1025
1014
self.server = server.SmartTCPServer(self.backing_transport)
1026
self.server.start_server('127.0.0.1', 0)
1027
1015
self.server.start_background_thread('-' + self.id())
1028
1016
self.transport = remote.RemoteTCPTransport(self.server.get_url())
1029
self.addCleanup(self.stop_server)
1017
self.addCleanup(self.tearDownServer)
1030
1018
self.permit_url(self.server.get_url())
1032
def stop_server(self):
1033
"""Disconnect the client and stop the server.
1035
This must be re-entrant as some tests will call it explicitly in
1036
addition to the normal cleanup.
1020
def tearDownServer(self):
1038
1021
if getattr(self, 'transport', None):
1039
1022
self.transport.disconnect()
1040
1023
del self.transport
1041
1024
if getattr(self, 'server', None):
1042
1025
self.server.stop_background_thread()
1026
# XXX: why not .stop_server() -- mbp 20100106
1043
1027
del self.server
1046
1030
class TestServerSocketUsage(SmartTCPTests):
1048
def test_server_start_stop(self):
1049
"""It should be safe to stop the server with no requests."""
1032
def test_server_setup_teardown(self):
1033
"""It should be safe to teardown the server with no requests."""
1050
1034
self.start_server()
1051
t = remote.RemoteTCPTransport(self.server.get_url())
1053
self.assertRaises(errors.ConnectionError, t.has, '.')
1035
server = self.server
1036
transport = remote.RemoteTCPTransport(self.server.get_url())
1037
self.tearDownServer()
1038
self.assertRaises(errors.ConnectionError, transport.has, '.')
1055
1040
def test_server_closes_listening_sock_on_shutdown_after_request(self):
1056
1041
"""The server should close its listening socket when it's stopped."""
1057
1042
self.start_server()
1058
server_url = self.server.get_url()
1043
server = self.server
1059
1044
self.transport.has('.')
1045
self.tearDownServer()
1061
1046
# if the listening socket has closed, we should get a BADFD error
1062
1047
# when connecting, rather than a hang.
1063
t = remote.RemoteTCPTransport(server_url)
1064
self.assertRaises(errors.ConnectionError, t.has, '.')
1048
transport = remote.RemoteTCPTransport(server.get_url())
1049
self.assertRaises(errors.ConnectionError, transport.has, '.')
1067
1052
class WritableEndToEndTests(SmartTCPTests):
2874
2859
self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2876
2861
def assertWriteCount(self, expected_count):
2877
# self.writes can be quite large; don't show the whole thing
2878
2862
self.assertEqual(
2879
2863
expected_count, len(self.writes),
2880
"Too many writes: %d, expected %d" % (len(self.writes), expected_count))
2864
"Too many writes: %r" % (self.writes,))
2882
2866
def test_send_error_writes_just_once(self):
2883
2867
"""An error response is written to the medium all at once."""
2906
2890
response = _mod_request.SuccessfulSmartServerResponse(
2907
2891
('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2908
2892
self.responder.send_response(response)
2909
# Per the discussion in bug 590638 we flush once after the header and
2910
# then once after each chunk
2911
self.assertWriteCount(3)
2893
# We will write just once, despite the multiple chunks, due to
2895
self.assertWriteCount(1)
2897
def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
2898
"""When there are many bytes (>1MB), multiple writes will occur rather
2899
than buffering indefinitely.
2901
# Construct a response with stream with ~1.5MB in it. This should
2902
# trigger 2 writes, but not 3
2903
onekib = '12345678' * 128
2904
body_stream = [onekib] * (1024 + 512)
2905
response = _mod_request.SuccessfulSmartServerResponse(
2906
('arg', 'arg'), body_stream=body_stream)
2907
self.responder.send_response(response)
2908
self.assertWriteCount(2)
2914
2911
class TestSmartClientUnicode(tests.TestCase):