~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: John Arbash Meinel
  • Date: 2007-04-28 15:04:17 UTC
  • mfrom: (2466 +trunk)
  • mto: This revision was merged to the branch mainline in revision 2566.
  • Revision ID: john@arbash-meinel.com-20070428150417-trp3pi0pzd411pu4
[merge] bzr.dev 2466

Show diffs side-by-side

added added

removed removed

Lines of Context:
31
31
        urlutils,
32
32
        )
33
33
from bzrlib.smart import (
 
34
        client,
34
35
        medium,
35
36
        protocol,
36
37
        request,
 
38
        request as _mod_request,
37
39
        server,
38
40
        vfs,
39
41
)
41
43
        HTTPServerWithSmarts,
42
44
        SmartRequestHandler,
43
45
        )
 
46
from bzrlib.tests.test_smart import TestCaseWithSmartMedium
44
47
from bzrlib.transport import (
45
48
        get_transport,
46
49
        local,
519
522
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
520
523
 
521
524
 
522
 
class RemoteTransportTests(tests.TestCaseWithTransport):
523
 
 
524
 
    def setUp(self):
525
 
        super(RemoteTransportTests, self).setUp()
526
 
        # We're allowed to set  the transport class here, so that we don't use
527
 
        # the default or a parameterized class, but rather use the
528
 
        # TestCaseWithTransport infrastructure to set up a smart server and
529
 
        # transport.
530
 
        self.transport_server = server.SmartTCPServer_for_testing
 
525
class RemoteTransportTests(TestCaseWithSmartMedium):
531
526
 
532
527
    def test_plausible_url(self):
533
528
        self.assert_(self.get_url().startswith('bzr://'))
534
529
 
535
530
    def test_probe_transport(self):
536
531
        t = self.get_transport()
537
 
        self.assertIsInstance(t, remote.SmartTransport)
 
532
        self.assertIsInstance(t, remote.RemoteTransport)
538
533
 
539
534
    def test_get_medium_from_transport(self):
540
535
        """Remote transport has a medium always, which it can return."""
541
536
        t = self.get_transport()
542
 
        smart_medium = t.get_smart_medium()
543
 
        self.assertIsInstance(smart_medium, medium.SmartClientMedium)
 
537
        client_medium = t.get_smart_medium()
 
538
        self.assertIsInstance(client_medium, medium.SmartClientMedium)
544
539
 
545
540
 
546
541
class ErrorRaisingProtocol(object):
604
599
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
605
600
                from_server.write)
606
601
        server._serve_one_request(smart_protocol)
