~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2007-04-10 07:43:02 UTC
  • mfrom: (2400.1.9 split-smart-part-1-rename)
  • Revision ID: pqm@pqm.ubuntu.com-20070410074302-cf6b95587a1058cd
(Andrew Bennetts) Split bzrlib/transport/smart.py into several smaller modules.

Show diffs side-by-side

added added

removed removed

Lines of Context:
30
30
        tests,
31
31
        urlutils,
32
32
        )
 
33
from bzrlib.smart import (
 
34
        medium,
 
35
        protocol,
 
36
        request,
 
37
        server,
 
38
)
33
39
from bzrlib.tests.HTTPTestUtil import (
34
40
        HTTPServerWithSmarts,
35
41
        SmartRequestHandler,
38
44
        get_transport,
39
45
        local,
40
46
        memory,
41
 
        smart,
 
47
        remote,
42
48
        )
43
49
from bzrlib.transport.http import SmartClientHTTPMediumRequest
44
50
 
85
91
        sock.bind(('127.0.0.1', 0))
86
92
        sock.listen(1)
87
93
        port = sock.getsockname()[1]
88
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', port)
89
 
        return sock, medium
 
94
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port)
 
95
        return sock, client_medium
90
96
 
91
97
    def receive_bytes_on_server(self, sock, bytes):
