1381
1382
Subclasses can override client_protocol_class and server_protocol_class.
1385
request_encoder = None
1386
response_decoder = None
1387
server_protocol_class = None
1384
1388
client_protocol_class = None
1385
server_protocol_class = None
1387
def make_client_protocol(self):
1388
client_medium = medium.SmartSimplePipesClientMedium(
1389
StringIO(), StringIO())
1390
return self.client_protocol_class(client_medium.get_request())
1390
def make_client_protocol_and_output(self, input_bytes=None):
1391
# This is very similar to
1392
# bzrlib.smart.client._SmartClient._build_client_protocol
1393
if input_bytes is None:
1396
input = StringIO(input_bytes)
1398
client_medium = medium.SmartSimplePipesClientMedium(input, output)
1399
request = client_medium.get_request()
1400
if self.client_protocol_class is not None:
1401
client_protocol = self.client_protocol_class(request)
1402
return client_protocol, client_protocol, output
1404
assert self.request_encoder is not None
1405
assert self.response_decoder is not None
1406
requester = self.request_encoder(request)
1407
response_handler = message.ConventionalResponseHandler()
1408
response_protocol = self.response_decoder(response_handler)
1409
response_handler.setProtoAndMedium(response_protocol, request)
1410
return requester, response_handler, output
1412
def make_client_protocol(self, input_bytes=None):
1413
result = self.make_client_protocol_and_output(input_bytes=input_bytes)
1414
requester, response_handler, output = result
1415
return requester, response_handler
1392
1417
def make_server_protocol(self):
1393
1418
out_stream = StringIO()
1403
1428
self.client_protocol_class, 'request_marker', None)
1405
1430
def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1407
1432
"""Check that smart (de)serialises offsets as expected.
1409
1434
We check both serialisation and deserialisation at the same time
1418
1443
readv_cmd = vfs.ReadvRequest(None, '/')
1419
1444
offsets = readv_cmd._deserialise_offsets(expected_serialised)
1420
1445
self.assertEqual(expected_offsets, offsets)
1421
serialised = client._serialise_offsets(offsets)
1446
serialised = requester._serialise_offsets(offsets)
1422
1447
self.assertEqual(expected_serialised, serialised)
1424
1449
def build_protocol_waiting_for_body(self):
1452
1477
_mod_request.SuccessfulSmartServerResponse(input_tuple))
1453
1478
self.assertEqual(expected_bytes, server_output.getvalue())
1454
1479
# check the decoding of the client smart_protocol from expected_bytes:
1455
input = StringIO(expected_bytes)
1457
client_medium = medium.SmartSimplePipesClientMedium(input, output)
1458
request = client_medium.get_request()
1459
smart_protocol = self.client_protocol_class(request)
1460
smart_protocol.call('foo')
1461
self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
1480
requester, response_handler = self.make_client_protocol(expected_bytes)
1481
requester.call('foo')
1482
self.assertEqual(expected_tuple, response_handler.read_response_tuple())
1464
1485
class CommonSmartProtocolTestMixin(object):
1475
1496
self.assertContainsRe(test_log, 'SmartProtocolError')
1477
1498
def test_connection_closed_reporting(self):
1478
smart_protocol = self.make_client_protocol()
1479
smart_protocol.call('hello')
1499
requester, response_handler = self.make_client_protocol()
1500
requester.call('hello')
1480
1501
ex = self.assertRaises(errors.ConnectionReset,
1481
smart_protocol.read_response_tuple)
1502
response_handler.read_response_tuple)
1482
1503
self.assertEqual("Connection closed: "
1483
1504
"please check connectivity and permissions "
1484
1505
"(and try -Dhpss if further diagnosis is required)", str(ex))
1490
1511
one with the order of reads not increasing (an out of order read), and
1491
1512
one that should coalesce.
1493
client_protocol = self.make_client_protocol()
1494
self.assertOffsetSerialisation([], '', client_protocol)
1495
self.assertOffsetSerialisation([(1,2)], '1,2', client_protocol)
1514
requester, response_handler = self.make_client_protocol()
1515
self.assertOffsetSerialisation([], '', requester)
1516
self.assertOffsetSerialisation([(1,2)], '1,2', requester)
1496
1517
self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1498
1519
self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1499
'1,2\n3,4\n100,200', client_protocol)
1520
'1,2\n3,4\n100,200', requester)
1502
1523
class TestVersionOneFeaturesInProtocolOne(
2117
client_protocol_class = protocol.SmartClientRequestProtocolThree
2138
request_encoder = protocol.ProtocolThreeRequester
2139
response_decoder = protocol.ProtocolThreeDecoder
2118
2140
# build_server_protocol_three is a function, so we can't set it as a class
2119
2141
# attribute directly, because then Python will assume it is actually a
2120
2142
# method. So we make server_protocol_class be a static method, rather than
2122
2144
# "server_protocol_class = protocol.build_server_protocol_three".
2123
2145
server_protocol_class = staticmethod(protocol.build_server_protocol_three)
2148
super(TestVersionOneFeaturesInProtocolThree, self).setUp()
2149
self.response_marker = protocol.MESSAGE_VERSION_THREE
2150
self.request_marker = protocol.MESSAGE_VERSION_THREE
2125
2152
def test_construct_version_three_server_protocol(self):
2126
2153
smart_protocol = protocol.ProtocolThreeDecoder(None)
2127
2154
self.assertEqual('', smart_protocol.excess_buffer)
2170
2197
class TestProtocolThree(TestSmartProtocol):
2171
2198
"""Tests for v3 of the server-side protocol."""
2173
client_protocol_class = protocol.SmartClientRequestProtocolThree
2200
request_encoder = protocol.ProtocolThreeRequester
2201
response_decoder = protocol.ProtocolThreeDecoder
2174
2202
server_protocol_class = protocol.ProtocolThreeDecoder
2176
2204
def test_trivial_request(self):
2383
2411
#class TestClientDecodingProtocolThree(TestSmartProtocol):
2384
2412
# """Tests for v3 of the client-side protocol decoding."""
2386
# client_protocol_class = protocol.SmartClientRequestProtocolThree
2414
# request_encoder = protocol.ProtocolThreeRequester
2415
# response_decoder = protocol.ProtocolThreeDecoder
2387
2416
# server_protocol_class = protocol.SmartServerRequestProtocolThree
2389
2418
# def test_trivial_response_decoding(self):
2406
2435
class TestClientEncodingProtocolThree(TestSmartProtocol):
2408
client_protocol_class = protocol.SmartClientRequestProtocolThree
2437
request_encoder = protocol.ProtocolThreeRequester
2438
response_decoder = protocol.ProtocolThreeDecoder
2409
2439
server_protocol_class = protocol.ProtocolThreeDecoder
2411
2441
def make_client_encoder_and_output(self):
2414
client_medium = medium.SmartSimplePipesClientMedium(input, output)
2415
request = client_medium.get_request()
2416
smart_protocol = self.client_protocol_class(request)
2417
return smart_protocol, output
2442
result = self.make_client_protocol_and_output()
2443
requester, response_handler, output = result
2444
return requester, output
2419
2446
def test_call_smoke_test(self):
2420
"""A smoke test SmartClientRequestProtocolThree.call.
2447
"""A smoke test for ProtocolThreeRequester.call.
2422
2449
This test checks that a particular simple invocation of call emits the
2423
2450
correct bytes for that invocation.
2425
smart_protocol, output = self.make_client_encoder_and_output()
2426
smart_protocol.call('one arg', headers={'header name': 'header value'})
2452
requester, output = self.make_client_encoder_and_output()
2453
requester.call('one arg', headers={'header name': 'header value'})
2427
2454
self.assertEquals(
2428
2455
'bzr message 3 (bzr 1.3)\n' # protocol version
2429
2456
'\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2432
2459
output.getvalue())
2434
2461
def test_call_default_headers(self):
2435
"""SmartClientRequestProtocolThree.call by default sends a 'Software
2462
"""ProtocolThreeRequester.call by default sends a 'Software
2436
2463
version' header.
2438
smart_protocol, output = self.make_client_encoder_and_output()
2439
smart_protocol.call('foo')
2465
requester, output = self.make_client_encoder_and_output()
2466
requester.call('foo')
2440
2467
# XXX: using assertContainsRe is a pretty poor way to assert this.
2441
2468
self.assertContainsRe(output.getvalue(), 'Software version')
2443
2470
def test_call_with_body_bytes_smoke_test(self):
2444
"""A smoke test SmartClientRequestProtocolThree.call_with_body_bytes.
2471
"""A smoke test for ProtocolThreeRequester.call_with_body_bytes.
2446
2473
This test checks that a particular simple invocation of
2447
2474
call_with_body_bytes emits the correct bytes for that invocation.
2449
smart_protocol, output = self.make_client_encoder_and_output()
2450
smart_protocol.call_with_body_bytes(
2476
requester, output = self.make_client_encoder_and_output()
2477
requester.call_with_body_bytes(
2451
2478
('one arg',), 'body bytes',
2452
2479
headers={'header name': 'header value'})
2453
2480
self.assertEquals(