607
 
        self.assertEqual('ok\0011\n',
 
602
        self.assertEqual('ok\0012\n',
608
603
                         from_server.getvalue())
609
604
 
610
605
    def test_response_to_canned_get(self):
745
740
        self.assertTrue(server.finished)
746
741
        
747
742
    def test_socket_stream_error_handling(self):
748
 
        # Use plain python StringIO so we can monkey-patch the close method to
749
 
        # not discard the contents.
750
 
        from StringIO import StringIO
751
743
        server_sock, client_sock = self.portable_socket_pair()
752
744
        server = medium.SmartServerSocketStreamMedium(
753
745
            server_sock, None)
759
751
        self.assertTrue(server.finished)
760
752
        
761
753
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
762
 
        # Use plain python StringIO so we can monkey-patch the close method to
763
 
        # not discard the contents.
764
754
        to_server = StringIO('')
765
755
        from_server = StringIO()
766
756
        server = medium.SmartServerPipeStreamMedium(
779
769
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
780
770
        server_sock.close()
781
771
        self.assertEqual('', client_sock.recv(1))
 
772
 
 
773
    def build_protocol_pipe_like(self, bytes):
 
774
        to_server = StringIO(bytes)
 
775
        from_server = StringIO()
 
776
        server = medium.SmartServerPipeStreamMedium(
 
777
            to_server, from_server, None)
 
778
        return server._build_protocol()
 
779
 
 
780
    def build_protocol_socket(self, bytes):
 
781
        server_sock, client_sock = self.portable_socket_pair()
 
782
        server = medium.SmartServerSocketStreamMedium(
 
783
            server_sock, None)
 
784
        client_sock.sendall(bytes)
 
785
        client_sock.close()
 
786
        return server._build_protocol()
 
787
 
 
788
    def assertProtocolOne(self, server_protocol):
 
789
        # Use assertIs because assertIsInstance will wrongly pass
 
790
        # SmartServerRequestProtocolTwo (because it subclasses
 
791
        # SmartServerRequestProtocolOne).
 
792
        self.assertIs(
 
793
            type(server_protocol), protocol.SmartServerRequestProtocolOne)
 
794
 
 
795
    def assertProtocolTwo(self, server_protocol):
 
796
        self.assertIsInstance(
 
797
            server_protocol, protocol.SmartServerRequestProtocolTwo)
 
798
 
 
799
    def test_pipe_like_build_protocol_empty_bytes(self):
 
800
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
801
        server_protocol = self.build_protocol_pipe_like('')
 
802
        self.assertProtocolOne(server_protocol)
 
803
        
 
804
    def test_socket_like_build_protocol_empty_bytes(self):
 
805
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
806
        server_protocol = self.build_protocol_socket('')
 
807
        self.assertProtocolOne(server_protocol)
 
808
 
 
809
    def test_pipe_like_build_protocol_non_two(self):
 
810
        # A request that doesn't start with "bzr request 2\n" is version one.
 
811
        server_protocol = self.build_protocol_pipe_like('abc\n')
 
812
        self.assertProtocolOne(server_protocol)
 
813
 
 
814
    def test_socket_build_protocol_non_two(self):
 
815
        # A request that doesn't start with "bzr request 2\n" is version one.
 
816
        server_protocol = self.build_protocol_socket('abc\n')
 
817
        self.assertProtocolOne(server_protocol)
 
818
 
 
819
    def test_pipe_like_build_protocol_two(self):
 
820
        # A request that starts with "bzr request 2\n" is version two.
 
821
        server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
 
822
        self.assertProtocolTwo(server_protocol)
 
823
 
 
824
    def test_socket_build_protocol_two(self):
 
825
        # A request that starts with "bzr request 2\n" is version two.
 
826
        server_protocol = self.build_protocol_socket('bzr request 2\n')
 
827
        self.assertProtocolTwo(server_protocol)
782
828
        
783
829
 
784
830
class TestSmartTCPServer(tests.TestCase):
793
839
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
794
840
        smart_server.start_background_thread()
795
841
        try:
796
 
            transport = remote.SmartTCPTransport(smart_server.get_url())
 
842
            transport = remote.RemoteTCPTransport(smart_server.get_url())
797
843
            try:
798
844
                transport.get('something')
799
845
            except errors.TransportError, e:
800
846
                self.assertContainsRe(str(e), 'some random exception')
801
847
            else:
802
848
                self.fail("get did not raise expected error")
 
849
            transport.disconnect()
803
850
        finally:
804
851
            smart_server.stop_background_thread()
805
852
 
824
871
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
825
872
        self.server = server.SmartTCPServer(self.backing_transport)
826
873
        self.server.start_background_thread()
827
 
        self.transport = remote.SmartTCPTransport(self.server.get_url())
 
874
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
828
875
        self.addCleanup(self.tearDownServer)
829
876
 
830
877
    def tearDownServer(self):
842
889
        """It should be safe to teardown the server with no requests."""
843
890
        self.setUpServer()
844
891
        server = self.server
845
 
        transport = remote.SmartTCPTransport(self.server.get_url())
 
892
        transport = remote.RemoteTCPTransport(self.server.get_url())
846
893
        self.tearDownServer()
847
894
        self.assertRaises(errors.ConnectionError, transport.has, '.')
848
895
 
854
901
        self.tearDownServer()
855
902
        # if the listening socket has closed, we should get a BADFD error
856
903
        # when connecting, rather than a hang.
857
 
        transport = remote.SmartTCPTransport(server.get_url())
 
904
        transport = remote.RemoteTCPTransport(server.get_url())
858
905
        self.assertRaises(errors.ConnectionError, transport.has, '.')
859
906
 
860
907
 
991
1038
class SmartServerCommandTests(tests.TestCaseWithTransport):
992
1039
    """Tests that call directly into the command objects, bypassing the network
993
1040
    and the request dispatching.
 
1041
 
 
1042
    Note: these tests are rudimentary versions of the command object tests in
 
1043
    test_remote.py.
994
1044
    """
995
1045
        
996
1046
    def test_hello(self):
997
1047
        cmd = request.HelloRequest(None)
998
1048
        response = cmd.execute()
999
 
        self.assertEqual(('ok', '1'), response.args)
 
1049
        self.assertEqual(('ok', '2'), response.args)
1000
1050
        self.assertEqual(None, response.body)
1001
1051
        
1002
1052
    def test_get_bundle(self):
1021
1071
 
1022
1072
    def build_handler(self, transport):
1023
1073
        """Returns a handler for the commands in protocol version one."""
1024
 
        return request.SmartServerRequestHandler(transport, request.request_handlers)
 
1074
        return request.SmartServerRequestHandler(transport,
 
1075
                                                 request.request_handlers)
1025
1076
 
1026
1077
    def test_construct_request_handler(self):
1027
1078
        """Constructing a request handler should be easy and set defaults."""
1031
1082
    def test_hello(self):
1032
1083
        handler = self.build_handler(None)
1033
1084
        handler.dispatch_command('hello', ())
1034
 
        self.assertEqual(('ok', '1'), handler.response.args)
 
1085
        self.assertEqual(('ok', '2'), handler.response.args)
1035
1086
        self.assertEqual(None, handler.response.body)
1036
1087
        
1037
1088
    def test_disable_vfs_handler_classes_via_environment(self):
1038
 
        # VFS handler classes will raise an error from "execute" if BZR_NO_SMART_VFS
1039
 
        # is set.
 
1089
        # VFS handler classes will raise an error from "execute" if
 
1090
        # BZR_NO_SMART_VFS is set.
1040
1091
        handler = vfs.HasRequest(None)
1041
1092
        # set environment variable after construction to make sure it's
1042
1093
        # examined.
1043
 
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp has
1044
 
        # called _captureVar, so it will be restored to the right state
 
1094
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
 
1095
        # has called _captureVar, so it will be restored to the right state
1045
1096
        # afterwards.
1046
1097
        os.environ['BZR_NO_SMART_VFS'] = ''
1047
1098
        self.assertRaises(errors.DisabledMethod, handler.execute)
1111
1162
 
1112
1163
    def test_registration(self):
1113
1164
        t = get_transport('bzr+ssh://example.com/path')
1114
 
        self.assertIsInstance(t, remote.SmartSSHTransport)
 
1165
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1115
1166
        self.assertEqual('example.com', t._host)
1116
1167
 
1117
1168
 
1122
1173
        input = StringIO("ok\n3\nbardone\n")
1123
1174
        output = StringIO()
1124
1175
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1125
 
        transport = remote.SmartTransport(
 
1176
        transport = remote.RemoteTransport(
1126
1177
            'bzr://localhost/', medium=client_medium)
1127
1178
 
1128
1179
        # We want to make sure the client is used when the first remote
1142
1193
    def test__translate_error_readonly(self):
1143
1194
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1144
1195
        client_medium = medium.SmartClientMedium()
1145
 
        transport = remote.SmartTransport(
 
1196
        transport = remote.RemoteTransport(
1146
1197
            'bzr://localhost/', medium=client_medium)
1147
1198
        self.assertRaises(errors.TransportNotPossible,
1148
1199
            transport._translate_error, ("ReadOnlyError", ))
1157
1208
 
1158
1209
 
1159
1210
class TestSmartProtocol(tests.TestCase):
1160
 
    """Tests for the smart protocol.
 
1211
    """Base class for smart protocol tests.
1161
1212
 
1162
1213
    Each test case gets a smart_server and smart_client created during setUp().
1163
1214
 
1167
1218
    serialised client request. Output done by the client or server for these
1168
1219
    calls will be captured to self.to_server and self.to_client. Each element
1169
1220
    in the list is a write call from the client or server respectively.
 
1221
 
 
1222
    Subclasses can override client_protocol_class and server_protocol_class.
1170
1223
    """
1171
1224
 
 
1225
    client_protocol_class = None
 
1226
    server_protocol_class = None
 
1227
 
1172
1228
    def setUp(self):
1173
1229
        super(TestSmartProtocol, self).setUp()
1174
1230
        # XXX: self.server_to_client doesn't seem to be used.  If so,
1178
1234
        self.to_client = StringIO()
1179
1235
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1180
1236
            self.to_server)
1181
 
        self.client_protocol = protocol.SmartClientRequestProtocolOne(
1182
 
            self.client_medium)
 
1237
        self.client_protocol = self.client_protocol_class(self.client_medium)
1183
1238
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1184
1239
        self.smart_server_request = request.SmartServerRequestHandler(
1185
1240
            None, request.request_handlers)
1205
1260
 
1206
1261
    def build_protocol_waiting_for_body(self):
1207
1262
        out_stream = StringIO()
1208
 
        smart_protocol = protocol.SmartServerRequestProtocolOne(None,
1209
 
                out_stream.write)
 
1263
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1210
1264
        smart_protocol.has_dispatched = True
1211
1265
        smart_protocol.request = self.smart_server_request
1212
1266
        class FakeCommand(object):
1213
1267
            def do_body(cmd, body_bytes):
1214
1268
                self.end_received = True
1215
1269
                self.assertEqual('abcdefg', body_bytes)
1216
 
                return request.SmartServerResponse(('ok', ))
 
1270
                return request.SuccessfulSmartServerResponse(('ok', ))
1217
1271
        smart_protocol.request._command = FakeCommand()
1218
1272
        # Call accept_bytes to make sure that internal state like _body_decoder
1219
1273
        # is initialised.  This test should probably be given a clearer
1222
1276
        smart_protocol.accept_bytes('')
1223
1277
        return smart_protocol
1224
1278
 
 
1279
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
 
1280
            input_tuples):
 
1281
        """Assert that each input_tuple serialises as expected_bytes, and the
 
1282
        bytes deserialise as expected_tuple.
 
1283
        """
 
1284
        # check the encoding of the server for all input_tuples matches
 
1285
        # expected bytes
 
1286
        for input_tuple in input_tuples:
 
1287
            server_output = StringIO()
 
1288
            server_protocol = self.server_protocol_class(
 
1289
                None, server_output.write)
 
1290
            server_protocol._send_response(
 
1291
                _mod_request.SuccessfulSmartServerResponse(input_tuple))
 
1292
            self.assertEqual(expected_bytes, server_output.getvalue())
 
1293
        # check the decoding of the client smart_protocol from expected_bytes:
 
1294
        input = StringIO(expected_bytes)
 
1295
        output = StringIO()
 
1296
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1297
        request = client_medium.get_request()
 
1298
        smart_protocol = self.client_protocol_class(request)
 
1299
        smart_protocol.call('foo')
 
1300
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1301
 
 
1302
 
 
1303
class TestSmartProtocolOne(TestSmartProtocol):
 
1304
    """Tests for the smart protocol version one."""
 
1305
 
 
1306
    client_protocol_class = protocol.SmartClientRequestProtocolOne
 
1307
    server_protocol_class = protocol.SmartServerRequestProtocolOne
 
1308
 
1225
1309
    def test_construct_version_one_server_protocol(self):
1226
1310
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1227
1311
        self.assertEqual('', smart_protocol.excess_buffer)
1291
1375
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1292
1376
            None, out_stream.write)
1293
1377
        smart_protocol.accept_bytes('hello\nhello\n')
1294
 
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1378
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1295
1379
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
1296
1380
        self.assertEqual("", smart_protocol.in_buffer)
1297
1381
 
1310
1394
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1311
1395
            None, out_stream.write)
1312
1396
        smart_protocol.accept_bytes('hello\n')
1313
 
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1397
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1314
1398
        smart_protocol.accept_bytes('hel')
1315
1399
        self.assertEqual("hel", smart_protocol.excess_buffer)
1316
1400
        smart_protocol.accept_bytes('lo\n')
1321
1405
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1322
1406
            None, lambda x: None)
1323
1407
        self.assertEqual(1, smart_protocol.next_read_size())
1324
 
        smart_protocol._send_response(('x',))
 
1408
        smart_protocol._send_response(
 
1409
            request.SuccessfulSmartServerResponse(('x',)))
1325
1410
        self.assertEqual(0, smart_protocol.next_read_size())
1326
1411
 
 
1412
    def test__send_response_errors_with_base_response(self):
 
1413
        """Ensure that only the Successful/Failed subclasses are used."""
 
1414
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1415
            None, lambda x: None)
 
1416
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1417
            request.SmartServerResponse(('x',)))
 
1418
 
1327
1419
    def test_query_version(self):
1328
1420
        """query_version on a SmartClientProtocolOne should return a number.
1329
1421
        
1334
1426
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1335
1427
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
1336
1428
        # the error if the response is a non-understood version.
1337
 
        input = StringIO('ok\x011\n')
1338
 
        output = StringIO()
1339
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1340
 
        request = client_medium.get_request()
1341
 
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1342
 
        self.assertEqual(1, smart_protocol.query_version())
1343
 
 
1344
 
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1345
 
            input_tuples):
1346
 
        """Assert that each input_tuple serialises as expected_bytes, and the
1347
 
        bytes deserialise as expected_tuple.
1348
 
        """
1349
 
        # check the encoding of the server for all input_tuples matches
1350
 
        # expected bytes
1351
 
        for input_tuple in input_tuples:
1352
 
            server_output = StringIO()
1353
 
            server_protocol = protocol.SmartServerRequestProtocolOne(
1354
 
                None, server_output.write)
1355
 
            server_protocol._send_response(input_tuple)
1356
 
            self.assertEqual(expected_bytes, server_output.getvalue())
1357
 
        # check the decoding of the client smart_protocol from expected_bytes:
1358
 
        input = StringIO(expected_bytes)
1359
 
        output = StringIO()
1360
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1361
 
        request = client_medium.get_request()
1362
 
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1363
 
        smart_protocol.call('foo')
1364
 
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1429
        input = StringIO('ok\x012\n')
 
1430
        output = StringIO()
 
1431
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1432
        request = client_medium.get_request()
 
1433
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1434
        self.assertEqual(2, smart_protocol.query_version())
1365
1435
 
1366
1436
    def test_client_call_empty_response(self):
1367
1437
        # protocol.call() can get back an empty tuple as a response. This occurs
1451
1521
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1452
1522
 
1453
1523
 
 
1524
class TestSmartProtocolTwo(TestSmartProtocol):
 
1525
    """Tests for the smart protocol version two.
 
1526
 
 
1527
    This test case is mostly the same as TestSmartProtocolOne.
 
1528
    """
 
1529
 
 
1530
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
1531
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
1532
 
 
1533
    def test_construct_version_two_server_protocol(self):
 
1534
        smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
 
1535
        self.assertEqual('', smart_protocol.excess_buffer)
 
1536
        self.assertEqual('', smart_protocol.in_buffer)
 
1537
        self.assertFalse(smart_protocol.has_dispatched)
 
1538
        self.assertEqual(1, smart_protocol.next_read_size())
 
1539
 
 
1540
    def test_construct_version_two_client_protocol(self):
 
1541
        # we can construct a client protocol from a client medium request
 
1542
        output = StringIO()
 
1543
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1544
        request = client_medium.get_request()
 
1545
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1546
 
 
1547
    def test_server_offset_serialisation(self):
 
1548
        """The Smart protocol serialises offsets as a comma and \n string.
 
1549
 
 
1550
        We check a number of boundary cases are as expected: empty, one offset,
 
1551
        one with the order of reads not increasing (an out of order read), and
 
1552
        one that should coalesce.
 
1553
        """
 
1554
        self.assertOffsetSerialisation([], '', self.client_protocol)
 
1555
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
 
1556
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
 
1557
            self.client_protocol)
 
1558
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
 
1559
            '1,2\n3,4\n100,200', self.client_protocol)
 
1560
 
 
1561
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1562
        out_stream = StringIO()
 
1563
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1564
            None, out_stream.write)
 