92
98
        """Accept a connection on sock and read 3 bytes.
108
114
        # this just ensures that the constructor stays parameter-free which
109
115
        # is important for reuse : some subclasses will dynamically connect,
110
116
        # others are always on, etc.
111
 
        medium = smart.SmartClientStreamMedium()
 
117
        client_medium = medium.SmartClientStreamMedium()
112
118
 
113
119
    def test_construct_smart_client_medium(self):
114
120
        # the base client medium takes no parameters
115
 
        medium = smart.SmartClientMedium()
 
121
        client_medium = medium.SmartClientMedium()
116
122
    
117
123
    def test_construct_smart_simple_pipes_client_medium(self):
118
124
        # the SimplePipes client medium takes two pipes:
119
125
        # readable pipe, writeable pipe.
120
126
        # Constructing one should just save these and do nothing.
121
127
        # We test this by passing in None.
122
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
128
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
123
129
        
124
130
    def test_simple_pipes_client_request_type(self):
125
131
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
126
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
127
 
        request = medium.get_request()
128
 
        self.assertIsInstance(request, smart.SmartClientStreamMediumRequest)
 
132
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
133
        request = client_medium.get_request()
 
134
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
129
135
 
130
136
    def test_simple_pipes_client_get_concurrent_requests(self):
131
137
        # the simple_pipes client does not support pipelined requests:
135
141
        # classes - as the sibling classes share this logic, they do not have
136
142
        # explicit tests for this.
137
143
        output = StringIO()
138
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
139
 
        request = medium.get_request()
 
144
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
145
        request = client_medium.get_request()
140
146
        request.finished_writing()
141
147
        request.finished_reading()
142
 
        request2 = medium.get_request()
 
148
        request2 = client_medium.get_request()
143
149
        request2.finished_writing()
144
150
        request2.finished_reading()
145
151
 
146
152
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
147
153
        # accept_bytes writes to the writeable pipe.
148
154
        output = StringIO()
149
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
150
 
        medium._accept_bytes('abc')
 
155
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
156
        client_medium._accept_bytes('abc')
151
157
        self.assertEqual('abc', output.getvalue())
152
158
    
153
159
    def test_simple_pipes_client_disconnect_does_nothing(self):
154
160
        # calling disconnect does nothing.
155
161
        input = StringIO()
156
162
        output = StringIO()
157
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
163
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
158
164
        # send some bytes to ensure disconnecting after activity still does not
159
165
        # close.
160
 
        medium._accept_bytes('abc')
161
 
        medium.disconnect()
 
166
        client_medium._accept_bytes('abc')
 
167
        client_medium.disconnect()
162
168
        self.assertFalse(input.closed)
163
169
        self.assertFalse(output.closed)
164
170
 
167
173
        # accept_bytes writes to.
168
174
        input = StringIO()
169
175
        output = StringIO()
170
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
171
 
        medium._accept_bytes('abc')
172
 
        medium.disconnect()
173
 
        medium._accept_bytes('abc')
 
176
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
177
        client_medium._accept_bytes('abc')
 
178
        client_medium.disconnect()
 
179
        client_medium._accept_bytes('abc')
174
180
        self.assertFalse(input.closed)
175
181
        self.assertFalse(output.closed)
176
182
        self.assertEqual('abcabc', output.getvalue())
178
184
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
179
185
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
180
186
        # does nothing.
181
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
182
 
        medium.disconnect()
 
187
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
188
        client_medium.disconnect()
183
189
 
184
190
    def test_simple_pipes_client_can_always_read(self):
185
191
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
186
192
        # always tries to read from the underlying pipe.
187
193
        input = StringIO('abcdef')
188
 
        medium = smart.SmartSimplePipesClientMedium(input, None)
189
 
        self.assertEqual('abc', medium.read_bytes(3))
190
 
        medium.disconnect()
191
 
        self.assertEqual('def', medium.read_bytes(3))
 
194
        client_medium = medium.SmartSimplePipesClientMedium(input, None)
 
195
        self.assertEqual('abc', client_medium.read_bytes(3))
 
196
        client_medium.disconnect()
 
197
        self.assertEqual('def', client_medium.read_bytes(3))
192
198
        
193
199
    def test_simple_pipes_client_supports__flush(self):
194
200
        # invoking _flush on a SimplePipesClient should flush the output 
200
206
        flush_calls = []
201
207
        def logging_flush(): flush_calls.append('flush')
202
208
        output.flush = logging_flush
203
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
209
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
204
210
        # this call is here to ensure we only flush once, not on every
205
211
        # _accept_bytes call.
206
 
        medium._accept_bytes('abc')
207
 
        medium._flush()
208
 
        medium.disconnect()
 
212
        client_medium._accept_bytes('abc')
 
213
        client_medium._flush()
 
214
        client_medium.disconnect()
209
215
        self.assertEqual(['flush'], flush_calls)
210
216
 
211
217
    def test_construct_smart_ssh_client_medium(self):
219
225
        unopened_port = sock.getsockname()[1]
220
226
        # having vendor be invalid means that if it tries to connect via the
221
227
        # vendor it will blow up.
222
 
        medium = smart.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
228
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
223
229
            username=None, password=None, vendor="not a vendor")
224
230
        sock.close()
225
231
 
228
234
        # it bytes.
229
235
        output = StringIO()
230
236
        vendor = StringIOSSHVendor(StringIO(), output)
231
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
232
 
            'a password', vendor)
233
 
        medium._accept_bytes('abc')
 
237
        client_medium = medium.SmartSSHClientMedium(
 
238
            'a hostname', 'a port', 'a username', 'a password', vendor)
 
239
        client_medium._accept_bytes('abc')
234
240
        self.assertEqual('abc', output.getvalue())
235
241
        self.assertEqual([('connect_ssh', 'a username', 'a password',
236
242
            'a hostname', 'a port',
247
253
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
248
254
        self.addCleanup(cleanup_environ)
249
255
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
250
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
 
256
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port', 'a username',
251
257
            'a password', vendor)
252
 
        medium._accept_bytes('abc')
 
258
        client_medium._accept_bytes('abc')
253
259
        self.assertEqual('abc', output.getvalue())
254
260
        self.assertEqual([('connect_ssh', 'a username', 'a password',
255
261
            'a hostname', 'a port',
262
268
        input = StringIO()
263
269
        output = StringIO()
264
270
        vendor = StringIOSSHVendor(input, output)
265
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
266
 
        medium._accept_bytes('abc')
267
 
        medium.disconnect()
 
271
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
272
        client_medium._accept_bytes('abc')
 
273
        client_medium.disconnect()
268
274
        self.assertTrue(input.closed)
269
275
        self.assertTrue(output.closed)
270
276
        self.assertEqual([
282
288
        input = StringIO()
283
289
        output = StringIO()
284
290
        vendor = StringIOSSHVendor(input, output)
285
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
286
 
        medium._accept_bytes('abc')
287
 
        medium.disconnect()
 
291
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
292
        client_medium._accept_bytes('abc')
 
293
        client_medium.disconnect()
288
294
        # the disconnect has closed output, so we need a new output for the
289
295
        # new connection to write to.
290
296
        input2 = StringIO()
291
297
        output2 = StringIO()
292
298
        vendor.read_from = input2
293
299
        vendor.write_to = output2
294
 
        medium._accept_bytes('abc')
295
 
        medium.disconnect()
 
300
        client_medium._accept_bytes('abc')
 
301
        client_medium.disconnect()
296
302
        self.assertTrue(input.closed)
297
303
        self.assertTrue(output.closed)
298
304
        self.assertTrue(input2.closed)
310
316
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
311
317
        # Doing a disconnect on a new (and thus unconnected) SSH medium
312
318
        # does not fail.  It's ok to disconnect an unconnected medium.
313
 
        medium = smart.SmartSSHClientMedium(None)
314
 
        medium.disconnect()
 
319
        client_medium = medium.SmartSSHClientMedium(None)
 
320
        client_medium.disconnect()
315
321
 
316
322
    def test_ssh_client_raises_on_read_when_not_connected(self):
317
323
        # Doing a read on a new (and thus unconnected) SSH medium raises
318
324
        # MediumNotConnected.
319
 
        medium = smart.SmartSSHClientMedium(None)
320
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
321
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
325
        client_medium = medium.SmartSSHClientMedium(None)
 
326
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
327
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
322
328
 
323
329
    def test_ssh_client_supports__flush(self):
324
330
        # invoking _flush on a SSHClientMedium should flush the output 
331
337
        def logging_flush(): flush_calls.append('flush')
332
338
        output.flush = logging_flush
333
339
        vendor = StringIOSSHVendor(input, output)
334
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
340
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
335
341
        # this call is here to ensure we only flush once, not on every
336
342
        # _accept_bytes call.
337
 
        medium._accept_bytes('abc')
338
 
        medium._flush()
339
 
        medium.disconnect()
 
343
        client_medium._accept_bytes('abc')
 
344
        client_medium._flush()
 
345
        client_medium.disconnect()
340
346
        self.assertEqual(['flush'], flush_calls)
341
347
        
342
348
    def test_construct_smart_tcp_client_medium(self):
345
351
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
346
352
        sock.bind(('127.0.0.1', 0))
347
353
        unopened_port = sock.getsockname()[1]
348
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
354
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', unopened_port)
349
355
        sock.close()
350
356
 
351
357
    def test_tcp_client_connects_on_first_use(self):
378
384
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
379
385
        # Doing a disconnect on a new (and thus unconnected) TCP medium
380
386
        # does not fail.  It's ok to disconnect an unconnected medium.
381
 
        medium = smart.SmartTCPClientMedium(None, None)
382
 
        medium.disconnect()
 
387
        client_medium = medium.SmartTCPClientMedium(None, None)
 
388
        client_medium.disconnect()
383
389
 
384
390
    def test_tcp_client_raises_on_read_when_not_connected(self):
385
391
        # Doing a read on a new (and thus unconnected) TCP medium raises
386
392
        # MediumNotConnected.
387
 
        medium = smart.SmartTCPClientMedium(None, None)
388
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
389
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
393
        client_medium = medium.SmartTCPClientMedium(None, None)
 
394
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
395
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
390
396
 
391
397
    def test_tcp_client_supports__flush(self):
392
398
        # invoking _flush on a TCPClientMedium should do something useful.
421
427
        # WritingCompleted to prevent bad assumptions on stream environments
422
428
        # breaking the needs of message-based environments.
423
429
        output = StringIO()
424
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
425
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
430
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
431
        request = medium.SmartClientStreamMediumRequest(client_medium)
426
432
        request.finished_writing()
427
433
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
428
434
 
432
438
        # and checking that the pipes get the data.
433
439
        input = StringIO()
434
440
        output = StringIO()
435
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
436
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
441
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
442
        request = medium.SmartClientStreamMediumRequest(client_medium)
437
443
        request.accept_bytes('123')
438
444
        request.finished_writing()
439
445
        request.finished_reading()
444
450
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
445
451
        # the current request to the new SmartClientStreamMediumRequest
446
452
        output = StringIO()
447
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
448
 
        request = smart.SmartClientStreamMediumRequest(medium)
449
 
        self.assertIs(medium._current_request, request)
 
453
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
454
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
455
        self.assertIs(client_medium._current_request, request)
450
456
 
451
457
    def test_construct_while_another_request_active_throws(self):
452
458
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
453
459
        # a non-None _current_request raises TooManyConcurrentRequests.
454
460
        output = StringIO()
455
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
456
 
        medium._current_request = "a"
 
461
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
462
        client_medium._current_request = "a"
457
463
        self.assertRaises(errors.TooManyConcurrentRequests,
458
 
            smart.SmartClientStreamMediumRequest, medium)
 
464
            medium.SmartClientStreamMediumRequest, client_medium)
459
465
 
460
466
    def test_finished_read_clears_current_request(self):
461
467
        # calling finished_reading clears the current request from the requests
462
468
        # medium
463
469
        output = StringIO()
464
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
465
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
470
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
471
        request = medium.SmartClientStreamMediumRequest(client_medium)
466
472
        request.finished_writing()
467
473
        request.finished_reading()
468
 
        self.assertEqual(None, medium._current_request)
 
474
        self.assertEqual(None, client_medium._current_request)
469
475
 
470
476
    def test_finished_read_before_finished_write_errors(self):
471
477
        # calling finished_reading before calling finished_writing triggers a
472
478
        # WritingNotComplete error.
473
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
474
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
479
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
480
        request = medium.SmartClientStreamMediumRequest(client_medium)
475
481
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
476
482
        
477
483
    def test_read_bytes(self):
483
489
        # smoke tests.
484
490
        input = StringIO('321')
485
491
        output = StringIO()
486
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
487
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
492
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
493
        request = medium.SmartClientStreamMediumRequest(client_medium)
488
494
        request.finished_writing()
489
495
        self.assertEqual('321', request.read_bytes(3))
490
496
        request.finished_reading()
496
502
        # WritingNotComplete error because the Smart protocol is designed to be
497
503
        # compatible with strict message based protocols like HTTP where the
498
504
        # request cannot be submitted until the writing has completed.
499
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
500
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
505
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
506
        request = medium.SmartClientStreamMediumRequest(client_medium)
501
507
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
502
508
 
503
509
    def test_read_bytes_after_finished_reading_errors(self):
505
511
        # ReadingCompleted to prevent bad assumptions on stream environments
506
512
        # breaking the needs of message-based environments.
507
513
        output = StringIO()
508
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
509
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
514
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
515
        request = medium.SmartClientStreamMediumRequest(client_medium)
510
516
        request.finished_writing()
511
517
        request.finished_reading()
512
518
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
520
526
        # the default or a parameterized class, but rather use the
521
527
        # TestCaseWithTransport infrastructure to set up a smart server and
522
528
        # transport.
523
 
        self.transport_server = smart.SmartTCPServer_for_testing
 
529
        self.transport_server = server.SmartTCPServer_for_testing
524
530
 
525
531
    def test_plausible_url(self):
526
532
        self.assert_(self.get_url().startswith('bzr://'))
527
533
 
528
534
    def test_probe_transport(self):
529
535
        t = self.get_transport()
530
 
        self.assertIsInstance(t, smart.SmartTransport)
 
536
        self.assertIsInstance(t, remote.SmartTransport)
531
537
 
532
538
    def test_get_medium_from_transport(self):
533
539
        """Remote transport has a medium always, which it can return."""
534
540
        t = self.get_transport()
535
 
        medium = t.get_smart_medium()
536
 
        self.assertIsInstance(medium, smart.SmartClientMedium)
 
541
        smart_medium = t.get_smart_medium()
 
542
        self.assertIsInstance(smart_medium, medium.SmartClientMedium)
537
543
 
538
544
 
539
545
class ErrorRaisingProtocol(object):
588
594
        to_server = StringIO('hello\n')
589
595
        from_server = StringIO()
590
596
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
591
 
        server = smart.SmartServerPipeStreamMedium(
 
597
        server = medium.SmartServerPipeStreamMedium(
592
598
            to_server, from_server, transport)
593
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
599
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
594
600
                from_server.write)
595
 
        server._serve_one_request(protocol)
 
601
        server._serve_one_request(smart_protocol)
596
602
        self.assertEqual('ok\0011\n',
597
603
                         from_server.getvalue())
598
604
 
601
607
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
602
608
        to_server = StringIO('get\001./testfile\n')
603
609
        from_server = StringIO()
604
 
        server = smart.SmartServerPipeStreamMedium(
 
610
        server = medium.SmartServerPipeStreamMedium(
605
611
            to_server, from_server, transport)
606
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
612
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
607
613
                from_server.write)
608
 
        server._serve_one_request(protocol)
 
614
        server._serve_one_request(smart_protocol)
609
615
        self.assertEqual('ok\n'
610
616
                         '17\n'
611
617
                         'contents\nof\nfile\n'
619
625
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
620
626
        to_server = StringIO('get\001' + utf8_filename + '\n')
621
627
        from_server = StringIO()
622
 
        server = smart.SmartServerPipeStreamMedium(
 
628
        server = medium.SmartServerPipeStreamMedium(
623
629
            to_server, from_server, transport)
624
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
630
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
625
631
                from_server.write)
626
 
        server._serve_one_request(protocol)
 
632
        server._serve_one_request(smart_protocol)
627
633
        self.assertEqual('ok\n'
628
634
                         '17\n'
629
635
                         'contents\nof\nfile\n'
634
640
        sample_request_bytes = 'command\n9\nbulk datadone\n'
635
641
        to_server = StringIO(sample_request_bytes)
636
642
        from_server = StringIO()
637
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
643
        server = medium.SmartServerPipeStreamMedium(
 
644
            to_server, from_server, None)
638
645
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
639
646
        server._serve_one_request(sample_protocol)
640
647
        self.assertEqual('', from_server.getvalue())
644
651
    def test_socket_stream_with_bulk_data(self):
645
652
        sample_request_bytes = 'command\n9\nbulk datadone\n'
646
653
        server_sock, client_sock = self.portable_socket_pair()
647
 
        server = smart.SmartServerSocketStreamMedium(
 
654
        server = medium.SmartServerSocketStreamMedium(
648
655
            server_sock, None)
649
656
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
650
657
        client_sock.sendall(sample_request_bytes)
657
664
    def test_pipe_like_stream_shutdown_detection(self):
658
665
        to_server = StringIO('')
659
666
        from_server = StringIO()
660
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
667
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
661
668
        server._serve_one_request(SampleRequest('x'))
662
669
        self.assertTrue(server.finished)
663
670
        
664
671
    def test_socket_stream_shutdown_detection(self):
665
672
        server_sock, client_sock = self.portable_socket_pair()
666
673
        client_sock.close()
667
 
        server = smart.SmartServerSocketStreamMedium(
 
674
        server = medium.SmartServerSocketStreamMedium(
668
675
            server_sock, None)
669
676
        server._serve_one_request(SampleRequest('x'))
670
677
        self.assertTrue(server.finished)
676
683
        sample_request_bytes = 'command\n'
677
684
        to_server = StringIO(sample_request_bytes * 2)
678
685
        from_server = StringIO()
679
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
686
        server = medium.SmartServerPipeStreamMedium(
 
687
            to_server, from_server, None)
680
688
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
681
689
        server._serve_one_request(first_protocol)
682
690
        self.assertEqual(0, first_protocol.next_read_size())
696
704
        # been received seperately.
697
705
        sample_request_bytes = 'command\n'
698
706
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = smart.SmartServerSocketStreamMedium(
 
707
        server = medium.SmartServerSocketStreamMedium(
700
708
            server_sock, None)
701
709
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
710
        # Put two whole requests on the wire.
723
731
        def close():
724
732
            self.closed = True
725
733
        from_server.close = close
726
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
734
        server = medium.SmartServerPipeStreamMedium(
 
735
            to_server, from_server, None)
727
736
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
728
737
        server._serve_one_request(fake_protocol)
729
738
        self.assertEqual('', from_server.getvalue())
735
744
        # not discard the contents.
736
745
        from StringIO import StringIO
737
746
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = smart.SmartServerSocketStreamMedium(
 
747
        server = medium.SmartServerSocketStreamMedium(
739
748
            server_sock, None)
740
749
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
741
750
        server._serve_one_request(fake_protocol)
749
758
        # not discard the contents.
750
759
        to_server = StringIO('')
751
760
        from_server = StringIO()
752
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
761
        server = medium.SmartServerPipeStreamMedium(
 
762
            to_server, from_server, None)
753
763
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
754
764
        self.assertRaises(
755
765
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
757
767
 
758
768
    def test_socket_stream_keyboard_interrupt_handling(self):
759
769
        server_sock, client_sock = self.portable_socket_pair()
760
 
        server = smart.SmartServerSocketStreamMedium(
 
770
        server = medium.SmartServerSocketStreamMedium(
761
771
            server_sock, None)
762
772
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
763
773
        self.assertRaises(
774
784
            base = 'a_url'
775
785
            def get_bytes(self, path):
776
786
                raise Exception("some random exception from inside server")
777
 
        server = smart.SmartTCPServer(backing_transport=FlakyTransport())
778
 
        server.start_background_thread()
 
787
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
 
788
        smart_server.start_background_thread()
779
789
        try:
780
 
            transport = smart.SmartTCPTransport(server.get_url())
 
790
            transport = remote.SmartTCPTransport(smart_server.get_url())
781
791
            try:
782
792
                transport.get('something')
783
793
            except errors.TransportError, e:
785
795
            else:
786
796
                self.fail("get did not raise expected error")
787
797
        finally:
788
 
            server.stop_background_thread()
 
798
            smart_server.stop_background_thread()
789
799
 
790
800
 
791
801
class SmartTCPTests(tests.TestCase):
806
816
        if readonly:
807
817
            self.real_backing_transport = self.backing_transport
808
818
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
809
 
        self.server = smart.SmartTCPServer(self.backing_transport)
 
819
        self.server = server.SmartTCPServer(self.backing_transport)
810
820
        self.server.start_background_thread()
811
 
        self.transport = smart.SmartTCPTransport(self.server.get_url())
 
821
        self.transport = remote.SmartTCPTransport(self.server.get_url())
812
822
        self.addCleanup(self.tearDownServer)
813
823
 
814
824
    def tearDownServer(self):
826
836
        """It should be safe to teardown the server with no requests."""
827
837
        self.setUpServer()
828
838
        server = self.server
829
 
        transport = smart.SmartTCPTransport(self.server.get_url())
 
839
        transport = remote.SmartTCPTransport(self.server.get_url())
830
840
        self.tearDownServer()
831
841
        self.assertRaises(errors.ConnectionError, transport.has, '.')
832
842
 
838
848
        self.tearDownServer()
839
849
        # if the listening socket has closed, we should get a BADFD error
840
850
        # when connecting, rather than a hang.
841
 
        transport = smart.SmartTCPTransport(server.get_url())
 
851
        transport = remote.SmartTCPTransport(server.get_url())
842
852
        self.assertRaises(errors.ConnectionError, transport.has, '.')
843
853
 
844
854
 
937
947
    def test_server_started_hook(self):
938
948
        """The server_started hook fires when the server is started."""
939
949
        self.hook_calls = []
940
 
        smart.SmartTCPServer.hooks.install_hook('server_started',
 
950
        server.SmartTCPServer.hooks.install_hook('server_started',
941
951
            self.capture_server_call)
942
952
        self.setUpServer()
943
953
        # at this point, the server will be starting a thread up.
949
959
    def test_server_stopped_hook_simple(self):
950
960
        """The server_stopped hook fires when the server is stopped."""
951
961
        self.hook_calls = []
952
 
        smart.SmartTCPServer.hooks.install_hook('server_stopped',
 
962
        server.SmartTCPServer.hooks.install_hook('server_stopped',
953
963
            self.capture_server_call)
954
964
        self.setUpServer()
955
965
        result = [(self.backing_transport.base, self.transport.base)]
959
969
        self.transport.has('.')
960
970
        self.assertEqual([], self.hook_calls)
961
971
        # clean up the server
962
 
        server = self.server
963
972
        self.tearDownServer()
964
973
        # now it should have fired.
965
974
        self.assertEqual(result, self.hook_calls)
973
982
 
974
983
    def test_construct_request_handler(self):
975
984
        """Constructing a request handler should be easy and set defaults."""
976
 
        handler = smart.SmartServerRequestHandler(None)
 
985
        handler = request.SmartServerRequestHandler(None)
977
986
        self.assertFalse(handler.finished_reading)
978
987
 
979
988
    def test_hello(self):
980
 
        handler = smart.SmartServerRequestHandler(None)
 
989
        handler = request.SmartServerRequestHandler(None)
981
990
        handler.dispatch_command('hello', ())
982
991
        self.assertEqual(('ok', '1'), handler.response.args)
983
992
        self.assertEqual(None, handler.response.body)
989
998
        wt.add('hello')
990
999
        rev_id = wt.commit('add hello')
991
1000
        
992
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
1001
        handler = request.SmartServerRequestHandler(self.get_transport())
993
1002
        handler.dispatch_command('get_bundle', ('.', rev_id))
994
1003
        bundle = serializer.read_bundle(StringIO(handler.response.body))
995
1004
        self.assertEqual((), handler.response.args)
996
1005
 
997
1006
    def test_readonly_exception_becomes_transport_not_possible(self):
998
1007
        """The response for a read-only error is ('ReadOnlyError')."""
999
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1008
        handler = request.SmartServerRequestHandler(self.get_readonly_transport())
1000
1009
        # send a mkdir for foo, with no explicit mode - should fail.
1001
1010
        handler.dispatch_command('mkdir', ('foo', ''))
1002
1011
        # and the failure should be an explicit ReadOnlyError
1008
1017
 
1009
1018
    def test_hello_has_finished_body_on_dispatch(self):
1010
1019
        """The 'hello' command should set finished_reading."""
1011
 
        handler = smart.SmartServerRequestHandler(None)
 
1020
        handler = request.SmartServerRequestHandler(None)
1012
1021
        handler.dispatch_command('hello', ())
1013
1022
        self.assertTrue(handler.finished_reading)
1014
1023
        self.assertNotEqual(None, handler.response)
1015
1024
 
1016
1025
    def test_put_bytes_non_atomic(self):
1017
1026
        """'put_...' should set finished_reading after reading the bytes."""
1018
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
1027
        handler = request.SmartServerRequestHandler(self.get_transport())
1019
1028
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
1020
1029
        self.assertFalse(handler.finished_reading)
1021
1030
        handler.accept_body('1234')
1029
1038
    def test_readv_accept_body(self):
1030
1039
        """'readv' should set finished_reading after reading offsets."""
1031
1040
        self.build_tree(['a-file'])
1032
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1041
        handler = request.SmartServerRequestHandler(self.get_readonly_transport())
1033
1042
        handler.dispatch_command('readv', ('a-file', ))
1034
1043
        self.assertFalse(handler.finished_reading)
1035
1044
        handler.accept_body('2,')
1044
1053
    def test_readv_short_read_response_contents(self):
1045
1054
        """'readv' when a short read occurs sets the response appropriately."""
1046
1055
        self.build_tree(['a-file'])
1047
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1056
        handler = request.SmartServerRequestHandler(self.get_readonly_transport())
1048
1057
        handler.dispatch_command('readv', ('a-file', ))
1049
1058
        # read beyond the end of the file.
1050
1059
        handler.accept_body('100,1')
1055
1064
        self.assertEqual(None, handler.response.body)
1056
1065
 
1057
1066
 
1058
 
class SmartTransportRegistration(tests.TestCase):
 
1067
class RemoteTransportRegistration(tests.TestCase):
1059
1068
 
1060
1069
    def test_registration(self):
1061
1070
        t = get_transport('bzr+ssh://example.com/path')
1062
 
        self.assertIsInstance(t, smart.SmartSSHTransport)
 
1071
        self.assertIsInstance(t, remote.SmartSSHTransport)
1063
1072
        self.assertEqual('example.com', t._host)
1064
1073
 
1065
1074
 
1066
 
class TestSmartTransport(tests.TestCase):
 
1075
class TestRemoteTransport(tests.TestCase):
1067
1076
        
1068
1077
    def test_use_connection_factory(self):
1069
 
        # We want to be able to pass a client as a parameter to SmartTransport.
 
1078
        # We want to be able to pass a client as a parameter to RemoteTransport.
1070
1079
        input = StringIO("ok\n3\nbardone\n")
1071
1080
        output = StringIO()
1072
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1073
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1081
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1082
        transport = remote.SmartTransport(
 
1083
            'bzr://localhost/', medium=client_medium)
1074
1084
 
1075
1085
        # We want to make sure the client is used when the first remote
1076
1086
        # method is called.  No data should have been sent, or read.
1088
1098
 
1089
1099
    def test__translate_error_readonly(self):
1090
1100
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1091
 
        medium = smart.SmartClientMedium()
1092
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1101
        client_medium = medium.SmartClientMedium()
 
1102
        transport = remote.SmartTransport(
 
1103
            'bzr://localhost/', medium=client_medium)
1093
1104
        self.assertRaises(errors.TransportNotPossible,
1094
1105
            transport._translate_error, ("ReadOnlyError", ))
1095
1106
 
1096
1107
 
1097
 
class InstrumentedServerProtocol(smart.SmartServerStreamMedium):
 
1108
class InstrumentedServerProtocol(medium.SmartServerStreamMedium):
1098
1109
    """A smart server which is backed by memory and saves its write requests."""
1099
1110
 
1100
1111
    def __init__(self, write_output_list):
1101
 
        smart.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
 
1112
        medium.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
1102
1113
        self._write_output_list = write_output_list
1103
1114
 
1104
1115
 
1117
1128
 
1118
1129
    def setUp(self):
1119
1130
        super(TestSmartProtocol, self).setUp()
 
1131
        # XXX: self.server_to_client doesn't seem to be used.  If so,
 
1132
        # InstrumentedServerProtocol is redundant too.
1120
1133
        self.server_to_client = []
1121
1134
        self.to_server = StringIO()
1122
1135
        self.to_client = StringIO()
1123
 
        self.client_medium = smart.SmartSimplePipesClientMedium(self.to_client,
 
1136
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1124
1137
            self.to_server)
1125
 
        self.client_protocol = smart.SmartClientRequestProtocolOne(
 
1138
        self.client_protocol = protocol.SmartClientRequestProtocolOne(
1126
1139
            self.client_medium)
1127
1140
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1128
 
        self.smart_server_request = smart.SmartServerRequestHandler(None)
 
1141
        self.smart_server_request = request.SmartServerRequestHandler(None)
1129
1142
 
1130
1143
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1131
1144
        client, smart_server_request):
1148
1161
 
1149
1162
    def build_protocol_waiting_for_body(self):
1150
1163
        out_stream = StringIO()
1151
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1152
 
        protocol.has_dispatched = True
1153
 
        protocol.request = smart.SmartServerRequestHandler(None)
 
1164
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, out_stream.write)
 
1165
        smart_protocol.has_dispatched = True
 
1166
        smart_protocol.request = request.SmartServerRequestHandler(None)
1154
1167
        def handle_end_of_bytes():
1155
1168
            self.end_received = True
1156
 
            self.assertEqual('abcdefg', protocol.request._body_bytes)
1157
 
            protocol.request.response = smart.SmartServerResponse(('ok', ))
1158
 
        protocol.request._end_of_body_handler = handle_end_of_bytes
 
1169
            self.assertEqual('abcdefg', smart_protocol.request._body_bytes)
 
1170
            smart_protocol.request.response = request.SmartServerResponse(('ok', ))
 
1171
        smart_protocol.request._end_of_body_handler = handle_end_of_bytes
1159
1172
        # Call accept_bytes to make sure that internal state like _body_decoder
1160
1173
        # is initialised.  This test should probably be given a clearer
1161
1174
        # interface to work with that will not cause this inconsistency.
1162
1175
        #   -- Andrew Bennetts, 2006-09-28
1163
 
        protocol.accept_bytes('')
1164
 
        return protocol
 
1176
        smart_protocol.accept_bytes('')
 
1177
        return smart_protocol
1165
1178
 
1166
1179
    def test_construct_version_one_server_protocol(self):
1167
 
        protocol = smart.SmartServerRequestProtocolOne(None, None)
1168
 
        self.assertEqual('', protocol.excess_buffer)
1169
 
        self.assertEqual('', protocol.in_buffer)
1170
 
        self.assertFalse(protocol.has_dispatched)
1171
 
        self.assertEqual(1, protocol.next_read_size())
 
1180
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
 
1181
        self.assertEqual('', smart_protocol.excess_buffer)
 
1182
        self.assertEqual('', smart_protocol.in_buffer)
 
1183
        self.assertFalse(smart_protocol.has_dispatched)
 
1184
        self.assertEqual(1, smart_protocol.next_read_size())
1172
1185
 
1173
1186
    def test_construct_version_one_client_protocol(self):
1174
1187
        # we can construct a client protocol from a client medium request
1175
1188
        output = StringIO()
1176
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
1177
 
        request = medium.get_request()
1178
 
        client_protocol = smart.SmartClientRequestProtocolOne(request)
 
1189
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1190
        request = client_medium.get_request()
 
1191
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1179
1192
 
1180
1193
    def test_server_offset_serialisation(self):
1181
1194
        """The Smart protocol serialises offsets as a comma and \n string.
