~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2007-04-26 05:42:38 UTC
  • mfrom: (2432.2.9 hpss-protocol2)
  • Revision ID: pqm@pqm.ubuntu.com-20070426054238-v6k5ge3z766vaafk
(Andrew Bennetts, Robert Collins) Smart server protocol versioning.

Show diffs side-by-side

added added

removed removed

Lines of Context:
598
598
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
599
599
                from_server.write)
600
600
        server._serve_one_request(smart_protocol)
601
 
        self.assertEqual('ok\0011\n',
 
601
        self.assertEqual('ok\0012\n',
602
602
                         from_server.getvalue())
603
603
 
604
604
    def test_response_to_canned_get(self):
768
768
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
769
769
        server_sock.close()
770
770
        self.assertEqual('', client_sock.recv(1))
 
771
 
 
772
    def build_protocol_pipe_like(self, bytes):
 
773
        to_server = StringIO(bytes)
 
774
        from_server = StringIO()
 
775
        server = medium.SmartServerPipeStreamMedium(
 
776
            to_server, from_server, None)
 
777
        return server._build_protocol()
 
778
 
 
779
    def build_protocol_socket(self, bytes):
 
780
        server_sock, client_sock = self.portable_socket_pair()
 
781
        server = medium.SmartServerSocketStreamMedium(
 
782
            server_sock, None)
 
783
        client_sock.sendall(bytes)
 
784
        client_sock.close()
 
785
        return server._build_protocol()
 
786
 
 
787
    def assertProtocolOne(self, server_protocol):
 
788
        # Use assertIs because assertIsInstance will wrongly pass
 
789
        # SmartServerRequestProtocolTwo (because it subclasses
 
790
        # SmartServerRequestProtocolOne).
 
791
        self.assertIs(
 
792
            type(server_protocol), protocol.SmartServerRequestProtocolOne)
 
793
 
 
794
    def assertProtocolTwo(self, server_protocol):
 
795
        self.assertIsInstance(
 
796
            server_protocol, protocol.SmartServerRequestProtocolTwo)
 
797
 
 
798
    def test_pipe_like_build_protocol_empty_bytes(self):
 
799
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
800
        server_protocol = self.build_protocol_pipe_like('')
 
801
        self.assertProtocolOne(server_protocol)
 
802
        
 
803
    def test_socket_like_build_protocol_empty_bytes(self):
 
804
        # Any empty request (i.e. no bytes) is detected as protocol version one.
 
805
        server_protocol = self.build_protocol_socket('')
 
806
        self.assertProtocolOne(server_protocol)
 
807
 
 
808
    def test_pipe_like_build_protocol_non_two(self):
 
809
        # A request that doesn't start with "bzr request 2\n" is version one.
 
810
        server_protocol = self.build_protocol_pipe_like('abc\n')
 
811
        self.assertProtocolOne(server_protocol)
 
812
 
 
813
    def test_socket_build_protocol_non_two(self):
 
814
        # A request that doesn't start with "bzr request 2\n" is version one.
 
815
        server_protocol = self.build_protocol_socket('abc\n')
 
816
        self.assertProtocolOne(server_protocol)
 
817
 
 
818
    def test_pipe_like_build_protocol_two(self):
 
819
        # A request that starts with "bzr request 2\n" is version two.
 
820
        server_protocol = self.build_protocol_pipe_like('bzr request 2\n')
 
821
        self.assertProtocolTwo(server_protocol)
 
822
 
 
823
    def test_socket_build_protocol_two(self):
 
824
        # A request that starts with "bzr request 2\n" is version two.
 
825
        server_protocol = self.build_protocol_socket('bzr request 2\n')
 
826
        self.assertProtocolTwo(server_protocol)
771
827
        
772
828
 
773
829
class TestSmartTCPServer(tests.TestCase):
981
1037
class SmartServerCommandTests(tests.TestCaseWithTransport):
982
1038
    """Tests that call directly into the command objects, bypassing the network
983
1039
    and the request dispatching.
 
1040
 
 
1041
    Note: these tests are rudimentary versions of the command object tests in
 
1042
    test_remote.py.
984
1043
    """
985
1044
        
986
1045
    def test_hello(self):
987
1046
        cmd = request.HelloRequest(None)
988
1047
        response = cmd.execute()
989
 
        self.assertEqual(('ok', '1'), response.args)
 
1048
        self.assertEqual(('ok', '2'), response.args)
990
1049
        self.assertEqual(None, response.body)
991
1050
        
992
1051
    def test_get_bundle(self):
1022
1081
    def test_hello(self):
1023
1082
        handler = self.build_handler(None)
1024
1083
        handler.dispatch_command('hello', ())
1025
 
        self.assertEqual(('ok', '1'), handler.response.args)
 
1084
        self.assertEqual(('ok', '2'), handler.response.args)
1026
1085
        self.assertEqual(None, handler.response.body)
1027
1086
        
1028
1087
    def test_disable_vfs_handler_classes_via_environment(self):
1148
1207
 
1149
1208
 
1150
1209
class TestSmartProtocol(tests.TestCase):
1151
 
    """Tests for the smart protocol.
 
1210
    """Base class for smart protocol tests.
1152
1211
 
1153
1212
    Each test case gets a smart_server and smart_client created during setUp().
1154
1213
 
1158
1217
    serialised client request. Output done by the client or server for these
1159
1218
    calls will be captured to self.to_server and self.to_client. Each element
1160
1219
    in the list is a write call from the client or server respectively.
 
1220
 
 
1221
    Subclasses can override client_protocol_class and server_protocol_class.
1161
1222
    """
1162
1223
 
 
1224
    client_protocol_class = None
 
1225
    server_protocol_class = None
 
1226
 
1163
1227
    def setUp(self):
1164
1228
        super(TestSmartProtocol, self).setUp()
1165
1229
        # XXX: self.server_to_client doesn't seem to be used.  If so,
1169
1233
        self.to_client = StringIO()
1170
1234
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1171
1235
            self.to_server)
1172
 
        self.client_protocol = protocol.SmartClientRequestProtocolOne(
1173
 
            self.client_medium)
 
1236
        self.client_protocol = self.client_protocol_class(self.client_medium)
1174
1237
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1175
1238
        self.smart_server_request = request.SmartServerRequestHandler(
1176
1239
            None, request.request_handlers)
1196
1259
 
1197
1260
    def build_protocol_waiting_for_body(self):
1198
1261
        out_stream = StringIO()
1199
 
        smart_protocol = protocol.SmartServerRequestProtocolOne(None,
1200
 
                out_stream.write)
 
1262
        smart_protocol = self.server_protocol_class(None, out_stream.write)
1201
1263
        smart_protocol.has_dispatched = True
1202
1264
        smart_protocol.request = self.smart_server_request
1203
1265
        class FakeCommand(object):
1213
1275
        smart_protocol.accept_bytes('')
1214
1276
        return smart_protocol
1215
1277
 
 
1278
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
 
1279
            input_tuples):
 
1280
        """Assert that each input_tuple serialises as expected_bytes, and the
 
1281
        bytes deserialise as expected_tuple.
 
1282
        """
 
1283
        # check the encoding of the server for all input_tuples matches
 
1284
        # expected bytes
 
1285
        for input_tuple in input_tuples:
 
1286
            server_output = StringIO()
 
1287
            server_protocol = self.server_protocol_class(
 
1288
                None, server_output.write)
 
1289
            server_protocol._send_response(input_tuple)
 
1290
            self.assertEqual(expected_bytes, server_output.getvalue())
 
1291
        # check the decoding of the client smart_protocol from expected_bytes:
 
1292
        input = StringIO(expected_bytes)
 
1293
        output = StringIO()
 
1294
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1295
        request = client_medium.get_request()
 
1296
        smart_protocol = self.client_protocol_class(request)
 
1297
        smart_protocol.call('foo')
 
1298
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1299
 
 
1300
 
 
1301
class TestSmartProtocolOne(TestSmartProtocol):
 
1302
    """Tests for the smart protocol version one."""
 
1303
 
 
1304
    client_protocol_class = protocol.SmartClientRequestProtocolOne
 
1305
    server_protocol_class = protocol.SmartServerRequestProtocolOne
 
1306
 
1216
1307
    def test_construct_version_one_server_protocol(self):
1217
1308
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
1218
1309
        self.assertEqual('', smart_protocol.excess_buffer)
1282
1373
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1283
1374
            None, out_stream.write)
1284
1375
        smart_protocol.accept_bytes('hello\nhello\n')
1285
 
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1376
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1286
1377
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
1287
1378
        self.assertEqual("", smart_protocol.in_buffer)
1288
1379
 
1301
1392
        smart_protocol = protocol.SmartServerRequestProtocolOne(
1302
1393
            None, out_stream.write)
1303
1394
        smart_protocol.accept_bytes('hello\n')
1304
 
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1395
        self.assertEqual("ok\x012\n", out_stream.getvalue())
1305
1396
        smart_protocol.accept_bytes('hel')
1306
1397
        self.assertEqual("hel", smart_protocol.excess_buffer)
1307
1398
        smart_protocol.accept_bytes('lo\n')
1325
1416
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
1326
1417
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
1327
1418
        # the error if the response is a non-understood version.
1328
 
        input = StringIO('ok\x011\n')
1329
 
        output = StringIO()
1330
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1331
 
        request = client_medium.get_request()
1332
 
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1333
 
        self.assertEqual(1, smart_protocol.query_version())
1334
 
 
1335
 
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1336
 
            input_tuples):
1337
 
        """Assert that each input_tuple serialises as expected_bytes, and the
1338
 
        bytes deserialise as expected_tuple.
1339
 
        """
1340
 
        # check the encoding of the server for all input_tuples matches
1341
 
        # expected bytes
1342
 
        for input_tuple in input_tuples:
1343
 
            server_output = StringIO()
1344
 
            server_protocol = protocol.SmartServerRequestProtocolOne(
1345
 
                None, server_output.write)
1346
 
            server_protocol._send_response(input_tuple)
1347
 
            self.assertEqual(expected_bytes, server_output.getvalue())
1348
 
        # check the decoding of the client smart_protocol from expected_bytes:
1349
 
        input = StringIO(expected_bytes)
1350
 
        output = StringIO()
1351
 
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
1352
 
        request = client_medium.get_request()
1353
 
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
1354
 
        smart_protocol.call('foo')
1355
 
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
 
1419
        input = StringIO('ok\x012\n')
 
1420
        output = StringIO()
 
1421
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1422
        request = client_medium.get_request()
 
1423
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1424
        self.assertEqual(2, smart_protocol.query_version())
1356
1425
 
1357
1426
    def test_client_call_empty_response(self):
1358
1427
        # protocol.call() can get back an empty tuple as a response. This occurs
1442
1511
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1443
1512
 
1444
1513
 
 
1514
class TestSmartProtocolTwo(TestSmartProtocol):
 
1515
    """Tests for the smart protocol version two.
 
1516
 
 
1517
    This test case is mostly the same as TestSmartProtocolOne.
 
1518
    """
 
1519
 
 
1520
    client_protocol_class = protocol.SmartClientRequestProtocolTwo
 
1521
    server_protocol_class = protocol.SmartServerRequestProtocolTwo
 
1522
 
 
1523
    def test_construct_version_two_server_protocol(self):
 
1524
        smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
 
1525
        self.assertEqual('', smart_protocol.excess_buffer)
 
1526
        self.assertEqual('', smart_protocol.in_buffer)
 
1527
        self.assertFalse(smart_protocol.has_dispatched)
 
1528
        self.assertEqual(1, smart_protocol.next_read_size())
 
1529
 
 
1530
    def test_construct_version_two_client_protocol(self):
 
1531
        # we can construct a client protocol from a client medium request
 
1532
        output = StringIO()
 
1533
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1534
        request = client_medium.get_request()
 
1535
        client_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1536
 
 
1537
    def test_server_offset_serialisation(self):
 
1538
        """The Smart protocol serialises offsets as a comma and \n string.
 
1539
 
 
1540
        We check a number of boundary cases are as expected: empty, one offset,
 
1541
        one with the order of reads not increasing (an out of order read), and
 
1542
        one that should coalesce.
 
1543
        """
 
1544
        self.assertOffsetSerialisation([], '', self.client_protocol)
 
1545
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
 
1546
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
 
1547
            self.client_protocol)
 
1548
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
 
1549
            '1,2\n3,4\n100,200', self.client_protocol)
 
1550
 
 
1551
    def test_accept_bytes_of_bad_request_to_protocol(self):
 
1552
        out_stream = StringIO()
 
1553
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1554
            None, out_stream.write)
 
1555
        smart_protocol.accept_bytes('abc')
 
1556
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1557
        smart_protocol.accept_bytes('\n')
 
1558
        self.assertEqual(
 
1559
            protocol.RESPONSE_VERSION_TWO +
 
1560
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1561
            out_stream.getvalue())
 
1562
        self.assertTrue(smart_protocol.has_dispatched)
 
1563
        self.assertEqual(0, smart_protocol.next_read_size())
 
1564
 
 
1565
    def test_accept_body_bytes_to_protocol(self):
 
1566
        protocol = self.build_protocol_waiting_for_body()
 
1567
        self.assertEqual(6, protocol.next_read_size())
 
1568
        protocol.accept_bytes('7\nabc')
 
1569
        self.assertEqual(9, protocol.next_read_size())
 
1570
        protocol.accept_bytes('defgd')
 
1571
        protocol.accept_bytes('one\n')
 
1572
        self.assertEqual(0, protocol.next_read_size())
 
1573
        self.assertTrue(self.end_received)
 
1574
 
 
1575
    def test_accept_request_and_body_all_at_once(self):
 
1576
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1577
        mem_transport = memory.MemoryTransport()
 
1578
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1579
        out_stream = StringIO()
 
1580
        smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
 
1581
                out_stream.write)
 
1582
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1583
        self.assertEqual(0, smart_protocol.next_read_size())
 
1584
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + 'readv\n3\ndefdone\n',
 
1585
                         out_stream.getvalue())
 
1586
        self.assertEqual('', smart_protocol.excess_buffer)
 
1587
        self.assertEqual('', smart_protocol.in_buffer)
 
1588
 
 
1589
    def test_accept_excess_bytes_are_preserved(self):
 
1590
        out_stream = StringIO()
 
1591
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1592
            None, out_stream.write)
 
1593
        smart_protocol.accept_bytes('hello\nhello\n')
 
1594
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "ok\x012\n",
 
1595
                         out_stream.getvalue())
 
1596
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1597
        self.assertEqual("", smart_protocol.in_buffer)
 
1598
 
 
1599
    def test_accept_excess_bytes_after_body(self):
 
1600
        # The excess bytes look like the start of another request.
 
1601
        server_protocol = self.build_protocol_waiting_for_body()
 
1602
        server_protocol.accept_bytes(
 
1603
            '7\nabcdefgdone\n' + protocol.RESPONSE_VERSION_TWO)
 
1604
        self.assertTrue(self.end_received)
 
1605
        self.assertEqual(protocol.RESPONSE_VERSION_TWO,
 
1606
                         server_protocol.excess_buffer)
 
1607
        self.assertEqual("", server_protocol.in_buffer)
 
1608
        server_protocol.accept_bytes('Y')
 
1609
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "Y",
 
1610
                         server_protocol.excess_buffer)
 
1611
        self.assertEqual("", server_protocol.in_buffer)
 
1612
 
 
1613
    def test_accept_excess_bytes_after_dispatch(self):
 
1614
        out_stream = StringIO()
 
1615
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1616
            None, out_stream.write)
 
1617
        smart_protocol.accept_bytes('hello\n')
 
1618
        self.assertEqual(protocol.RESPONSE_VERSION_TWO + "ok\x012\n",
 
1619
                         out_stream.getvalue())
 
1620
        smart_protocol.accept_bytes(protocol.REQUEST_VERSION_TWO + 'hel')
 
1621
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hel",
 
1622
                         smart_protocol.excess_buffer)
 
1623
        smart_protocol.accept_bytes('lo\n')
 
1624
        self.assertEqual(protocol.REQUEST_VERSION_TWO + "hello\n",
 
1625
                         smart_protocol.excess_buffer)
 
1626
        self.assertEqual("", smart_protocol.in_buffer)
 
1627
 
 
1628
    def test__send_response_sets_finished_reading(self):
 
1629
        smart_protocol = protocol.SmartServerRequestProtocolTwo(
 
1630
            None, lambda x: None)
 
1631
        self.assertEqual(1, smart_protocol.next_read_size())
 
1632
        smart_protocol._send_response(('x',))
 
1633
        self.assertEqual(0, smart_protocol.next_read_size())
 
1634
 
 
1635
    def test_query_version(self):
 
1636
        """query_version on a SmartClientProtocolTwo should return a number.
 
1637
        
 
1638
        The protocol provides the query_version because the domain level clients
 
1639
        may all need to be able to probe for capabilities.
 
1640
        """
 
1641
        # What we really want to test here is that SmartClientProtocolTwo calls
 
1642
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1643
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1644
        # the error if the response is a non-understood version.
 
1645
        input = StringIO(protocol.RESPONSE_VERSION_TWO + 'ok\x012\n')
 
1646
        output = StringIO()
 
1647
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1648
        request = client_medium.get_request()
 
1649
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1650
        self.assertEqual(2, smart_protocol.query_version())
 
1651
 
 
1652
    def test_client_call_empty_response(self):
 
1653
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1654
        # when the parsed line is an empty line, and results in a tuple with
 
1655
        # one element - an empty string.
 
1656
        self.assertServerToClientEncoding(
 
1657
            protocol.RESPONSE_VERSION_TWO + '\n', ('', ), [(), ('', )])
 
1658
 
 
1659
    def test_client_call_three_element_response(self):
 
1660
        # protocol.call() can get back tuples of other lengths. A three element
 
1661
        # tuple should be unpacked as three strings.
 
1662
        self.assertServerToClientEncoding(
 
1663
            protocol.RESPONSE_VERSION_TWO + 'a\x01b\x0134\n', ('a', 'b', '34'),
 
1664
            [('a', 'b', '34')])
 
1665
 
 
1666
    def test_client_call_with_body_bytes_uploads(self):
 
1667
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
 
1668
        # wire.
 
1669
        expected_bytes = protocol.REQUEST_VERSION_TWO + "foo\n7\nabcdefgdone\n"
 
1670
        input = StringIO("\n")
 
1671
        output = StringIO()
 
1672
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1673
        request = client_medium.get_request()
 
1674
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1675
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1676
        self.assertEqual(expected_bytes, output.getvalue())
 
1677
 
 
1678
    def test_client_call_with_body_readv_array(self):
 
1679
        # protocol.call_with_upload should encode the readv array and then
 
1680
        # length-prefix the bytes onto the wire.
 
1681
        expected_bytes = protocol.REQUEST_VERSION_TWO+"foo\n7\n1,2\n5,6done\n"
 
1682
        input = StringIO("\n")
 
1683
        output = StringIO()
 
1684
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1685
        request = client_medium.get_request()
 
1686
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1687
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1688
        self.assertEqual(expected_bytes, output.getvalue())
 
1689
 
 
1690
    def test_client_read_body_bytes_all(self):
 
1691
        # read_body_bytes should decode the body bytes from the wire into
 
1692
        # a response.
 
1693
        expected_bytes = "1234567"
 
1694
        server_bytes = protocol.RESPONSE_VERSION_TWO + "ok\n7\n1234567done\n"
 
1695
        input = StringIO(server_bytes)
 
1696
        output = StringIO()
 
1697
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1698
        request = client_medium.get_request()
 
1699
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1700
        smart_protocol.call('foo')
 
1701
        smart_protocol.read_response_tuple(True)
 
1702
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
 
1703
 
 
1704
    def test_client_read_body_bytes_incremental(self):
 
1705
        # test reading a few bytes at a time from the body
 
1706
        # XXX: possibly we should test dribbling the bytes into the stringio
 
1707
        # to make the state machine work harder: however, as we use the
 
1708
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
1709
        # that.
 
1710
        expected_bytes = "1234567"
 
1711
        server_bytes = protocol.RESPONSE_VERSION_TWO + "ok\n7\n1234567done\n"
 
1712
        input = StringIO(server_bytes)
 
1713
        output = StringIO()
 
1714
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1715
        request = client_medium.get_request()
 
1716
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1717
        smart_protocol.call('foo')
 
1718
        smart_protocol.read_response_tuple(True)
 
1719
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1720
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1721
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1722
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
 
1723
 
 
1724
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
1725
        # cancelling the expected body needs to finish the request, but not
 
1726
        # read any more bytes.
 
1727
        expected_bytes = "1234567"
 
1728
        server_bytes = protocol.RESPONSE_VERSION_TWO + "ok\n7\n1234567done\n"
 
1729
        input = StringIO(server_bytes)
 
1730
        output = StringIO()
 
1731
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1732
        request = client_medium.get_request()
 
1733
        smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
 
1734
        smart_protocol.call('foo')
 
1735
        smart_protocol.read_response_tuple(True)
 
1736
        smart_protocol.cancel_read_body()
 
1737
        self.assertEqual(len(protocol.RESPONSE_VERSION_TWO + 'ok\n'),
 
1738
                         input.tell())
 
1739
        self.assertRaises(
 
1740
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
 
1741
 
 
1742
 
1445
1743
class TestSmartClientUnicode(tests.TestCase):
1446
1744
    """_SmartClient tests for unicode arguments.
1447
1745
 
1601
1899
        self.addCleanup(http_server.tearDown)
1602
1900
 
1603
1901
        post_body = 'hello\n'
1604
 
        expected_reply_body = 'ok\x011\n'
 
1902
        expected_reply_body = 'ok\x012\n'
1605
1903
 
1606
1904
        http_transport = get_transport(http_server.get_url())
1607
1905
        medium = http_transport.get_smart_medium()
1622
1920
        self.transport_readonly_server = HTTPServerWithSmarts
1623
1921
 
1624
1922
        post_body = 'hello\n'
1625
 
        expected_reply_body = 'ok\x011\n'
 
1923
        expected_reply_body = 'ok\x012\n'
1626
1924
 
1627
1925
        smart_server_url = self.get_readonly_url('.bzr/smart')
1628
1926
        reply = urllib2.urlopen(smart_server_url, post_body).read()
1648
1946
        response = socket.writefile.getvalue()
1649
1947
        self.assertStartsWith(response, 'HTTP/1.0 200 ')
1650
1948
        # This includes the end of the HTTP headers, and all the body.
1651
 
        expected_end_of_response = '\r\n\r\nok\x011\n'
 
1949
        expected_end_of_response = '\r\n\r\nok\x012\n'
1652
1950
        self.assertEndsWith(response, expected_end_of_response)
1653
1951
 
1654
1952