~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Martin Pool
  • Date: 2010-04-01 04:41:18 UTC
  • mto: This revision was merged to the branch mainline in revision 5128.
  • Revision ID: mbp@sourcefrog.net-20100401044118-shyctqc02ob08ngz
ignore .testrepository

Show diffs side-by-side

added added

removed removed

Lines of Context:
40
40
        server,
41
41
        vfs,
42
42
)
43
 
from bzrlib.tests import (
44
 
    test_smart,
45
 
    test_server,
46
 
    )
 
43
from bzrlib.tests import test_smart
47
44
from bzrlib.transport import (
48
45
        http,
49
46
        local,
50
47
        memory,
51
48
        remote,
52
 
        ssh,
53
49
        )
54
50
 
55
51
 
67
63
        return StringIOSSHConnection(self)
68
64
 
69
65
 
70
 
class StringIOSSHConnection(ssh.SSHConnection):
 
66
class StringIOSSHConnection(object):
71
67
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
72
68
 
73
69
    def __init__(self, vendor):
75
71
 
76
72
    def close(self):
77
73
        self.vendor.calls.append(('close', ))
78
 
        self.vendor.read_from.close()
79
 
        self.vendor.write_to.close()
80
74
 
81
 
    def get_sock_or_pipes(self):
82
 
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
 
75
    def get_filelike_channels(self):
 
76
        return self.vendor.read_from, self.vendor.write_to
83
77
 
84
78
 
85
79
class _InvalidHostnameFeature(tests.Feature):
249
243
        unopened_port = sock.getsockname()[1]
250
244
        # having vendor be invalid means that if it tries to connect via the
251
245
        # vendor it will blow up.
252
 
        ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
253
 
        client_medium = medium.SmartSSHClientMedium(
254
 
            'base', ssh_params, "not a vendor")
 
246
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
247
            username=None, password=None, base='base', vendor="not a vendor",
 
248
            bzr_remote_path='bzr')
255
249
        sock.close()
256
250
 
257
251
    def test_ssh_client_connects_on_first_use(self):
259
253
        # it bytes.
260
254
        output = StringIO()
261
255
        vendor = StringIOSSHVendor(StringIO(), output)
262
 
        ssh_params = medium.SSHParams(
263
 
            'a hostname', 'a port', 'a username', 'a password', 'bzr')
264
 
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
256
        client_medium = medium.SmartSSHClientMedium(
 
257
            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
 
258
            'bzr')
265
259
        client_medium._accept_bytes('abc')
266
260
        self.assertEqual('abc', output.getvalue())
267
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
274
268
        # it bytes.
275
269
        output = StringIO()
276
270
        vendor = StringIOSSHVendor(StringIO(), output)
277
 
        ssh_params = medium.SSHParams(
278
 
            'a hostname', 'a port', 'a username', 'a password',
279
 
            bzr_remote_path='fugly')
280
 
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
271
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
 
272
            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
281
273
        client_medium._accept_bytes('abc')
282
274
        self.assertEqual('abc', output.getvalue())
283
275
        self.assertEqual([('connect_ssh', 'a username', 'a password',
292
284
        output = StringIO()
293
285
        vendor = StringIOSSHVendor(input, output)
294
286
        client_medium = medium.SmartSSHClientMedium(
295
 
            'base', medium.SSHParams('a hostname'), vendor)
 
287
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
296
288
        client_medium._accept_bytes('abc')
297
289
        client_medium.disconnect()
298
290
        self.assertTrue(input.closed)
313
305
        output = StringIO()
314
306
        vendor = StringIOSSHVendor(input, output)
315
307
        client_medium = medium.SmartSSHClientMedium(
316
 
            'base', medium.SSHParams('a hostname'), vendor)
 
308
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
317
309
        client_medium._accept_bytes('abc')
318
310
        client_medium.disconnect()
319
311
        # the disconnect has closed output, so we need a new output for the
342
334
        # Doing a disconnect on a new (and thus unconnected) SSH medium
343
335
        # does not fail.  It's ok to disconnect an unconnected medium.
344
336
        client_medium = medium.SmartSSHClientMedium(
345
 
            'base', medium.SSHParams(None))
 
337
            None, base='base', bzr_remote_path='bzr')
346
338
        client_medium.disconnect()
347
339
 
348
340
    def test_ssh_client_raises_on_read_when_not_connected(self):
349
341
        # Doing a read on a new (and thus unconnected) SSH medium raises
350
342
        # MediumNotConnected.
351
343
        client_medium = medium.SmartSSHClientMedium(
352
 
            'base', medium.SSHParams(None))
 
344
            None, base='base', bzr_remote_path='bzr')
353
345
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
354
346
                          0)