1195
1208
 
1196
1209
    def test_accept_bytes_of_bad_request_to_protocol(self):
1197
1210
        out_stream = StringIO()
1198
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1199
 
        protocol.accept_bytes('abc')
1200
 
        self.assertEqual('abc', protocol.in_buffer)
1201
 
        protocol.accept_bytes('\n')
1202
 
        self.assertEqual("error\x01Generic bzr smart protocol error: bad request"
1203
 
            " 'abc'\n", out_stream.getvalue())
1204
 
        self.assertTrue(protocol.has_dispatched)
1205
 
        self.assertEqual(0, protocol.next_read_size())
 
1211
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1212
            None, out_stream.write)
 
1213
        smart_protocol.accept_bytes('abc')
 
1214
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1215
        smart_protocol.accept_bytes('\n')
 
1216
        self.assertEqual(
 
1217
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1218
            out_stream.getvalue())
 
1219
        self.assertTrue(smart_protocol.has_dispatched)
 
1220
        self.assertEqual(0, smart_protocol.next_read_size())
1206
1221
 
1207
1222
    def test_accept_body_bytes_to_protocol(self):
1208
1223
        protocol = self.build_protocol_waiting_for_body()
1215
1230
        self.assertTrue(self.end_received)
1216
1231
 
1217
1232
    def test_accept_request_and_body_all_at_once(self):
 