1565
        smart_protocol.accept_bytes('abc')
 
1566
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1567
        smart_protocol.accept_bytes('\n')
 
1568
        self.assertEqual(
 
1569
            protocol.RESPONSE_VERSION_TWO +
 
1570
            "failed\nerror\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1571
            out_stream.getvalue())
 
1572
        self.assertTrue(smart_protocol.has_dispatched)
 
1573
        self.assertEqual(0, smart_protocol.next_read_size())
 
1574
 
 
1575
    def test_accept_body_bytes_to_protocol(self):
 
1576
        protocol = self.build_protocol_waiting_for_body()
 
1577
        self.assertEqual(6, protocol.next_read_size())
 
1578
        protocol.accept_bytes('7\nabc')
 
1579
        self.assertEqual(9, protocol.next_read_size())
 
1580
        protocol.accept_bytes('defgd')
 
1581
        protocol.accept_bytes('one\n')
 
1582
        self.assertEqual(0, protocol.next_read_size())
 
1583
        self.assertTrue(self.end_received)
 
1584
 
 
1585
    def test_accept_request_and_body_all_at_once(self):
 
1586
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1587
        mem_transport = memory.MemoryTransport()
 
1588
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1589
        out_stream = StringIO()
 
1590
        smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
 
1591
                out_stream.write)
 