355
347
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
367
359
        output.flush = logging_flush
368
360
        vendor = StringIOSSHVendor(input, output)
369
361
        client_medium = medium.SmartSSHClientMedium(
370
 
            'base', medium.SSHParams('a hostname'), vendor=vendor)
 
362
            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
371
363
        # this call is here to ensure we only flush once, not on every
372
364
        # _accept_bytes call.
373
365
        client_medium._accept_bytes('abc')
975
967
            base = 'a_url'
976
968
            def external_url(self):
977
969
                return self.base
978
 
            def get(self, path):
 
970
            def get_bytes(self, path):
979
971
                raise Exception("some random exception from inside server")
980
 
 
981
 
        class FlakyServer(test_server.SmartTCPServer_for_testing):
982
 
            def get_backing_transport(self, backing_transport_server):
983
 
                return FlakyTransport()
984
 
 
985
 
        smart_server = FlakyServer()
986
 
        smart_server.start_server()
987
 
        self.addCleanup(smart_server.stop_server)
988
 
        t = remote.RemoteTCPTransport(smart_server.get_url())
989
 
        self.addCleanup(t.disconnect)
990
 
        err = self.assertRaises(errors.UnknownErrorFromSmartServer,
991
 
                                t.get, 'something')
992
 
        self.assertContainsRe(str(err), 'some random exception')
 
972
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
 
973
        smart_server.start_background_thread('-' + self.id())
 
974
        try:
 
975
            transport = remote.RemoteTCPTransport(smart_server.get_url())
 
976
            err = self.assertRaises(errors.UnknownErrorFromSmartServer,
 
977
                transport.get, 'something')
 
978
            self.assertContainsRe(str(err), 'some random exception')
 
979
            transport.disconnect()
 
980
        finally:
 
981
            smart_server.stop_background_thread()
993
982
 
994
983
 
995
984
class SmartTCPTests(tests.TestCase):
996
985
    """Tests for connection/end to end behaviour using the TCP server.
997
986
 
998
 
    All of these tests are run with a server running in another thread serving
 
987
    All of these tests are run with a server running on another thread serving
999
988
    a MemoryTransport, and a connection to it already open.
1000
989
 
1001
990
    the server is obtained by calling self.start_server(readonly=False).
1009
998
        # NB: Tests using this fall into two categories: tests of the server,
1010
999
        # tests wanting a server. The latter should be updated to use
1011
1000
        # self.vfs_transport_factory etc.
1012
 
        if backing_transport is None:
 
1001
        if not backing_transport:
1013
1002
            mem_server = memory.MemoryServer()
1014
1003
            mem_server.start_server()
1015
1004
            self.addCleanup(mem_server.stop_server)
1023
1012
            self.backing_transport = transport.get_transport(
1024
1013
                "readonly+" + self.backing_transport.abspath('.'))
1025
1014
        self.server = server.SmartTCPServer(self.backing_transport)
1026
 
        self.server.start_server('127.0.0.1', 0)
1027
1015
        self.server.start_background_thread('-' + self.id())
1028
1016
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1029
 
        self.addCleanup(self.stop_server)
 
1017
        self.addCleanup(self.tearDownServer)
1030
1018
        self.permit_url(self.server.get_url())
1031
1019
 
1032
 
    def stop_server(self):
1033
 
        """Disconnect the client and stop the server.
1034
 
 
1035
 
        This must be re-entrant as some tests will call it explicitly in
1036
 
        addition to the normal cleanup.