1233
        self._captureVar('NO_SMART_VFS', None)
1218
1234
        mem_transport = memory.MemoryTransport()
1219
1235
        mem_transport.put_bytes('foo', 'abcdefghij')
1220
1236
        out_stream = StringIO()
1221
 
        protocol = smart.SmartServerRequestProtocolOne(mem_transport,
 
1237
        smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
1222
1238
                out_stream.write)
1223
 
        protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1224
 
        self.assertEqual(0, protocol.next_read_size())
 
1239
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1240
        self.assertEqual(0, smart_protocol.next_read_size())
1225
1241
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1226
 
        self.assertEqual('', protocol.excess_buffer)
1227
 
        self.assertEqual('', protocol.in_buffer)
 
1242
        self.assertEqual('', smart_protocol.excess_buffer)
 
1243
        self.assertEqual('', smart_protocol.in_buffer)
1228
1244
 
1229
1245
    def test_accept_excess_bytes_are_preserved(self):
1230
1246
        out_stream = StringIO()
1231
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1232
 
        protocol.accept_bytes('hello\nhello\n')
 
1247
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1248
            None, out_stream.write)
 
1249
        smart_protocol.accept_bytes('hello\nhello\n')
1233
1250
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1234
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1235
 
        self.assertEqual("", protocol.in_buffer)
 