1592
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1593
        self.assertEqual(0, smart_protocol.next_read_size())
 
1594
        self.assertEqual(protocol.RESPONSE_VERSION_TWO +
 
1595
                         'success\nreadv\n3\ndefdone\n',
 
1596
                         out_stream.getvalue())
 
1597
        self.assertEqual('', smart_protocol.excess_buffer)
 
1598
        self.assertEqual('', smart_protocol.in_buffer)
 
1599
 
 
1600
    def test_accept_excess_bytes_are_preserved(self):
 
1601
        out_stream = StringIO()
 
1602
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1603
            None, out_stream.write)
 
1604
        smart_protocol.accept_bytes('hello\nhello\n')
 
1605
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
 
1606
                         out_stream.getvalue())
 
1607
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1608
        self.assertEqual("", smart_protocol.in_buffer)
 
1609
 
 
1610
    def test_accept_excess_bytes_after_body(self):
 
1611
        # The excess bytes look like the start of another request.
 
1612
        server_protocol = self.build_protocol_waiting_for_body()
 
1613
        server_protocol.accept_bytes(
 
1614
            '7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
 
1615
        self.assertTrue(self.end_received)
 
1616
        self.assertEqual(protocol.RESPONSE_VERSION_TWO,
 
1617
                         server_protocol.excess_buffer)
 
1618
        self.assertEqual("", server_protocol.in_buffer)
 
1619
        server_protocol.accept_bytes('Y')
 
1620
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
 
1621
                         server_protocol.excess_buffer)
 
