~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

merge bzr.dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
30
30
        tests,
31
31
        urlutils,
32
32
        )
 
33
from bzrlib.smart import (
 
34
        medium,
 
35
        protocol,
 
36
        request,
 
37
        server,
 
38
        vfs,
 
39
)
33
40
from bzrlib.tests.HTTPTestUtil import (
34
41
        HTTPServerWithSmarts,
35
42
        SmartRequestHandler,
38
45
        get_transport,
39
46
        local,
40
47
        memory,
41
 
        smart,
 
48
        remote,
42
49
        )
43
50
from bzrlib.transport.http import SmartClientHTTPMediumRequest
44
51
 
85
92
        sock.bind(('127.0.0.1', 0))
86
93
        sock.listen(1)
87
94
        port = sock.getsockname()[1]
88
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', port)
89
 
        return sock, medium
 
95
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', port)
 
96
        return sock, client_medium
90
97
 
91
98
    def receive_bytes_on_server(self, sock, bytes):
92
99
        """Accept a connection on sock and read 3 bytes.
108
115
        # this just ensures that the constructor stays parameter-free which
109
116
        # is important for reuse : some subclasses will dynamically connect,
110
117
        # others are always on, etc.
111
 
        medium = smart.SmartClientStreamMedium()
 
118
        client_medium = medium.SmartClientStreamMedium()
112
119
 
113
120
    def test_construct_smart_client_medium(self):
114
121
        # the base client medium takes no parameters
115
 
        medium = smart.SmartClientMedium()
 
122
        client_medium = medium.SmartClientMedium()
116
123
    
117
124
    def test_construct_smart_simple_pipes_client_medium(self):
118
125
        # the SimplePipes client medium takes two pipes:
119
126
        # readable pipe, writeable pipe.
120
127
        # Constructing one should just save these and do nothing.
121
128
        # We test this by passing in None.
122
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
129
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
123
130
        
124
131
    def test_simple_pipes_client_request_type(self):
125
132
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
126
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
127
 
        request = medium.get_request()
128
 
        self.assertIsInstance(request, smart.SmartClientStreamMediumRequest)
 
133
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
134
        request = client_medium.get_request()
 
135
        self.assertIsInstance(request, medium.SmartClientStreamMediumRequest)
129
136
 
130
137
    def test_simple_pipes_client_get_concurrent_requests(self):
131
138
        # the simple_pipes client does not support pipelined requests:
135
142
        # classes - as the sibling classes share this logic, they do not have
136
143
        # explicit tests for this.
137
144
        output = StringIO()
138
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
139
 
        request = medium.get_request()
 
145
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
146
        request = client_medium.get_request()
140
147
        request.finished_writing()
141
148
        request.finished_reading()
142
 
        request2 = medium.get_request()
 
149
        request2 = client_medium.get_request()
143
150
        request2.finished_writing()
144
151
        request2.finished_reading()
145
152
 
146
153
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
147
154
        # accept_bytes writes to the writeable pipe.
148
155
        output = StringIO()
149
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
150
 
        medium._accept_bytes('abc')
 
156
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
157
        client_medium._accept_bytes('abc')
151
158
        self.assertEqual('abc', output.getvalue())
152
159
    
153
160
    def test_simple_pipes_client_disconnect_does_nothing(self):
154
161
        # calling disconnect does nothing.
155
162
        input = StringIO()
156
163
        output = StringIO()
157
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
164
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
158
165
        # send some bytes to ensure disconnecting after activity still does not
159
166
        # close.
160
 
        medium._accept_bytes('abc')
161
 
        medium.disconnect()
 
167
        client_medium._accept_bytes('abc')
 
168
        client_medium.disconnect()
162
169
        self.assertFalse(input.closed)
163
170
        self.assertFalse(output.closed)
164
171
 
167
174
        # accept_bytes writes to.
168
175
        input = StringIO()
169
176
        output = StringIO()
170
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
171
 
        medium._accept_bytes('abc')
172
 
        medium.disconnect()
173
 
        medium._accept_bytes('abc')
 
177
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
178
        client_medium._accept_bytes('abc')
 
179
        client_medium.disconnect()
 
180
        client_medium._accept_bytes('abc')
174
181
        self.assertFalse(input.closed)
175
182
        self.assertFalse(output.closed)
176
183
        self.assertEqual('abcabc', output.getvalue())
178
185
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
179
186
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
180
187
        # does nothing.
181
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
182
 
        medium.disconnect()
 
188
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
189
        client_medium.disconnect()
183
190
 
184
191
    def test_simple_pipes_client_can_always_read(self):
185
192
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
186
193
        # always tries to read from the underlying pipe.
187
194
        input = StringIO('abcdef')
188
 
        medium = smart.SmartSimplePipesClientMedium(input, None)
189
 
        self.assertEqual('abc', medium.read_bytes(3))
190
 
        medium.disconnect()
191
 
        self.assertEqual('def', medium.read_bytes(3))
 
195
        client_medium = medium.SmartSimplePipesClientMedium(input, None)
 
196
        self.assertEqual('abc', client_medium.read_bytes(3))
 
197
        client_medium.disconnect()
 
198
        self.assertEqual('def', client_medium.read_bytes(3))
192
199
        
193
200
    def test_simple_pipes_client_supports__flush(self):
194
201
        # invoking _flush on a SimplePipesClient should flush the output 
200
207
        flush_calls = []
201
208
        def logging_flush(): flush_calls.append('flush')
202
209
        output.flush = logging_flush
203
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
210
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
204
211
        # this call is here to ensure we only flush once, not on every
205
212
        # _accept_bytes call.
206
 
        medium._accept_bytes('abc')
207
 
        medium._flush()
208
 
        medium.disconnect()
 
213
        client_medium._accept_bytes('abc')
 
214
        client_medium._flush()
 
215
        client_medium.disconnect()
209
216
        self.assertEqual(['flush'], flush_calls)
210
217
 
211
218
    def test_construct_smart_ssh_client_medium(self):
219
226
        unopened_port = sock.getsockname()[1]
220
227
        # having vendor be invalid means that if it tries to connect via the
221
228
        # vendor it will blow up.
222
 
        medium = smart.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
229
        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
223
230
            username=None, password=None, vendor="not a vendor")
224
231
        sock.close()
225
232
 
228
235
        # it bytes.
229
236
        output = StringIO()
230
237
        vendor = StringIOSSHVendor(StringIO(), output)
231
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
232
 
            'a password', vendor)
233
 
        medium._accept_bytes('abc')
 
238
        client_medium = medium.SmartSSHClientMedium(
 
239
            'a hostname', 'a port', 'a username', 'a password', vendor)
 
240
        client_medium._accept_bytes('abc')
234
241
        self.assertEqual('abc', output.getvalue())
235
242
        self.assertEqual([('connect_ssh', 'a username', 'a password',
236
243
            'a hostname', 'a port',
247
254
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
248
255
        self.addCleanup(cleanup_environ)
249
256
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
250
 
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
 
257
        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port', 'a username',
251
258
            'a password', vendor)
252
 
        medium._accept_bytes('abc')
 
259
        client_medium._accept_bytes('abc')
253
260
        self.assertEqual('abc', output.getvalue())
254
261
        self.assertEqual([('connect_ssh', 'a username', 'a password',
255
262
            'a hostname', 'a port',
262
269
        input = StringIO()
263
270
        output = StringIO()
264
271
        vendor = StringIOSSHVendor(input, output)
265
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
266
 
        medium._accept_bytes('abc')
267
 
        medium.disconnect()
 
272
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
273
        client_medium._accept_bytes('abc')
 
274
        client_medium.disconnect()
268
275
        self.assertTrue(input.closed)
269
276
        self.assertTrue(output.closed)
270
277
        self.assertEqual([
282
289
        input = StringIO()
283
290
        output = StringIO()
284
291
        vendor = StringIOSSHVendor(input, output)
285
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
286
 
        medium._accept_bytes('abc')
287
 
        medium.disconnect()
 
292
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
 
293
        client_medium._accept_bytes('abc')
 
294
        client_medium.disconnect()
288
295
        # the disconnect has closed output, so we need a new output for the
289
296
        # new connection to write to.
290
297
        input2 = StringIO()
291
298
        output2 = StringIO()
292
299
        vendor.read_from = input2
293
300
        vendor.write_to = output2
294
 
        medium._accept_bytes('abc')
295
 
        medium.disconnect()
 
301
        client_medium._accept_bytes('abc')
 
302
        client_medium.disconnect()
296
303
        self.assertTrue(input.closed)
297
304
        self.assertTrue(output.closed)
298
305
        self.assertTrue(input2.closed)
310
317
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
311
318
        # Doing a disconnect on a new (and thus unconnected) SSH medium
312
319
        # does not fail.  It's ok to disconnect an unconnected medium.
313
 
        medium = smart.SmartSSHClientMedium(None)
314
 
        medium.disconnect()
 
320
        client_medium = medium.SmartSSHClientMedium(None)
 
321
        client_medium.disconnect()
315
322
 
316
323
    def test_ssh_client_raises_on_read_when_not_connected(self):
317
324
        # Doing a read on a new (and thus unconnected) SSH medium raises
318
325
        # MediumNotConnected.
319
 
        medium = smart.SmartSSHClientMedium(None)
320
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
321
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
326
        client_medium = medium.SmartSSHClientMedium(None)
 
327
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
328
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
322
329
 
323
330
    def test_ssh_client_supports__flush(self):
324
331
        # invoking _flush on a SSHClientMedium should flush the output 
331
338
        def logging_flush(): flush_calls.append('flush')
332
339
        output.flush = logging_flush
333
340
        vendor = StringIOSSHVendor(input, output)
334
 
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
341
        client_medium = medium.SmartSSHClientMedium('a hostname', vendor=vendor)
335
342
        # this call is here to ensure we only flush once, not on every
336
343
        # _accept_bytes call.
337
 
        medium._accept_bytes('abc')
338
 
        medium._flush()
339
 
        medium.disconnect()
 
344
        client_medium._accept_bytes('abc')
 
345
        client_medium._flush()
 
346
        client_medium.disconnect()
340
347
        self.assertEqual(['flush'], flush_calls)
341
348
        
342
349
    def test_construct_smart_tcp_client_medium(self):
345
352
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
346
353
        sock.bind(('127.0.0.1', 0))
347
354
        unopened_port = sock.getsockname()[1]
348
 
        medium = smart.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
355
        client_medium = medium.SmartTCPClientMedium('127.0.0.1', unopened_port)
349
356
        sock.close()
350
357
 
351
358
    def test_tcp_client_connects_on_first_use(self):
378
385
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
379
386
        # Doing a disconnect on a new (and thus unconnected) TCP medium
380
387
        # does not fail.  It's ok to disconnect an unconnected medium.
381
 
        medium = smart.SmartTCPClientMedium(None, None)
382
 
        medium.disconnect()
 
388
        client_medium = medium.SmartTCPClientMedium(None, None)
 
389
        client_medium.disconnect()
383
390
 
384
391
    def test_tcp_client_raises_on_read_when_not_connected(self):
385
392
        # Doing a read on a new (and thus unconnected) TCP medium raises
386
393
        # MediumNotConnected.
387
 
        medium = smart.SmartTCPClientMedium(None, None)
388
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
389
 
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
394
        client_medium = medium.SmartTCPClientMedium(None, None)
 
395
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 0)
 
396
        self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes, 1)
390
397
 
391
398
    def test_tcp_client_supports__flush(self):
392
399
        # invoking _flush on a TCPClientMedium should do something useful.
421
428
        # WritingCompleted to prevent bad assumptions on stream environments
422
429
        # breaking the needs of message-based environments.
423
430
        output = StringIO()
424
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
425
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
431
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
432
        request = medium.SmartClientStreamMediumRequest(client_medium)
426
433
        request.finished_writing()
427
434
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
428
435
 
432
439
        # and checking that the pipes get the data.
433
440
        input = StringIO()
434
441
        output = StringIO()
435
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
436
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
442
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
443
        request = medium.SmartClientStreamMediumRequest(client_medium)
437
444
        request.accept_bytes('123')
438
445
        request.finished_writing()
439
446
        request.finished_reading()
444
451
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
445
452
        # the current request to the new SmartClientStreamMediumRequest
446
453
        output = StringIO()
447
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
448
 
        request = smart.SmartClientStreamMediumRequest(medium)
449
 
        self.assertIs(medium._current_request, request)
 
454
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
455
        request = medium.SmartClientStreamMediumRequest(client_medium)
 
456
        self.assertIs(client_medium._current_request, request)
450
457
 
451
458
    def test_construct_while_another_request_active_throws(self):
452
459
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
453
460
        # a non-None _current_request raises TooManyConcurrentRequests.
454
461
        output = StringIO()
455
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
456
 
        medium._current_request = "a"
 
462
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
463
        client_medium._current_request = "a"
457
464
        self.assertRaises(errors.TooManyConcurrentRequests,
458
 
            smart.SmartClientStreamMediumRequest, medium)
 
465
            medium.SmartClientStreamMediumRequest, client_medium)
459
466
 
460
467
    def test_finished_read_clears_current_request(self):
461
468
        # calling finished_reading clears the current request from the requests
462
469
        # medium
463
470
        output = StringIO()
464
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
465
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
471
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
472
        request = medium.SmartClientStreamMediumRequest(client_medium)
466
473
        request.finished_writing()
467
474
        request.finished_reading()
468
 
        self.assertEqual(None, medium._current_request)
 
475
        self.assertEqual(None, client_medium._current_request)
469
476
 
470
477
    def test_finished_read_before_finished_write_errors(self):
471
478
        # calling finished_reading before calling finished_writing triggers a
472
479
        # WritingNotComplete error.
473
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
474
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
480
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
481
        request = medium.SmartClientStreamMediumRequest(client_medium)
475
482
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
476
483
        
477
484
    def test_read_bytes(self):
483
490
        # smoke tests.
484
491
        input = StringIO('321')
485
492
        output = StringIO()
486
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
487
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
493
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
494
        request = medium.SmartClientStreamMediumRequest(client_medium)
488
495
        request.finished_writing()
489
496
        self.assertEqual('321', request.read_bytes(3))
490
497
        request.finished_reading()
496
503
        # WritingNotComplete error because the Smart protocol is designed to be
497
504
        # compatible with strict message based protocols like HTTP where the
498
505
        # request cannot be submitted until the writing has completed.
499
 
        medium = smart.SmartSimplePipesClientMedium(None, None)
500
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
506
        client_medium = medium.SmartSimplePipesClientMedium(None, None)
 
507
        request = medium.SmartClientStreamMediumRequest(client_medium)
501
508
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
502
509
 
503
510
    def test_read_bytes_after_finished_reading_errors(self):
505
512
        # ReadingCompleted to prevent bad assumptions on stream environments
506
513
        # breaking the needs of message-based environments.
507
514
        output = StringIO()
508
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
509
 
        request = smart.SmartClientStreamMediumRequest(medium)
 
515
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
516
        request = medium.SmartClientStreamMediumRequest(client_medium)
510
517
        request.finished_writing()
511
518
        request.finished_reading()
512
519
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
520
527
        # the default or a parameterized class, but rather use the
521
528
        # TestCaseWithTransport infrastructure to set up a smart server and
522
529
        # transport.
523
 
        self.transport_server = smart.SmartTCPServer_for_testing
 
530
        self.transport_server = server.SmartTCPServer_for_testing
524
531
 
525
532
    def test_plausible_url(self):
526
533
        self.assert_(self.get_url().startswith('bzr://'))
527
534
 
528
535
    def test_probe_transport(self):
529
536
        t = self.get_transport()
530
 
        self.assertIsInstance(t, smart.SmartTransport)
 
537
        self.assertIsInstance(t, remote.SmartTransport)
531
538
 
532
539
    def test_get_medium_from_transport(self):
533
540
        """Remote transport has a medium always, which it can return."""
534
541
        t = self.get_transport()
535
 
        medium = t.get_smart_medium()
536
 
        self.assertIsInstance(medium, smart.SmartClientMedium)
 
542
        smart_medium = t.get_smart_medium()
 
543
        self.assertIsInstance(smart_medium, medium.SmartClientMedium)
537
544
 
538
545
 
539
546
class ErrorRaisingProtocol(object):
568
575
 
569
576
class TestSmartServerStreamMedium(tests.TestCase):
570
577
 
 
578
    def setUp(self):
 
579
        super(TestSmartServerStreamMedium, self).setUp()
 
580
        self._captureVar('BZR_NO_SMART_VFS', None)
 
581
 
571
582
    def portable_socket_pair(self):
572
583
        """Return a pair of TCP sockets connected to each other.