1251
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1252
        self.assertEqual("", smart_protocol.in_buffer)
1236
1253
 
1237
1254
    def test_accept_excess_bytes_after_body(self):
1238
1255
        protocol = self.build_protocol_waiting_for_body()
1246
1263
 
1247
1264
    def test_accept_excess_bytes_after_dispatch(self):
1248
1265
        out_stream = StringIO()
1249
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1250
 
        protocol.accept_bytes('hello\n')
 
1266
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1267
            None, out_stream.write)
 
1268
        smart_protocol.accept_bytes('hello\n')
1251
1269
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1252
 
        protocol.accept_bytes('hel')
1253
 
        self.assertEqual("hel", protocol.excess_buffer)
1254
 
        protocol.accept_bytes('lo\n')
1255
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1256
 
        self.assertEqual("", protocol.in_buffer)
 
1270
        smart_protocol.accept_bytes('hel')
 
1271
        self.assertEqual("hel", smart_protocol.excess_buffer)
 
1272
        smart_protocol.accept_bytes('lo\n')
 
1273
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1274
        self.assertEqual("", smart_protocol.in_buffer)
1257
1275
 
1258
1276
    def test__send_response_sets_finished_reading(self):
1259
 
        protocol = smart.SmartServerRequestProtocolOne(None, lambda x: None)