1622
        self.assertEqual("", server_protocol.in_buffer)
 
1623
 
 
1624
    def test_accept_excess_bytes_after_dispatch(self):
 
1625
        out_stream = StringIO()
 
1626
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1627
            None, out_stream.write)
 
1628
        smart_protocol.accept_bytes('hello\n')
 
1629
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "success\nok\x012\n",
 
1630
                         out_stream.getvalue())
 
1631
        smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
 
1632
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
 
1633
                         smart_protocol.excess_buffer)
 
1634
        smart_protocol.accept_bytes('lo\n')
 
1635
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
 
1636
                         smart_protocol.excess_buffer)
 
1637
        self.assertEqual("", smart_protocol.in_buffer)
 
1638
 
 
1639
    def test__send_response_sets_finished_reading(self):
 
1640
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1641
            None, lambda x: None)
 
1642
        self.assertEqual(1, smart_protocol.next_read_size())
 
1643
        smart_protocol._send_response(
 
1644
            request.SuccessfulSmartServerResponse(('x',)))
 
1645
        self.assertEqual(0, smart_protocol.next_read_size())
 
1646
 
 
1647
    def test__send_response_errors_with_base_response(self):
 
1648
        """Ensure that only the Successful/Failed subclasses are used."""
 
1649
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1650
            None, lambda x: None)
 
