106
107
class FakeProtocol(object):
107
108
"""Lookalike SmartClientRequestProtocolOne allowing body reading tests."""
109
def __init__(self, body):
110
def __init__(self, body, fake_client):
110
111
self._body_buffer = StringIO(body)
112
self._fake_client = fake_client
112
114
def read_body_bytes(self, count=-1):
113
return self._body_buffer.read(count)
115
bytes = self._body_buffer.read(count)
116
if self._body_buffer.tell() == len(self._body_buffer.getvalue()):
117
self._fake_client.expecting_body = False
120
def cancel_read_body(self):
121
self._fake_client.expecting_body = False
116
124
class FakeClient(_SmartClient):
119
127
def __init__(self, responses):
120
128
# We don't call the super init because there is no medium.
121
"""create a FakeClient.
129
"""Create a FakeClient.
123
131
:param respones: A list of response-tuple, body-data pairs to be sent
126
134
self.responses = responses
136
self.expecting_body = False
129
138
def call(self, method, *args):
130
139
self._calls.append(('call', method, args))
133
142
def call_expecting_body(self, method, *args):
134
143
self._calls.append(('call_expecting_body', method, args))
135
144
result = self.responses.pop(0)
136
return result[0], FakeProtocol(result[1])
145
self.expecting_body = True
146
return result[0], FakeProtocol(result[1], self)
139
149
class TestBzrDirOpenBranch(tests.TestCase):
557
567
[('call_expecting_body', 'Repository.get_revision_graph',
558
568
('///sinhala/', ''))],
560
self.assertEqual({r1: [], r2: [r1]}, result)
570
self.assertEqual({r1: (), r2: (r1, )}, result)
562
572
def test_specific_revision(self):
563
573
# with a specific revision we want the graph for that
577
587
[('call_expecting_body', 'Repository.get_revision_graph',
578
588
('///sinhala/', r2))],
580
self.assertEqual({r11: [], r12: [], r2: [r11, r12], }, result)
590
self.assertEqual({r11: (), r12: (), r2: (r11, r12), }, result)
582
592
def test_no_such_revision(self):
734
744
tarball_file.close()
736
def test_sprout_uses_tarball(self):
737
# RemoteRepository.sprout should try to use the
738
# tarball command rather than accessing all the files
739
transport_path = 'srcrepo'
740
expected_responses = [(('ok',), self.tarball_content),
742
expected_calls = [('call2', 'Repository.tarball', ('///srcrepo/', 'bz2',),),
744
remote_repo, client = self.setup_fake_client_and_repository(
745
expected_responses, transport_path)
746
# make a regular local repository to receive the results
747
dest_transport = MemoryTransport()
748
dest_transport.mkdir('destrepo')
749
bzrdir_format = bzrdir.format_registry.make_bzrdir('default')
750
dest_bzrdir = bzrdir_format.initialize_on_transport(dest_transport)
752
remote_repo.sprout(dest_bzrdir)
755
747
class TestRemoteRepositoryCopyContent(tests.TestCaseWithTransport):
756
748
"""RemoteRepository.copy_content_into optimizations"""
769
761
self.assertFalse(isinstance(dest_repo, RemoteRepository))
770
762
self.assertTrue(isinstance(src_repo, RemoteRepository))
771
763
src_repo.copy_content_into(dest_repo)
766
class TestRepositoryStreamKnitData(TestRemoteRepository):
768
def make_pack_file(self, records):
769
pack_file = StringIO()
770
pack_writer = pack.ContainerWriter(pack_file.write)
772
for bytes, names in records:
773
pack_writer.add_bytes_record(bytes, names)
778
def test_bad_pack_from_server(self):
779
"""A response with invalid data (e.g. it has a record with multiple
780
names) triggers an exception.
782
Not all possible errors will be caught at this stage, but obviously
783
malformed data should be.
785
record = ('bytes', [('name1',), ('name2',)])
786
pack_file = self.make_pack_file([record])
787
responses = [(('ok',), pack_file.getvalue()), ]
788
transport_path = 'quack'
789
repo, client = self.setup_fake_client_and_repository(
790
responses, transport_path)
791
stream = repo.get_data_stream(['revid'])
792
self.assertRaises(errors.SmartProtocolError, list, stream)
794
def test_backwards_compatibility(self):
795
"""If the server doesn't recognise this request, fallback to VFS."""
797
"Generic bzr smart protocol error: "
798
"bad request 'Repository.stream_knit_data_for_revisions'")
800
(('error', error_msg), '')]
801
repo, client = self.setup_fake_client_and_repository(
803
self.mock_called = False
804
repo._real_repository = MockRealRepository(self)
805
repo.get_data_stream(['revid'])
806
self.assertTrue(self.mock_called)
807
self.failIf(client.expecting_body,
808
"The protocol has been left in an unclean state that will cause "
809
"TooManyConcurrentRequests errors.")
812
class MockRealRepository(object):
813
"""Helper class for TestRepositoryStreamKnitData.test_unknown_method."""
815
def __init__(self, test):
818
def get_data_stream(self, revision_ids):
819
self.test.assertEqual(['revid'], revision_ids)
820
self.test.mock_called = True