1260
 
        self.assertEqual(1, protocol.next_read_size())
1261
 
        protocol._send_response(('x',))
1262
 
        self.assertEqual(0, protocol.next_read_size())
 
1277
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1278
            None, lambda x: None)
 
1279
        self.assertEqual(1, smart_protocol.next_read_size())
 
1280
        smart_protocol._send_response(('x',))
 
1281
        self.assertEqual(0, smart_protocol.next_read_size())
1263
1282
 
1264
1283
    def test_query_version(self):
1265
1284
        """query_version on a SmartClientProtocolOne should return a number.
1273
1292
        # the error if the response is a non-understood version.
1274
1293
        input = StringIO('ok\x011\n')
1275
1294
        output = StringIO()
1276
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1277
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1278
 
        self.assertEqual(1, protocol.query_version())
 
1295
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1296
        request = client_medium.get_request()
 
1297
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1298
        self.assertEqual(1, smart_protocol.query_version())
1279
1299
 
1280
1300
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1281
1301
            input_tuples):
1286
1306
        # expected bytes
1287
1307
        for input_tuple in input_tuples:
1288
1308
            server_output = StringIO()
1289
 
            server_protocol = smart.SmartServerRequestProtocolOne(
 
1309
            server_protocol = protocol.SmartServerRequestProtocolOne(
1290
1310
                None, server_output.write)
1291
1311
            server_protocol._send_response(input_tuple)
1292
1312
            self.assertEqual(expected_bytes, server_output.getvalue())
1293
 
        # check the decoding of the client protocol from expected_bytes:
 
1313
        # check the decoding of the client smart_protocol from expected_bytes:
1294
1314
        input = StringIO(expected_bytes)
1295
1315
        output = StringIO()
1296
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1297
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1298
 
        protocol.call('foo')
1299
 
        self.assertEqual(expected_tuple, protocol.read_response_tuple())
 
1316
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1317
        request = client_medium.get_request()
 
1318
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1319
        smart_protocol.call('foo')
 
1320
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
1300
1321
 
1301
1322
    def test_client_call_empty_response(self):
1302
1323
        # protocol.call() can get back an empty tuple as a response. This occurs
1311
1332
            [('a', 'b', '34')])
1312
1333
 
1313
1334
    def test_client_call_with_body_bytes_uploads(self):
1314
 
        # protocol.call_with_upload should length-prefix the bytes onto the 
 
1335
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1315
1336
        # wire.
1316
1337
        expected_bytes = "foo\n7\nabcdefgdone\n"
1317
1338
        input = StringIO("\n")
1318
1339
        output = StringIO()
1319
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1320
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1321
 
        protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1340
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1341
        request = client_medium.get_request()
 
1342
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1343
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1322
1344
        self.assertEqual(expected_bytes, output.getvalue())
1323
1345
 
1324
1346
    def test_client_call_with_body_readv_array(self):
1327
1349
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
1328
1350
        input = StringIO("\n")
1329
1351
        output = StringIO()
1330
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1331
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1332
 
        protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1352
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1353
        request = client_medium.get_request()
 
1354
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1355
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1333
1356
        self.assertEqual(expected_bytes, output.getvalue())
1334
1357
 
1335
1358
    def test_client_read_body_bytes_all(self):
1339
1362
        server_bytes = "ok\n7\n1234567done\n"
1340
1363
        input = StringIO(server_bytes)
1341
1364
        output = StringIO()
1342
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1343
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1344
 
        protocol.call('foo')
1345
 
        protocol.read_response_tuple(True)
1346
 
        self.assertEqual(expected_bytes, protocol.read_body_bytes())
 
1365
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1366
        request = client_medium.get_request()
 
1367
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1368
        smart_protocol.call('foo')
 
1369
        smart_protocol.read_response_tuple(True)
 
1370
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1347
1371
 
1348
1372
    def test_client_read_body_bytes_incremental(self):
1349
1373
        # test reading a few bytes at a time from the body
1355
1379
        server_bytes = "ok\n7\n1234567done\n"
1356
1380
        input = StringIO(server_bytes)
1357
1381
        output = StringIO()
1358
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1359
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1360
 
        protocol.call('foo')
1361
 
        protocol.read_response_tuple(True)
1362
 
        self.assertEqual(expected_bytes[0:2], protocol.read_body_bytes(2))
1363
 
        self.assertEqual(expected_bytes[2:4], protocol.read_body_bytes(2))
1364
 
        self.assertEqual(expected_bytes[4:6], protocol.read_body_bytes(2))
1365
 
        self.assertEqual(expected_bytes[6], protocol.read_body_bytes())
 
1382
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1383
        request = client_medium.get_request()
 
1384
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1385
        smart_protocol.call('foo')
 
1386
        smart_protocol.read_response_tuple(True)
 
1387
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1388
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1389
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1390
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1366
1391
 
1367
1392
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1368
1393
        # cancelling the expected body needs to finish the request, but not
1371
1396
        server_bytes = "ok\n7\n1234567done\n"
1372
1397
        input = StringIO(server_bytes)
1373
1398
        output = StringIO()
1374
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1375
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1376
 
        protocol.call('foo')
1377
 
        protocol.read_response_tuple(True)
1378
 
        protocol.cancel_read_body()
 
1399
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1400
        request = client_medium.get_request()
 
1401
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1402
        smart_protocol.call('foo')
 
1403
        smart_protocol.read_response_tuple(True)
 
1404
        smart_protocol.cancel_read_body()
1379
1405
        self.assertEqual(3, input.tell())
1380
 
        self.assertRaises(errors.ReadingCompleted, protocol.read_body_bytes)
 
1406
        self.assertRaises(
 
1407
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1381
1408
 
1382
1409
 
1383
1410
class LengthPrefixedBodyDecoder(tests.TestCase):
1386
1413
    # something similar to the ProtocolBase method.
1387
1414
 
1388
1415
    def test_construct(self):
1389
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1416
        decoder = protocol.LengthPrefixedBodyDecoder()
1390
1417
        self.assertFalse(decoder.finished_reading)
1391
1418
        self.assertEqual(6, decoder.next_read_size())
1392
1419
        self.assertEqual('', decoder.read_pending_data())
1393
1420
        self.assertEqual('', decoder.unused_data)
1394
1421
 
1395
1422
    def test_accept_bytes(self):
1396
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1423
        decoder = protocol.LengthPrefixedBodyDecoder()
1397
1424
        decoder.accept_bytes('')
1398
1425
        self.assertFalse(decoder.finished_reading)
1399
1426
        self.assertEqual(6, decoder.next_read_size())
1426
1453
        self.assertEqual('blarg', decoder.unused_data)
1427
1454
        
1428
1455
    def test_accept_bytes_all_at_once_with_excess(self):
1429
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1456
        decoder = protocol.LengthPrefixedBodyDecoder()
1430
1457
        decoder.accept_bytes('1\nadone\nunused')
1431
1458
        self.assertTrue(decoder.finished_reading)
1432
1459
        self.assertEqual(1, decoder.next_read_size())
1434
1461
        self.assertEqual('unused', decoder.unused_data)
1435
1462
 
1436
1463
    def test_accept_bytes_exact_end_of_body(self):
1437
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1464
        decoder = protocol.LengthPrefixedBodyDecoder()
1438
1465
        decoder.accept_bytes('1\na')
1439
1466
        self.assertFalse(decoder.finished_reading)
1440
1467
        self.assertEqual(5, decoder.next_read_size())
1458
1485
 
1459
1486
class HTTPTunnellingSmokeTest(tests.TestCaseWithTransport):
1460
1487
    
 
1488
    def setUp(self):
 
1489
        super(HTTPTunnellingSmokeTest, self).setUp()
 
1490
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1491
        self._captureVar('NO_SMART_VFS', None)
 
1492
 
1461
1493
    def _test_bulk_data(self, url_protocol):
1462
1494
        # We should be able to send and receive bulk data in a single message.
1463
1495
        # The 'readv' command in the smart protocol both sends and receives bulk
1464
1496
        # data, so we use that.
1465
1497
        self.build_tree(['data-file'])
1466
 
        http_server = HTTPServerWithSmarts()
1467
 
        http_server._url_protocol = url_protocol
1468
 
        http_server.setUp()
1469
 
        self.addCleanup(http_server.tearDown)
1470
 
 
1471
 
        http_transport = get_transport(http_server.get_url())
1472
 
 
 
1498
        self.transport_readonly_server = HTTPServerWithSmarts
 
1499
 
 
1500
        http_transport = self.get_readonly_transport()
1473
1501
        medium = http_transport.get_smart_medium()
1474
1502
        #remote_transport = RemoteTransport('fake_url', medium)
1475
 
        remote_transport = smart.SmartTransport('/', medium=medium)
 
1503
        remote_transport = remote.SmartTransport('/', medium=medium)
1476
1504
        self.assertEqual(
1477
1505
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
1478
1506
 
1497
1525
    def _test_http_send_smart_request(self, url_protocol):
1498
1526
        http_server = HTTPServerWithSmarts()
1499
1527
        http_server._url_protocol = url_protocol
1500
 
        http_server.setUp()
 
1528
        http_server.setUp(self.get_vfs_only_server())
1501
1529
        self.addCleanup(http_server.tearDown)
1502
1530
 
1503
1531
        post_body = 'hello\n'
1519
1547
        self._test_http_send_smart_request('http+urllib')
1520
1548
 
1521
1549
    def test_http_server_with_smarts(self):
1522
 
        http_server = HTTPServerWithSmarts()
1523
 
        http_server.setUp()
1524
 
        self.addCleanup(http_server.tearDown)
 
1550
        self.transport_readonly_server = HTTPServerWithSmarts
1525
1551
 
1526
1552
        post_body = 'hello\n'
1527
1553
        expected_reply_body = 'ok\x011\n'
1528
1554
 
1529
 
        smart_server_url = http_server.get_url() + '.bzr/smart'
 
1555
        smart_server_url = self.get_readonly_url('.bzr/smart')
1530
1556
        reply = urllib2.urlopen(smart_server_url, post_body).read()
1531
1557
 
1532
1558
        self.assertEqual(expected_reply_body, reply)
1533
1559
 
1534
1560
    def test_smart_http_server_post_request_handler(self):
1535
 
        http_server = HTTPServerWithSmarts()
1536
 
        http_server.setUp()
1537
 
        self.addCleanup(http_server.tearDown)
1538
 
        httpd = http_server._get_httpd()
 
1561
        self.transport_readonly_server = HTTPServerWithSmarts
 
1562
        httpd = self.get_readonly_server()._get_httpd()
1539
1563
 
1540
1564
        socket = SampleSocket(
1541
1565
            'POST /.bzr/smart HTTP/1.0\r\n'
1577
1601
        else:
1578
1602
            return self.writefile
1579
1603
 
1580
 
        
 
1604
 
1581
1605
# TODO: Client feature that does get_bundle and then installs that into a
1582
1606
# branch; this can be used in place of the regular pull/fetch operation when
1583
1607
# coming from a smart server.