1651
        self.assertRaises(AttributeError, smart_protocol._send_response,
 
1652
            request.SmartServerResponse(('x',)))
 
1653
 
 
1654
    def test__send_response_includes_failure_marker(self):
 
1655
        """FailedSmartServerResponse have 'failed\n' after the version."""
 
1656
        out_stream = StringIO()
 
1657
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1658
            None, out_stream.write)
 
1659
        smart_protocol._send_response(
 
1660
            request.FailedSmartServerResponse(('x',)))
 
1661
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'failed\nx\n',
 
1662
                         out_stream.getvalue())
 
1663
 
 
1664
    def test__send_response_includes_success_marker(self):
 
1665
        """SuccessfulSmartServerResponse have 'success\n' after the version."""
 
1666
        out_stream = StringIO()
 
1667
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1668
            None, out_stream.write)
 
1669
        smart_protocol._send_response(
 
1670
            request.SuccessfulSmartServerResponse(('x',)))
 
1671
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'success\nx\n',
 
1672
                         out_stream.getvalue())
 
1673
 
 
1674
    def test_query_version(self):
 
1675
        """query_version on a SmartClientProtocolTwo should return a number.
 
1676
        
 
1677
        The protocol provides the query_version because the domain level clients
 
1678
        may all need to be able to probe for capabilities.
 
1679
        """
 
1680
        # What we really want to test here is that SmartClientProtocolTwo calls
 
1681
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1682
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1683
        # the error if the response is a non-understood version.
 
1684
        input = StringIO(protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
 
1685
        output = StringIO()
 
1686
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1687
        request = client_medium.get_request()
 
1688
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1689
        self.assertEqual(2, smart_protocol.query_version())
 
1690
 
 
1691
    def test_client_call_empty_response(self):
 
1692
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1693
        # when the parsed line is an empty line, and results in a tuple with
 
1694
        # one element - an empty string.
 
1695
        self.assertServerToClientEncoding(
 
1696
            protocol.RESPONSE_VERSION_TWO + 'success\n\n', ('', ), [(), ('', )])
 
1697
 
 
1698
    def test_client_call_three_element_response(self):
 
1699
        # protocol.call() can get back tuples of other lengths. A three element
 
1700
        # tuple should be unpacked as three strings.
 
1701
        self.assertServerToClientEncoding(
 
1702
            protocol.RESPONSE_VERSION_TWO + 'success\na\x01b\x0134\n',
 
1703
            ('a', 'b', '34'),
 
1704
            [('a', 'b', '34')])
 
1705
 
 
1706
    def test_client_call_with_body_bytes_uploads(self):
 
1707
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1708
        # wire.
 
1709
        expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
 
1710
        input = StringIO("\n")
 
1711
        output = StringIO()
 
1712
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1713
        request = client_medium.get_request()
 
1714
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1715
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1716
        self.assertEqual(expected_bytes, output.getvalue())
 
1717
 
 
1718
    def test_client_call_with_body_readv_array(self):
 
1719
        # protocol.call_with_upload should encode the readv array and then
 
1720
        # length-prefix the bytes onto the wire.
 
1721
        expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
 
1722
        input = StringIO("\n")
 
1723
        output = StringIO()
 
1724
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1725
        request = client_medium.get_request()
 
1726
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1727
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1728
        self.assertEqual(expected_bytes, output.getvalue())
 
1729
 
 
1730
    def test_client_read_response_tuple_sets_response_status(self):
 
1731
        server_bytes = protocol.RESPONSE_VERSION_TWO + "success\nok\n"
 
1732
        input = StringIO(server_bytes)
 
1733
        output = StringIO()
 
1734
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1735
        request = client_medium.get_request()
 
1736
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1737
        smart_protocol.call('foo')
 
1738
        smart_protocol.read_response_tuple(False)
 
1739
        self.assertEqual(True, smart_protocol.response_status)
 
1740
 
 
1741
    def test_client_read_body_bytes_all(self):
 
1742
        # read_body_bytes should decode the body bytes from the wire into
 
1743
        # a response.
 
1744
        expected_bytes = "1234567"
 
1745
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
1746
                        "success\nok\n7\n1234567done\n")
 
1747
        input = StringIO(server_bytes)
 
1748
        output = StringIO()
 
1749
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1750
        request = client_medium.get_request()
 
1751
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1752
        smart_protocol.call('foo')
 
1753
        smart_protocol.read_response_tuple(True)
 