573
584
        
588
599
        to_server = StringIO('hello\n')
589
600
        from_server = StringIO()
590
601
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
591
 
        server = smart.SmartServerPipeStreamMedium(
 
602
        server = medium.SmartServerPipeStreamMedium(
592
603
            to_server, from_server, transport)
593
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
604
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
594
605
                from_server.write)
595
 
        server._serve_one_request(protocol)
 
606
        server._serve_one_request(smart_protocol)
596
607
        self.assertEqual('ok\0011\n',
597
608
                         from_server.getvalue())
598
609
 
601
612
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
602
613
        to_server = StringIO('get\001./testfile\n')
603
614
        from_server = StringIO()
604
 
        server = smart.SmartServerPipeStreamMedium(
 
615
        server = medium.SmartServerPipeStreamMedium(
605
616
            to_server, from_server, transport)
606
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
617
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
607
618
                from_server.write)
608
 
        server._serve_one_request(protocol)
 
619
        server._serve_one_request(smart_protocol)
609
620
        self.assertEqual('ok\n'
610
621
                         '17\n'
611
622
                         'contents\nof\nfile\n'
619
630
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
620
631
        to_server = StringIO('get\001' + utf8_filename + '\n')
621
632
        from_server = StringIO()
622
 
        server = smart.SmartServerPipeStreamMedium(
 
633
        server = medium.SmartServerPipeStreamMedium(
623
634
            to_server, from_server, transport)
624
 
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
635
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
625
636
                from_server.write)
626
 
        server._serve_one_request(protocol)
 
637
        server._serve_one_request(smart_protocol)
627
638
        self.assertEqual('ok\n'
628
639
                         '17\n'
629
640
                         'contents\nof\nfile\n'
634
645
        sample_request_bytes = 'command\n9\nbulk datadone\n'
635
646
        to_server = StringIO(sample_request_bytes)
636
647
        from_server = StringIO()
637
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
648
        server = medium.SmartServerPipeStreamMedium(
 
649
            to_server, from_server, None)
638
650
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
639
651
        server._serve_one_request(sample_protocol)
640
652
        self.assertEqual('', from_server.getvalue())
644
656
    def test_socket_stream_with_bulk_data(self):
645
657
        sample_request_bytes = 'command\n9\nbulk datadone\n'
646
658
        server_sock, client_sock = self.portable_socket_pair()
647
 
        server = smart.SmartServerSocketStreamMedium(
 
659
        server = medium.SmartServerSocketStreamMedium(
648
660
            server_sock, None)
649
661
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
650
662
        client_sock.sendall(sample_request_bytes)
657
669
    def test_pipe_like_stream_shutdown_detection(self):
658
670
        to_server = StringIO('')
659
671
        from_server = StringIO()
660
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
672
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
661
673
        server._serve_one_request(SampleRequest('x'))
662
674
        self.assertTrue(server.finished)
663
675
        
664
676
    def test_socket_stream_shutdown_detection(self):
665
677
        server_sock, client_sock = self.portable_socket_pair()
666
678
        client_sock.close()
667
 
        server = smart.SmartServerSocketStreamMedium(
 
679
        server = medium.SmartServerSocketStreamMedium(
668
680
            server_sock, None)
669
681
        server._serve_one_request(SampleRequest('x'))
670
682
        self.assertTrue(server.finished)
676
688
        sample_request_bytes = 'command\n'
677
689
        to_server = StringIO(sample_request_bytes * 2)
678
690
        from_server = StringIO()
679
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
691
        server = medium.SmartServerPipeStreamMedium(
 
692
            to_server, from_server, None)
680
693
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
681
694
        server._serve_one_request(first_protocol)
682
695
        self.assertEqual(0, first_protocol.next_read_size())
696
709
        # been received seperately.
697
710
        sample_request_bytes = 'command\n'
698
711
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = smart.SmartServerSocketStreamMedium(
 
712
        server = medium.SmartServerSocketStreamMedium(
700
713
            server_sock, None)
701
714
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
715
        # Put two whole requests on the wire.
723
736
        def close():
724
737
            self.closed = True
725
738
        from_server.close = close
726
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
739
        server = medium.SmartServerPipeStreamMedium(
 
740
            to_server, from_server, None)
727
741
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
728
742
        server._serve_one_request(fake_protocol)
729
743
        self.assertEqual('', from_server.getvalue())
735
749
        # not discard the contents.
736
750
        from StringIO import StringIO
737
751
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = smart.SmartServerSocketStreamMedium(
 
752
        server = medium.SmartServerSocketStreamMedium(
739
753
            server_sock, None)
740
754
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
741
755
        server._serve_one_request(fake_protocol)
749
763
        # not discard the contents.
750
764
        to_server = StringIO('')
751
765
        from_server = StringIO()
752
 
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
766
        server = medium.SmartServerPipeStreamMedium(
 
767
            to_server, from_server, None)
753
768
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
754
769
        self.assertRaises(
755
770
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
757
772
 
758
773
    def test_socket_stream_keyboard_interrupt_handling(self):
759
774
        server_sock, client_sock = self.portable_socket_pair()
760
 
        server = smart.SmartServerSocketStreamMedium(
 
775
        server = medium.SmartServerSocketStreamMedium(
761
776
            server_sock, None)
762
777
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
763
778
        self.assertRaises(
770
785
 
771
786
    def test_get_error_unexpected(self):
772
787
        """Error reported by server with no specific representation"""
 
788
        self._captureVar('BZR_NO_SMART_VFS', None)
773
789
        class FlakyTransport(object):
774
790
            base = 'a_url'
775
791
            def get_bytes(self, path):
776
792
                raise Exception("some random exception from inside server")
777
 
        server = smart.SmartTCPServer(backing_transport=FlakyTransport())
778
 
        server.start_background_thread()
 
793
        smart_server = server.SmartTCPServer(backing_transport=FlakyTransport())
 
794
        smart_server.start_background_thread()
779
795
        try:
780
 
            transport = smart.SmartTCPTransport(server.get_url())
 
796
            transport = remote.SmartTCPTransport(smart_server.get_url())
781
797
            try:
782
798
                transport.get('something')
783
799
            except errors.TransportError, e:
785
801
            else:
786
802
                self.fail("get did not raise expected error")
787
803
        finally:
788
 
            server.stop_background_thread()
 
804
            smart_server.stop_background_thread()
789
805
 
790
806
 
791
807
class SmartTCPTests(tests.TestCase):
806
822
        if readonly:
807
823
            self.real_backing_transport = self.backing_transport
808
824
            self.backing_transport = get_transport("readonly+" + self.backing_transport.abspath('.'))
809
 
        self.server = smart.SmartTCPServer(self.backing_transport)
 
825
        self.server = server.SmartTCPServer(self.backing_transport)
810
826
        self.server.start_background_thread()
811
 
        self.transport = smart.SmartTCPTransport(self.server.get_url())
 
827
        self.transport = remote.SmartTCPTransport(self.server.get_url())
812
828
        self.addCleanup(self.tearDownServer)
813
829
 
814
830
    def tearDownServer(self):
826
842
        """It should be safe to teardown the server with no requests."""
827
843
        self.setUpServer()
828
844
        server = self.server
829
 
        transport = smart.SmartTCPTransport(self.server.get_url())
 
845
        transport = remote.SmartTCPTransport(self.server.get_url())
830
846
        self.tearDownServer()
831
847
        self.assertRaises(errors.ConnectionError, transport.has, '.')
832
848
 
838
854
        self.tearDownServer()
839
855
        # if the listening socket has closed, we should get a BADFD error
840
856
        # when connecting, rather than a hang.
841
 
        transport = smart.SmartTCPTransport(server.get_url())
 
857
        transport = remote.SmartTCPTransport(server.get_url())
842
858
        self.assertRaises(errors.ConnectionError, transport.has, '.')
843
859
 
844
860
 
855
871
 
856
872
    def test_smart_transport_has(self):
857
873
        """Checking for file existence over smart."""
 
874
        self._captureVar('BZR_NO_SMART_VFS', None)
858
875
        self.backing_transport.put_bytes("foo", "contents of foo\n")
859
876
        self.assertTrue(self.transport.has("foo"))
860
877
        self.assertFalse(self.transport.has("non-foo"))
861
878
 
862
879
    def test_smart_transport_get(self):
863
880
        """Read back a file over smart."""
 
881
        self._captureVar('BZR_NO_SMART_VFS', None)
864
882
        self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
865
883
        fp = self.transport.get("foo")
866
884
        self.assertEqual('contents\nof\nfoo\n', fp.read())
870
888
        # The path in a raised NoSuchFile exception should be the precise path
871
889
        # asked for by the client. This gives meaningful and unsurprising errors
872
890
        # for users.
 
891
        self._captureVar('BZR_NO_SMART_VFS', None)
873
892
        try:
874
893
            self.transport.get('not%20a%20file')
875
894
        except errors.NoSuchFile, e:
896
915
 
897
916
    def test_open_dir(self):
898
917
        """Test changing directory"""
 
918
        self._captureVar('BZR_NO_SMART_VFS', None)
899
919
        transport = self.transport
900
920
        self.backing_transport.mkdir('toffee')
901
921
        self.backing_transport.mkdir('toffee/apple')
923
943
 
924
944
    def test_mkdir_error_readonly(self):
925
945
        """TransportNotPossible should be preserved from the backing transport."""
 
946
        self._captureVar('BZR_NO_SMART_VFS', None)
926
947
        self.setUpServer(readonly=True)
927
948
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
928
949
            'foo')
937
958
    def test_server_started_hook(self):
938
959
        """The server_started hook fires when the server is started."""
939
960
        self.hook_calls = []
940
 
        smart.SmartTCPServer.hooks.install_hook('server_started',
 
961
        server.SmartTCPServer.hooks.install_hook('server_started',
941
962
            self.capture_server_call)
942
963
        self.setUpServer()
943
964
        # at this point, the server will be starting a thread up.
949
970
    def test_server_stopped_hook_simple(self):
950
971
        """The server_stopped hook fires when the server is stopped."""
951
972
        self.hook_calls = []
952
 
        smart.SmartTCPServer.hooks.install_hook('server_stopped',
 
973
        server.SmartTCPServer.hooks.install_hook('server_stopped',
953
974
            self.capture_server_call)
954
975
        self.setUpServer()
955
976
        result = [(self.backing_transport.base, self.transport.base)]
959
980
        self.transport.has('.')
960
981
        self.assertEqual([], self.hook_calls)
961
982
        # clean up the server
962
 
        server = self.server
963
983
        self.tearDownServer()
964
984
        # now it should have fired.
965
985
        self.assertEqual(result, self.hook_calls)
968
988
# server-stopped hook.
969
989
 
970
990
 
971
 
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
972
 
    """Test that call directly into the handler logic, bypassing the network."""
973
 
 
974
 
    def test_construct_request_handler(self):
975
 
        """Constructing a request handler should be easy and set defaults."""
976
 
        handler = smart.SmartServerRequestHandler(None)
977
 
        self.assertFalse(handler.finished_reading)
978
 
 
 
991
class SmartServerCommandTests(tests.TestCaseWithTransport):
 
992
    """Tests that call directly into the command objects, bypassing the network
 
993
    and the request dispatching.
 
994
    """
 
995
        
979
996
    def test_hello(self):
980
 
        handler = smart.SmartServerRequestHandler(None)
981
 
        handler.dispatch_command('hello', ())
982
 
        self.assertEqual(('ok', '1'), handler.response.args)
983
 
        self.assertEqual(None, handler.response.body)
 
997
        cmd = request.HelloRequest(None)
 
998
        response = cmd.execute()
 
999
        self.assertEqual(('ok', '1'), response.args)
 
1000
        self.assertEqual(None, response.body)
984
1001
        
985
1002
    def test_get_bundle(self):
986
1003
        from bzrlib.bundle import serializer
989
1006
        wt.add('hello')
990
1007
        rev_id = wt.commit('add hello')
991
1008
        
992
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
993
 
        handler.dispatch_command('get_bundle', ('.', rev_id))
994
 
        bundle = serializer.read_bundle(StringIO(handler.response.body))
995
 
        self.assertEqual((), handler.response.args)
 
1009
        cmd = request.GetBundleRequest(self.get_transport())
 
1010
        response = cmd.execute('.', rev_id)
 
1011
        bundle = serializer.read_bundle(StringIO(response.body))
 
1012
        self.assertEqual((), response.args)
 
1013
 
 
1014
 
 
1015
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
 
1016
    """Test that call directly into the handler logic, bypassing the network."""
 
1017
 
 
1018
    def setUp(self):
 
1019
        super(SmartServerRequestHandlerTests, self).setUp()
 
1020
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1021
 
 
1022
    def build_handler(self, transport):
 
1023
        """Returns a handler for the commands in protocol version one."""
 
1024
        return request.SmartServerRequestHandler(transport, request.request_handlers)
 
1025
 
 
1026
    def test_construct_request_handler(self):
 
1027
        """Constructing a request handler should be easy and set defaults."""
 
1028
        handler = request.SmartServerRequestHandler(None, None)
 
1029
        self.assertFalse(handler.finished_reading)
 
1030
 
 
1031
    def test_hello(self):
 
1032
        handler = self.build_handler(None)
 
1033
        handler.dispatch_command('hello', ())
 
1034
        self.assertEqual(('ok', '1'), handler.response.args)
 
1035
        self.assertEqual(None, handler.response.body)
 
1036
        
 
1037
    def test_disable_vfs_handler_classes_via_environment(self):
 
1038
        # VFS handler classes will raise an error from "execute" if BZR_NO_SMART_VFS
 
1039
        # is set.
 
1040
        handler = vfs.HasRequest(None)
 
1041
        # set environment variable after construction to make sure it's
 
1042
        # examined.
 
1043
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp has
 
1044
        # called _captureVar, so it will be restored to the right state
 
1045
        # afterwards.
 
1046
        os.environ['BZR_NO_SMART_VFS'] = ''
 
1047
        self.assertRaises(errors.DisabledMethod, handler.execute)
996
1048
 
997
1049
    def test_readonly_exception_becomes_transport_not_possible(self):
998
1050
        """The response for a read-only error is ('ReadOnlyError')."""
999
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1051
        handler = self.build_handler(self.get_readonly_transport())
1000
1052
        # send a mkdir for foo, with no explicit mode - should fail.
1001
1053
        handler.dispatch_command('mkdir', ('foo', ''))
1002
1054
        # and the failure should be an explicit ReadOnlyError
1008
1060
 
1009
1061
    def test_hello_has_finished_body_on_dispatch(self):
1010
1062
        """The 'hello' command should set finished_reading."""
1011
 
        handler = smart.SmartServerRequestHandler(None)
 
1063
        handler = self.build_handler(None)
1012
1064
        handler.dispatch_command('hello', ())
1013
1065
        self.assertTrue(handler.finished_reading)
1014
1066
        self.assertNotEqual(None, handler.response)
1015
1067
 
1016
1068
    def test_put_bytes_non_atomic(self):
1017
1069
        """'put_...' should set finished_reading after reading the bytes."""
1018
 
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
1070
        handler = self.build_handler(self.get_transport())
1019
1071
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
1020
1072
        self.assertFalse(handler.finished_reading)
1021
1073
        handler.accept_body('1234')
1029
1081
    def test_readv_accept_body(self):
1030
1082
        """'readv' should set finished_reading after reading offsets."""
1031
1083
        self.build_tree(['a-file'])
1032
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1084
        handler = self.build_handler(self.get_readonly_transport())
1033
1085
        handler.dispatch_command('readv', ('a-file', ))
1034
1086
        self.assertFalse(handler.finished_reading)
1035
1087
        handler.accept_body('2,')
1044
1096
    def test_readv_short_read_response_contents(self):
1045
1097
        """'readv' when a short read occurs sets the response appropriately."""
1046
1098
        self.build_tree(['a-file'])
1047
 
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
1099
        handler = self.build_handler(self.get_readonly_transport())
1048
1100
        handler.dispatch_command('readv', ('a-file', ))
1049
1101
        # read beyond the end of the file.
1050
1102
        handler.accept_body('100,1')
1055
1107
        self.assertEqual(None, handler.response.body)
1056
1108
 
1057
1109
 
1058
 
class SmartTransportRegistration(tests.TestCase):
 
1110
class RemoteTransportRegistration(tests.TestCase):
1059
1111
 
1060
1112
    def test_registration(self):
1061
1113
        t = get_transport('bzr+ssh://example.com/path')
1062
 
        self.assertIsInstance(t, smart.SmartSSHTransport)
 
1114
        self.assertIsInstance(t, remote.SmartSSHTransport)
1063
1115
        self.assertEqual('example.com', t._host)
1064
1116
 
1065
1117
 
1066
 
class TestSmartTransport(tests.TestCase):
 
1118
class TestRemoteTransport(tests.TestCase):
1067
1119
        
1068
1120
    def test_use_connection_factory(self):
1069
 
        # We want to be able to pass a client as a parameter to SmartTransport.
 
1121
        # We want to be able to pass a client as a parameter to RemoteTransport.
1070
1122
        input = StringIO("ok\n3\nbardone\n")
1071
1123
        output = StringIO()
1072
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1073
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1124
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1125
        transport = remote.SmartTransport(
 
1126
            'bzr://localhost/', medium=client_medium)
1074
1127
 
1075
1128
        # We want to make sure the client is used when the first remote
1076
1129
        # method is called.  No data should have been sent, or read.
1088
1141
 
1089
1142
    def test__translate_error_readonly(self):
1090
1143
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
1091
 
        medium = smart.SmartClientMedium()
1092
 
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
 
1144
        client_medium = medium.SmartClientMedium()
 
1145
        transport = remote.SmartTransport(
 
1146
            'bzr://localhost/', medium=client_medium)
1093
1147
        self.assertRaises(errors.TransportNotPossible,
1094
1148
            transport._translate_error, ("ReadOnlyError", ))
1095
1149
 
1096
1150
 
1097
 
class InstrumentedServerProtocol(smart.SmartServerStreamMedium):
 
1151
class InstrumentedServerProtocol(medium.SmartServerStreamMedium):
1098
1152
    """A smart server which is backed by memory and saves its write requests."""
1099
1153
 
1100
1154
    def __init__(self, write_output_list):
1101
 
        smart.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
 
1155
        medium.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
1102
1156
        self._write_output_list = write_output_list
1103
1157
 
1104
1158
 
1117
1171
 
1118
1172
    def setUp(self):
1119
1173
        super(TestSmartProtocol, self).setUp()
 
1174
        # XXX: self.server_to_client doesn't seem to be used.  If so,
 
1175
        # InstrumentedServerProtocol is redundant too.
1120
1176
        self.server_to_client = []
1121
1177
        self.to_server = StringIO()
1122
1178
        self.to_client = StringIO()
1123
 
        self.client_medium = smart.SmartSimplePipesClientMedium(self.to_client,
 
1179
        self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
1124
1180
            self.to_server)
1125
 
        self.client_protocol = smart.SmartClientRequestProtocolOne(
 
1181
        self.client_protocol = protocol.SmartClientRequestProtocolOne(
1126
1182
            self.client_medium)
1127
1183
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
1128
 
        self.smart_server_request = smart.SmartServerRequestHandler(None)
 
1184
        self.smart_server_request = request.SmartServerRequestHandler(
 
1185
            None, request.request_handlers)
1129
1186
 
1130
1187
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
1131
 
        client, smart_server_request):
 
1188
        client):
1132
1189
        """Check that smart (de)serialises offsets as expected.
1133
1190
        
1134
1191
        We check both serialisation and deserialisation at the same time
1137
1194
        
1138
1195
        :param expected_offsets: a readv offset list.
1139
1196
        :param expected_seralised: an expected serial form of the offsets.
1140
 
        :param smart_server_request: a SmartServerRequestHandler instance.
1141
1197
        """
1142
 
        # XXX: 'smart_server_request' should be a SmartServerRequestProtocol in
1143
 
        # future.
1144
 
        offsets = smart_server_request._deserialise_offsets(expected_serialised)
 
1198
        # XXX: '_deserialise_offsets' should be a method of the
 
1199
        # SmartServerRequestProtocol in future.
 
1200
        readv_cmd = vfs.ReadvRequest(None)
 
1201
        offsets = readv_cmd._deserialise_offsets(expected_serialised)
1145
1202
        self.assertEqual(expected_offsets, offsets)
1146
1203
        serialised = client._serialise_offsets(offsets)
1147
1204
        self.assertEqual(expected_serialised, serialised)
1148
1205
 
1149
1206
    def build_protocol_waiting_for_body(self):
1150
1207
        out_stream = StringIO()
1151
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1152
 
        protocol.has_dispatched = True
1153
 
        protocol.request = smart.SmartServerRequestHandler(None)
1154
 
        def handle_end_of_bytes():
1155
 
            self.end_received = True
1156
 
            self.assertEqual('abcdefg', protocol.request._body_bytes)
1157
 
            protocol.request.response = smart.SmartServerResponse(('ok', ))
1158
 
        protocol.request._end_of_body_handler = handle_end_of_bytes
 
1208
        smart_protocol = protocol.SmartServerRequestProtocolOne(None,
 
1209
                out_stream.write)
 
1210
        smart_protocol.has_dispatched = True
 
1211
        smart_protocol.request = self.smart_server_request
 
1212
        class FakeCommand(object):
 
1213
            def do_body(cmd, body_bytes):
 
1214
                self.end_received = True
 
1215
                self.assertEqual('abcdefg', body_bytes)
 
1216
                return request.SmartServerResponse(('ok', ))
 
1217
        smart_protocol.request._command = FakeCommand()
1159
1218
        # Call accept_bytes to make sure that internal state like _body_decoder
1160
1219
        # is initialised.  This test should probably be given a clearer
1161
1220
        # interface to work with that will not cause this inconsistency.
1162
1221
        #   -- Andrew Bennetts, 2006-09-28
1163
 
        protocol.accept_bytes('')
1164
 
        return protocol
 
1222
        smart_protocol.accept_bytes('')
 
1223
        return smart_protocol
1165
1224
 
1166
1225
    def test_construct_version_one_server_protocol(self):
1167
 
        protocol = smart.SmartServerRequestProtocolOne(None, None)
1168
 
        self.assertEqual('', protocol.excess_buffer)
1169
 
        self.assertEqual('', protocol.in_buffer)
1170
 
        self.assertFalse(protocol.has_dispatched)
1171
 
        self.assertEqual(1, protocol.next_read_size())
 
1226
        smart_protocol = protocol.SmartServerRequestProtocolOne(None, None)
 
1227
        self.assertEqual('', smart_protocol.excess_buffer)
 
1228
        self.assertEqual('', smart_protocol.in_buffer)
 
1229
        self.assertFalse(smart_protocol.has_dispatched)
 
1230
        self.assertEqual(1, smart_protocol.next_read_size())
1172
1231
 
1173
1232
    def test_construct_version_one_client_protocol(self):
1174
1233
        # we can construct a client protocol from a client medium request
1175
1234
        output = StringIO()
1176
 
        medium = smart.SmartSimplePipesClientMedium(None, output)
1177
 
        request = medium.get_request()
1178
 
        client_protocol = smart.SmartClientRequestProtocolOne(request)
 
1235
        client_medium = medium.SmartSimplePipesClientMedium(None, output)
 
1236
        request = client_medium.get_request()
 
1237
        client_protocol = protocol.SmartClientRequestProtocolOne(request)
1179
1238
 
1180
1239
    def test_server_offset_serialisation(self):
1181
1240
        """The Smart protocol serialises offsets as a comma and \n string.
1184
1243
        one with the order of reads not increasing (an out of order read), and
1185
1244
        one that should coalesce.
1186
1245
        """
1187
 
        self.assertOffsetSerialisation([], '',
1188
 
            self.client_protocol, self.smart_server_request)
1189
 
        self.assertOffsetSerialisation([(1,2)], '1,2',
1190
 
            self.client_protocol, self.smart_server_request)
 
1246
        self.assertOffsetSerialisation([], '', self.client_protocol)
 
1247
        self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
1191
1248
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
1192
 
            self.client_protocol, self.smart_server_request)
 
1249
            self.client_protocol)
1193
1250
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
1194
 
            '1,2\n3,4\n100,200', self.client_protocol, self.smart_server_request)
 
1251
            '1,2\n3,4\n100,200', self.client_protocol)
1195
1252
 
1196
1253
    def test_accept_bytes_of_bad_request_to_protocol(self):
1197
1254
        out_stream = StringIO()
1198
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1199
 
        protocol.accept_bytes('abc')
1200
 
        self.assertEqual('abc', protocol.in_buffer)
1201
 
        protocol.accept_bytes('\n')
1202
 
        self.assertEqual("error\x01Generic bzr smart protocol error: bad request"
1203
 
            " 'abc'\n", out_stream.getvalue())
1204
 
        self.assertTrue(protocol.has_dispatched)
1205
 
        self.assertEqual(0, protocol.next_read_size())
 
1255
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1256
            None, out_stream.write)
 
1257
        smart_protocol.accept_bytes('abc')
 
1258
        self.assertEqual('abc', smart_protocol.in_buffer)
 
1259
        smart_protocol.accept_bytes('\n')
 
1260
        self.assertEqual(
 
1261
            "error\x01Generic bzr smart protocol error: bad request 'abc'\n",
 
1262
            out_stream.getvalue())
 
1263
        self.assertTrue(smart_protocol.has_dispatched)
 
1264
        self.assertEqual(0, smart_protocol.next_read_size())
1206
1265
 
1207
1266
    def test_accept_body_bytes_to_protocol(self):
1208
1267
        protocol = self.build_protocol_waiting_for_body()
1215
1274
        self.assertTrue(self.end_received)
1216
1275
 
1217
1276
    def test_accept_request_and_body_all_at_once(self):
 
1277
        self._captureVar('BZR_NO_SMART_VFS', None)
1218
1278
        mem_transport = memory.MemoryTransport()
1219
1279
        mem_transport.put_bytes('foo', 'abcdefghij')
1220
1280
        out_stream = StringIO()
1221
 
        protocol = smart.SmartServerRequestProtocolOne(mem_transport,
 
1281
        smart_protocol = protocol.SmartServerRequestProtocolOne(mem_transport,
1222
1282
                out_stream.write)
1223
 
        protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
1224
 
        self.assertEqual(0, protocol.next_read_size())
 
1283
        smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1284
        self.assertEqual(0, smart_protocol.next_read_size())
1225
1285
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
1226
 
        self.assertEqual('', protocol.excess_buffer)
1227
 
        self.assertEqual('', protocol.in_buffer)
 
1286
        self.assertEqual('', smart_protocol.excess_buffer)
 
1287
        self.assertEqual('', smart_protocol.in_buffer)
1228
1288
 
1229
1289
    def test_accept_excess_bytes_are_preserved(self):
1230
1290
        out_stream = StringIO()
1231
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1232
 
        protocol.accept_bytes('hello\nhello\n')
 
1291
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1292
            None, out_stream.write)
 
1293
        smart_protocol.accept_bytes('hello\nhello\n')
1233
1294
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1234
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1235
 
        self.assertEqual("", protocol.in_buffer)
 
1295
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1296
        self.assertEqual("", smart_protocol.in_buffer)
1236
1297
 
1237
1298
    def test_accept_excess_bytes_after_body(self):
1238
1299
        protocol = self.build_protocol_waiting_for_body()
1246
1307
 
1247
1308
    def test_accept_excess_bytes_after_dispatch(self):
1248
1309
        out_stream = StringIO()
1249
 
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
1250
 
        protocol.accept_bytes('hello\n')
 
1310
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1311
            None, out_stream.write)
 
1312
        smart_protocol.accept_bytes('hello\n')
1251
1313
        self.assertEqual("ok\x011\n", out_stream.getvalue())
1252
 
        protocol.accept_bytes('hel')
1253
 
        self.assertEqual("hel", protocol.excess_buffer)
1254
 
        protocol.accept_bytes('lo\n')
1255
 
        self.assertEqual("hello\n", protocol.excess_buffer)
1256
 
        self.assertEqual("", protocol.in_buffer)
 
1314
        smart_protocol.accept_bytes('hel')
 
1315
        self.assertEqual("hel", smart_protocol.excess_buffer)
 
1316
        smart_protocol.accept_bytes('lo\n')
 
1317
        self.assertEqual("hello\n", smart_protocol.excess_buffer)
 
1318
        self.assertEqual("", smart_protocol.in_buffer)
1257
1319
 
1258
1320
    def test__send_response_sets_finished_reading(self):
1259
 
        protocol = smart.SmartServerRequestProtocolOne(None, lambda x: None)
1260
 
        self.assertEqual(1, protocol.next_read_size())
1261
 
        protocol._send_response(('x',))
1262
 
        self.assertEqual(0, protocol.next_read_size())
 
1321
        smart_protocol = protocol.SmartServerRequestProtocolOne(
 
1322
            None, lambda x: None)
 
1323
        self.assertEqual(1, smart_protocol.next_read_size())
 
1324
        smart_protocol._send_response(('x',))
 
1325
        self.assertEqual(0, smart_protocol.next_read_size())
1263
1326
 
1264
1327
    def test_query_version(self):
1265
1328
        """query_version on a SmartClientProtocolOne should return a number.
1273
1336
        # the error if the response is a non-understood version.
1274
1337
        input = StringIO('ok\x011\n')
1275
1338
        output = StringIO()
1276
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1277
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1278
 
        self.assertEqual(1, protocol.query_version())
 
1339
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1340
        request = client_medium.get_request()
 
1341
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1342
        self.assertEqual(1, smart_protocol.query_version())
1279
1343
 
1280
1344
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
1281
1345
            input_tuples):
1286
1350
        # expected bytes
1287
1351
        for input_tuple in input_tuples:
1288
1352
            server_output = StringIO()
1289
 
            server_protocol = smart.SmartServerRequestProtocolOne(
 
1353
            server_protocol = protocol.SmartServerRequestProtocolOne(
1290
1354
                None, server_output.write)
1291
1355
            server_protocol._send_response(input_tuple)
1292
1356
            self.assertEqual(expected_bytes, server_output.getvalue())
1293
 
        # check the decoding of the client protocol from expected_bytes:
 
1357
        # check the decoding of the client smart_protocol from expected_bytes:
1294
1358
        input = StringIO(expected_bytes)
1295
1359
        output = StringIO()
1296
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1297
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1298
 
        protocol.call('foo')
1299
 
        self.assertEqual(expected_tuple, protocol.read_response_tuple())
 
1360
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1361
        request = client_medium.get_request()
 
1362
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1363
        smart_protocol.call('foo')
 
1364
        self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
1300
1365
 
1301
1366
    def test_client_call_empty_response(self):
1302
1367
        # protocol.call() can get back an empty tuple as a response. This occurs
1311
1376
            [('a', 'b', '34')])
1312
1377
 
1313
1378
    def test_client_call_with_body_bytes_uploads(self):
1314
 
        # protocol.call_with_upload should length-prefix the bytes onto the 
 
1379
        # protocol.call_with_body_bytes should length-prefix the bytes onto the
1315
1380
        # wire.
1316
1381
        expected_bytes = "foo\n7\nabcdefgdone\n"
1317
1382
        input = StringIO("\n")
1318
1383
        output = StringIO()
1319
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1320
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1321
 
        protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1384
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1385
        request = client_medium.get_request()
 
1386
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1387
        smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
1322
1388
        self.assertEqual(expected_bytes, output.getvalue())
1323
1389
 
1324
1390
    def test_client_call_with_body_readv_array(self):
1327
1393
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
1328
1394
        input = StringIO("\n")
1329
1395
        output = StringIO()
1330
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1331
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1332
 
        protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1396
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1397
        request = client_medium.get_request()
 
1398
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1399
        smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
1333
1400
        self.assertEqual(expected_bytes, output.getvalue())
1334
1401
 
1335
1402
    def test_client_read_body_bytes_all(self):
1339
1406
        server_bytes = "ok\n7\n1234567done\n"
1340
1407
        input = StringIO(server_bytes)
1341
1408
        output = StringIO()
1342
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1343
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1344
 
        protocol.call('foo')
1345
 
        protocol.read_response_tuple(True)
1346
 
        self.assertEqual(expected_bytes, protocol.read_body_bytes())
 
1409
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1410
        request = client_medium.get_request()
 
1411
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1412
        smart_protocol.call('foo')
 
1413
        smart_protocol.read_response_tuple(True)
 
1414
        self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
1347
1415
 
1348
1416
    def test_client_read_body_bytes_incremental(self):
1349
1417
        # test reading a few bytes at a time from the body
1355
1423
        server_bytes = "ok\n7\n1234567done\n"
1356
1424
        input = StringIO(server_bytes)
1357
1425
        output = StringIO()
1358
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1359
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1360
 
        protocol.call('foo')
1361
 
        protocol.read_response_tuple(True)
1362
 
        self.assertEqual(expected_bytes[0:2], protocol.read_body_bytes(2))
1363
 
        self.assertEqual(expected_bytes[2:4], protocol.read_body_bytes(2))
1364
 
        self.assertEqual(expected_bytes[4:6], protocol.read_body_bytes(2))
1365
 
        self.assertEqual(expected_bytes[6], protocol.read_body_bytes())
 
1426
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1427
        request = client_medium.get_request()
 
1428
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1429
        smart_protocol.call('foo')
 
1430
        smart_protocol.read_response_tuple(True)
 
1431
        self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
 
1432
        self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
 
1433
        self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
 
1434
        self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
1366
1435
 
1367
1436
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
1368
1437
        # cancelling the expected body needs to finish the request, but not
1371
1440
        server_bytes = "ok\n7\n1234567done\n"
1372
1441
        input = StringIO(server_bytes)
1373
1442
        output = StringIO()
1374
 
        medium = smart.SmartSimplePipesClientMedium(input, output)
1375
 
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
1376
 
        protocol.call('foo')
1377
 
        protocol.read_response_tuple(True)
1378
 
        protocol.cancel_read_body()
 
1443
        client_medium = medium.SmartSimplePipesClientMedium(input, output)
 
1444
        request = client_medium.get_request()
 
1445
        smart_protocol = protocol.SmartClientRequestProtocolOne(request)
 
1446
        smart_protocol.call('foo')
 
1447
        smart_protocol.read_response_tuple(True)
 
1448
        smart_protocol.cancel_read_body()
1379
1449
        self.assertEqual(3, input.tell())
1380
 
        self.assertRaises(errors.ReadingCompleted, protocol.read_body_bytes)
 
1450
        self.assertRaises(
 
1451
            errors.ReadingCompleted, smart_protocol.read_body_bytes)
1381
1452
 
1382
1453
 
1383
1454
class LengthPrefixedBodyDecoder(tests.TestCase):
1386
1457
    # something similar to the ProtocolBase method.
1387
1458
 
1388
1459
    def test_construct(self):
1389
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1460
        decoder = protocol.LengthPrefixedBodyDecoder()
1390
1461
        self.assertFalse(decoder.finished_reading)
1391
1462
        self.assertEqual(6, decoder.next_read_size())
1392
1463
        self.assertEqual('', decoder.read_pending_data())
1393
1464
        self.assertEqual('', decoder.unused_data)
1394
1465
 
1395
1466
    def test_accept_bytes(self):
1396
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1467
        decoder = protocol.LengthPrefixedBodyDecoder()
1397
1468
        decoder.accept_bytes('')
1398
1469
        self.assertFalse(decoder.finished_reading)
1399
1470
        self.assertEqual(6, decoder.next_read_size())
1426
1497
        self.assertEqual('blarg', decoder.unused_data)
1427
1498
        
1428
1499
    def test_accept_bytes_all_at_once_with_excess(self):
1429
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1500
        decoder = protocol.LengthPrefixedBodyDecoder()
1430
1501
        decoder.accept_bytes('1\nadone\nunused')
1431
1502
        self.assertTrue(decoder.finished_reading)
1432
1503
        self.assertEqual(1, decoder.next_read_size())
1434
1505
        self.assertEqual('unused', decoder.unused_data)
1435
1506
 
1436
1507
    def test_accept_bytes_exact_end_of_body(self):
1437
 
        decoder = smart.LengthPrefixedBodyDecoder()
 
1508
        decoder = protocol.LengthPrefixedBodyDecoder()
1438
1509
        decoder.accept_bytes('1\na')
1439
1510
        self.assertFalse(decoder.finished_reading)
1440
1511
        self.assertEqual(5, decoder.next_read_size())
1458
1529
 
1459
1530
class HTTPTunnellingSmokeTest(tests.TestCaseWithTransport):
1460
1531
    
 
1532
    def setUp(self):
 
1533
        super(HTTPTunnellingSmokeTest, self).setUp()
 
1534
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1535
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1536
 
1461
1537
    def _test_bulk_data(self, url_protocol):
1462
1538
        # We should be able to send and receive bulk data in a single message.
1463
1539
        # The 'readv' command in the smart protocol both sends and receives bulk
1464
1540
        # data, so we use that.
1465
1541
        self.build_tree(['data-file'])
1466
 
        http_server = HTTPServerWithSmarts()
1467
 
        http_server._url_protocol = url_protocol
1468
 
        http_server.setUp()
1469
 
        self.addCleanup(http_server.tearDown)
1470
 
 
1471
 
        http_transport = get_transport(http_server.get_url())
1472
 
 
 
1542
        self.transport_readonly_server = HTTPServerWithSmarts
 
1543
 
 
1544
        http_transport = self.get_readonly_transport()
1473
1545
        medium = http_transport.get_smart_medium()
1474
1546
        #remote_transport = RemoteTransport('fake_url', medium)
1475
 
        remote_transport = smart.SmartTransport('/', medium=medium)
 
1547
        remote_transport = remote.SmartTransport('/', medium=medium)
1476
1548
        self.assertEqual(
1477
1549
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
1478
1550
 
1497
1569
    def _test_http_send_smart_request(self, url_protocol):
1498
1570
        http_server = HTTPServerWithSmarts()
1499
1571
        http_server._url_protocol = url_protocol
1500
 
        http_server.setUp()
 
1572
        http_server.setUp(self.get_vfs_only_server())
1501
1573
        self.addCleanup(http_server.tearDown)
1502
1574
 
1503
1575
        post_body = 'hello\n'
1519
1591
        self._test_http_send_smart_request('http+urllib')
1520
1592
 
1521
1593
    def test_http_server_with_smarts(self):
1522
 
        http_server = HTTPServerWithSmarts()
1523
 
        http_server.setUp()
1524
 
        self.addCleanup(http_server.tearDown)
 
1594
        self.transport_readonly_server = HTTPServerWithSmarts
1525
1595
 
1526
1596
        post_body = 'hello\n'
1527
1597
        expected_reply_body = 'ok\x011\n'
1528
1598
 
1529
 
        smart_server_url = http_server.get_url() + '.bzr/smart'
 
1599
        smart_server_url = self.get_readonly_url('.bzr/smart')
1530
1600
        reply = urllib2.urlopen(smart_server_url, post_body).read()
1531
1601
 
1532
1602
        self.assertEqual(expected_reply_body, reply)
1533
1603
 
1534
1604
    def test_smart_http_server_post_request_handler(self):
1535
 
        http_server = HTTPServerWithSmarts()
1536
 
        http_server.setUp()
1537
 
        self.addCleanup(http_server.tearDown)
1538
 
        httpd = http_server._get_httpd()
 
1605
        self.transport_readonly_server = HTTPServerWithSmarts
 
1606
        httpd = self.get_readonly_server()._get_httpd()
1539
1607
 
1540
1608
        socket = SampleSocket(
1541
1609
            'POST /.bzr/smart HTTP/1.0\r\n'
1577
1645
        else:
1578
1646
            return self.writefile
1579
1647
 
1580
 
        
 
1648
 
1581
1649
# TODO: Client feature that does get_bundle and then installs that into a
1582
1650
# branch; this can be used in place of the regular pull/fetch operation when
1583
1651
# coming from a smart server.