1037
 
        """
 
1020
    def tearDownServer(self):
1038
1021
        if getattr(self, 'transport', None):
1039
1022
            self.transport.disconnect()
1040
1023
            del self.transport
1041
1024
        if getattr(self, 'server', None):
1042
1025
            self.server.stop_background_thread()
 
1026
            # XXX: why not .stop_server() -- mbp 20100106
1043
1027
            del self.server
1044
1028
 
1045
1029
 
1046
1030
class TestServerSocketUsage(SmartTCPTests):
1047
1031
 
1048
 
    def test_server_start_stop(self):
1049
 
        """It should be safe to stop the server with no requests."""
 
1032
    def test_server_setup_teardown(self):
 
1033
        """It should be safe to teardown the server with no requests."""
1050
1034
        self.start_server()
1051
 
        t = remote.RemoteTCPTransport(self.server.get_url())
1052
 
        self.stop_server()
1053
 
        self.assertRaises(errors.ConnectionError, t.has, '.')
 
1035
        server = self.server
 
1036
        transport = remote.RemoteTCPTransport(self.server.get_url())
 
1037
        self.tearDownServer()
 
1038
        self.assertRaises(errors.ConnectionError, transport.has, '.')
1054
1039
 
1055
1040
    def test_server_closes_listening_sock_on_shutdown_after_request(self):
1056
1041
        """The server should close its listening socket when it's stopped."""
1057
1042
        self.start_server()
1058
 
        server_url = self.server.get_url()
 
1043
        server = self.server
1059
1044
        self.transport.has('.')
1060
 
        self.stop_server()
 
1045
        self.tearDownServer()
1061
1046
        # if the listening socket has closed, we should get a BADFD error
1062
1047
        # when connecting, rather than a hang.
1063
 
        t = remote.RemoteTCPTransport(server_url)
1064
 
        self.assertRaises(errors.ConnectionError, t.has, '.')
 
1048
        transport = remote.RemoteTCPTransport(server.get_url())
 
1049
        self.assertRaises(errors.ConnectionError, transport.has, '.')
1065
1050
 
1066
1051
 
1067
1052
class WritableEndToEndTests(SmartTCPTests):
1202
1187
        self.transport.has('.')
1203
1188
        self.assertEqual([], self.hook_calls)
1204
1189
        # clean up the server
1205
 
        self.stop_server()
 
1190
        self.tearDownServer()
1206
1191
        # now it should have fired.
1207
1192
        self.assertEqual(result, self.hook_calls)
1208
1193
 
1221
1206
        self.transport.has('.')
1222
1207
        self.assertEqual([], self.hook_calls)
1223
1208
        # clean up the server
1224
 
        self.stop_server()
 
1209
        self.tearDownServer()
1225
1210
        # now it should have fired.
1226
1211
        self.assertEqual(result, self.hook_calls)
1227
1212
 
2874
2859
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2875
2860
 
2876
2861
    def assertWriteCount(self, expected_count):
2877
 
        # self.writes can be quite large; don't show the whole thing
2878
2862
        self.assertEqual(
2879
2863
            expected_count, len(self.writes),
2880
 
            "Too many writes: %d, expected %d" % (len(self.writes), expected_count))
 
2864
            "Too many writes: %r" % (self.writes,))
2881
2865
 
2882
2866
    def test_send_error_writes_just_once(self):
2883
2867
        """An error response is written to the medium all at once."""
2906
2890
        response = _mod_request.SuccessfulSmartServerResponse(
2907
2891
            ('arg', 'arg'), body_stream=['chunk1', 'chunk2'])
2908
2892
        self.responder.send_response(response)
2909
 
        # Per the discussion in bug 590638 we flush once after the header and
2910
 
        # then once after each chunk
2911
 
        self.assertWriteCount(3)
 
2893
        # We will write just once, despite the multiple chunks, due to
 
2894
        # buffering.
 
2895
        self.assertWriteCount(1)
 
2896
 
 
2897
    def test_send_response_with_body_stream_flushes_buffers_sometimes(self):
 
2898
        """When there are many bytes (>1MB), multiple writes will occur rather
 
2899
        than buffering indefinitely.
 
2900
        """
 
2901
        # Construct a response with stream with ~1.5MB in it. This should
 
2902
        # trigger 2 writes, but not 3
 
2903
        onekib = '12345678' * 128
 
2904
        body_stream = [onekib] * (1024 + 512)
 
2905
        response = _mod_request.SuccessfulSmartServerResponse(
 
2906
            ('arg', 'arg'), body_stream=body_stream)
 
2907
        self.responder.send_response(response)
 
2908
        self.assertWriteCount(2)
2912
2909
 
2913
2910
 
2914
2911
class TestSmartClientUnicode(tests.TestCase):