1754
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
1755
 
 
1756
    def test_client_read_body_bytes_incremental(self):
 
1757
        # test reading a few bytes at a time from the body
 
1758
        # XXX: possibly we should test dribbling the bytes into the stringio
 
1759
        # to make the state machine work harder: however, as we use the
 
1760
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
1761
        # that.
 
1762
        expected_bytes = "1234567"
 
1763
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
1764
                        "success\nok\n7\n1234567done\n")
 
1765
        input = StringIO(server_bytes)
 
1766
        output = StringIO()
 
1767
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1768
        request = client_medium.get_request()
 
1769
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1770
        smart_protocol.call('foo')
 
1771
        smart_protocol.read_response_tuple(True)
 
1772
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1773
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1774
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1775
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
1776
 
 
1777
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
1778
        # cancelling the expected body needs to finish the request, but not
 
1779
        # read any more bytes.
 
1780
        expected_bytes = "1234567"
 
1781
        server_bytes = (protocol.RESPONSE_VERSION_TWO +
 
1782
                        "success\nok\n7\n1234567done\n")
 
1783
        input = StringIO(server_bytes)
 
1784
        output = StringIO()
 
1785
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1786
        request = client_medium.get_request()
 
1787
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1788
        smart_protocol.call('foo')
 
1789
        smart_protocol.read_response_tuple(True)
 
1790
        smart_protocol.cancel_read_body()
 
1791
        self.assertEqual(len(protocol.RESPONSE_VERSION_TWO + 'success\nok\n'),
 
1792
                         input.tell())
 
1793
        self.assertRaises(
 
1794
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
1795
 
 
1796
 
 
1797
class TestSmartClientUnicode(tests.TestCase):
 
1798
    """_SmartClient tests for unicode arguments.
 
1799
 
 
1800
    Unicode arguments to call_with_body_bytes are not correct (remote method
 
1801
    names, arguments, and bodies must all be expressed as byte strings), but
 
1802
    _SmartClient should gracefully reject them, rather than getting into a
 
1803
    broken state that prevents future correct calls from working.  That is, it
 
1804
    should be possible to issue more requests on the medium afterwards, rather
 
1805
    than allowing one bad call to call_with_body_bytes to cause later calls to
 
1806
    mysteriously fail with TooManyConcurrentRequests.
 
1807
    """
 
1808
 
 
1809
    def assertCallDoesNotBreakMedium(self, method, args, body):
 
1810
        """Call a medium with the given method, args and body, then assert that
 
1811
        the medium is left in a sane state, i.e. is capable of allowing further
 
1812
        requests.
 
1813
        """
 
1814
        input = StringIO("\n")
 
1815
        output = StringIO()
 
1816
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1817
        smart_client = client._SmartClient(client_medium)
 
1818
        self.assertRaises(TypeError,
 
1819
            smart_client.call_with_body_bytes, method, args, body)
 
1820
        self.assertEqual("", output.getvalue())
 
1821
        self.assertEqual(None, client_medium._current_request)
 
1822
 
 
1823
    def test_call_with_body_bytes_unicode_method(self):
 
1824
        self.assertCallDoesNotBreakMedium(u'method', ('args',), 'body')
 
1825
 
 
1826
    def test_call_with_body_bytes_unicode_args(self):
 
1827
        self.assertCallDoesNotBreakMedium('method', (u'args',), 'body')
 
1828
        self.assertCallDoesNotBreakMedium('method', ('arg1', u'arg2'), 'body')
 
1829
 
 
1830
    def test_call_with_body_bytes_unicode_body(self):
 
1831
        self.assertCallDoesNotBreakMedium('method', ('args',), u'body')
 
1832
 
 
1833
 
1454
1834
class LengthPrefixedBodyDecoder(tests.TestCase):
1455
1835
 
1456
1836
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
1518
1898
        self.assertEqual('', decoder.unused_data)
1519
1899
 
1520
1900
 
 
1901
class TestSuccessfulSmartServerResponse(tests.TestCase):
 
1902
 
 
1903
    def test_construct(self):
 
1904
        response = request.SuccessfulSmartServerResponse(('foo', 'bar'))
 
1905
        self.assertEqual(('foo', 'bar'), response.args)
 
1906
        self.assertEqual(None, response.body)
 
1907
        response = request.SuccessfulSmartServerResponse(('foo', 'bar'), 'bytes')
 
1908
        self.assertEqual(('foo', 'bar'), response.args)
 
1909
        self.assertEqual('bytes', response.body)
 
1910
 
 
1911
    def test_is_successful(self):
 
1912
        """is_successful should return True for SuccessfulSmartServerResponse."""
 
1913
        response = request.SuccessfulSmartServerResponse(('error',))
 
1914
        self.assertEqual(True, response.is_successful())
 
1915
 
 
1916
 
 
1917
class TestFailedSmartServerResponse(tests.TestCase):
 
1918
 
 
1919
    def test_construct(self):
 
1920
        response = request.FailedSmartServerResponse(('foo', 'bar'))
 
1921
        self.assertEqual(('foo', 'bar'), response.args)
 
1922
        self.assertEqual(None, response.body)
 
1923
        response = request.FailedSmartServerResponse(('foo', 'bar'), 'bytes')
 
1924
        self.assertEqual(('foo', 'bar'), response.args)
 
1925
        self.assertEqual('bytes', response.body)
 
1926
 
 
1927
    def test_is_successful(self):
 
1928
        """is_successful should return False for FailedSmartServerResponse."""
 
1929
        response = request.FailedSmartServerResponse(('error',))
 
1930
        self.assertEqual(False, response.is_successful())
 
1931
 
 
1932
 
1521
1933
class FakeHTTPMedium(object):
1522
1934
    def __init__(self):
1523
1935
        self.written_request = None
1544
1956
        http_transport = self.get_readonly_transport()
1545
1957
        medium = http_transport.get_smart_medium()
1546
1958
        #remote_transport = RemoteTransport('fake_url', medium)
1547
 
        remote_transport = remote.SmartTransport('/', medium=medium)
 
1959
        remote_transport = remote.RemoteTransport('/', medium=medium)
1548
1960
        self.assertEqual(
1549
1961
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
1550
1962
 
1573
1985
        self.addCleanup(http_server.tearDown)
1574
1986
 
1575
1987
        post_body = 'hello\n'
1576
 
        expected_reply_body = 'ok\x011\n'
 
1988
        expected_reply_body = 'ok\x012\n'
1577
1989
 
1578
1990
        http_transport = get_transport(http_server.get_url())
1579
1991
        medium = http_transport.get_smart_medium()
1594
2006
        self.transport_readonly_server = HTTPServerWithSmarts
1595
2007
 
1596
2008
        post_body = 'hello\n'
1597
 
        expected_reply_body = 'ok\x011\n'
 
2009
        expected_reply_body = 'ok\x012\n'
1598
2010
 
1599
2011
        smart_server_url = self.get_readonly_url('.bzr/smart')
1600
2012
        reply = urllib2.urlopen(smart_server_url, post_body).read()
1620
2032
        response = socket.writefile.getvalue()
1621
2033
        self.assertStartsWith(response, 'HTTP/1.0 200 ')
1622
2034
        # This includes the end of the HTTP headers, and all the body.
1623
 
        expected_end_of_response = '\r\n\r\nok\x011\n'
 
2035
        expected_end_of_response = '\r\n\r\nok\x012\n'
1624
2036
        self.assertEndsWith(response, expected_end_of_response)
1625
2037
 
1626
2038
 
1646
2058
            return self.writefile
1647
2059
 
1648
2060
 
 
2061
class RemoteHTTPTransportTestCase(tests.TestCase):
 
2062
 
 
2063
    def test_remote_path_after_clone_child(self):
 
2064
        # If a user enters "bzr+http://host/foo", we want to sent all smart
 
2065
        # requests for child URLs of that to the original URL.  i.e., we want to
 
2066
        # POST to "bzr+http://host/foo/.bzr/smart" and never something like
 
2067
        # "bzr+http://host/foo/.bzr/branch/.bzr/smart".  So, a cloned
 
2068
        # RemoteHTTPTransport remembers the initial URL, and adjusts the relpaths
 
2069
        # it sends in smart requests accordingly.
 
2070
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/path')
 
2071
        new_transport = base_transport.clone('child_dir')
 
2072
        self.assertEqual(base_transport._http_transport,
 
2073
                         new_transport._http_transport)
 
2074
        self.assertEqual('child_dir/foo', new_transport._remote_path('foo'))
 
2075
 
 
2076
    def test_remote_path_after_clone_parent(self):
 
2077
        # However, accessing a parent directory should go direct to the parent's
 
2078
        # URL.  We don't send relpaths like "../foo" in smart requests.
 
2079
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/path1/path2')
 
2080
        new_transport = base_transport.clone('..')
 
2081
        self.assertEqual('foo', new_transport._remote_path('foo'))
 
2082
        new_transport = base_transport.clone('../')
 
2083
        self.assertEqual('foo', new_transport._remote_path('foo'))
 
2084
        new_transport = base_transport.clone('../abc')
 
2085
        self.assertEqual('foo', new_transport._remote_path('foo'))
 
2086
        # "abc/../.." should be equivalent to ".."
 
2087
        new_transport = base_transport.clone('abc/../..')
 
2088
        self.assertEqual('foo', new_transport._remote_path('foo'))
 
2089
 
 
2090
        
1649
2091
# TODO: Client feature that does get_bundle and then installs that into a
1650
2092
# branch; this can be used in place of the regular pull/fetch operation when
1651
2093
# coming from a smart server.