~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Martin Pool
  • Date: 2006-11-02 10:20:19 UTC
  • mfrom: (2114 +trunk)
  • mto: This revision was merged to the branch mainline in revision 2119.
  • Revision ID: mbp@sourcefrog.net-20061102102019-9a5a02f485dff6f6
merge bzr.dev and reconcile several changes, also some test fixes

Show diffs side-by-side

added added

removed removed

Lines of Context:
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
21
 
import subprocess
22
 
import sys
 
21
import os
 
22
import socket
 
23
import threading
 
24
import urllib2
23
25
 
24
 
import bzrlib
25
26
from bzrlib import (
26
27
        bzrdir,
27
28
        errors,
 
29
        osutils,
28
30
        tests,
29
31
        urlutils,
30
32
        )
34
36
        memory,
35
37
        smart,
36
38
        )
37
 
 
38
 
 
39
 
class SmartClientTests(tests.TestCase):
40
 
 
41
 
    def test_construct_smart_stream_client(self):
42
 
        # make a new client; this really wants a connector function that returns
43
 
        # two fifos or sockets but the constructor should not do any IO
44
 
        client = smart.SmartStreamClient(None)
45
 
 
46
 
 
47
 
class TCPClientTests(tests.TestCaseWithTransport):
 
39
from bzrlib.transport.http import (
 
40
        HTTPServerWithSmarts,
 
41
        SmartClientHTTPMediumRequest,
 
42
        SmartRequestHandler,
 
43
        )
 
44
 
 
45
 
 
46
class StringIOSSHVendor(object):
 
47
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
 
48
 
 
49
    def __init__(self, read_from, write_to):
 
50
        self.read_from = read_from
 
51
        self.write_to = write_to
 
52
        self.calls = []
 
53
 
 
54
    def connect_ssh(self, username, password, host, port, command):
 
55
        self.calls.append(('connect_ssh', username, password, host, port,
 
56
            command))
 
57
        return StringIOSSHConnection(self)
 
58
 
 
59
 
 
60
class StringIOSSHConnection(object):
 
61
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
 
62
 
 
63
    def __init__(self, vendor):
 
64
        self.vendor = vendor
 
65
    
 
66
    def close(self):
 
67
        self.vendor.calls.append(('close', ))
 
68
        
 
69
    def get_filelike_channels(self):
 
70
        return self.vendor.read_from, self.vendor.write_to
 
71
 
 
72
 
 
73
 
 
74
class SmartClientMediumTests(tests.TestCase):
 
75
    """Tests for SmartClientMedium.
 
76
 
 
77
    We should create a test scenario for this: we need a server module that
 
78
    construct the test-servers (like make_loopsocket_and_medium), and the list
 
79
    of SmartClientMedium classes to test.
 
80
    """
 
81
 
 
82
    def make_loopsocket_and_medium(self):
 
83
        """Create a loopback socket for testing, and a medium aimed at it."""
 
84
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
85
        sock.bind(('127.0.0.1', 0))
 
86
        sock.listen(1)
 
87
        port = sock.getsockname()[1]
 
88
        medium = smart.SmartTCPClientMedium('127.0.0.1', port)
 
89
        return sock, medium
 
90
 
 
91
    def receive_bytes_on_server(self, sock, bytes):
 
92
        """Accept a connection on sock and read 3 bytes.
 
93
 
 
94
        The bytes are appended to the list bytes.
 
95
 
 
96
        :return: a Thread which is running to do the accept and recv.
 
97
        """
 
98
        def _receive_bytes_on_server():
 
99
            connection, address = sock.accept()
 
100
            bytes.append(osutils.recv_all(connection, 3))
 
101
            connection.close()
 
102
        t = threading.Thread(target=_receive_bytes_on_server)
 
103
        t.start()
 
104
        return t
 
105
    
 
106
    def test_construct_smart_stream_medium_client(self):
 
107
        # make a new instance of the common base for Stream-like Mediums.
 
108
        # this just ensures that the constructor stays parameter-free which
 
109
        # is important for reuse : some subclasses will dynamically connect,
 
110
        # others are always on, etc.
 
111
        medium = smart.SmartClientStreamMedium()
 
112
 
 
113
    def test_construct_smart_client_medium(self):
 
114
        # the base client medium takes no parameters
 
115
        medium = smart.SmartClientMedium()
 
116
    
 
117
    def test_construct_smart_simple_pipes_client_medium(self):
 
118
        # the SimplePipes client medium takes two pipes:
 
119
        # readable pipe, writeable pipe.
 
120
        # Constructing one should just save these and do nothing.
 
121
        # We test this by passing in None.
 
122
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
123
        
 
124
    def test_simple_pipes_client_request_type(self):
 
125
        # SimplePipesClient should use SmartClientStreamMediumRequest's.
 
126
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
127
        request = medium.get_request()
 
128
        self.assertIsInstance(request, smart.SmartClientStreamMediumRequest)
 
129
 
 
130
    def test_simple_pipes_client_get_concurrent_requests(self):
 
131
        # the simple_pipes client does not support pipelined requests:
 
132
        # but it does support serial requests: we construct one after 
 
133
        # another is finished. This is a smoke test testing the integration
 
134
        # of the SmartClientStreamMediumRequest and the SmartClientStreamMedium
 
135
        # classes - as the sibling classes share this logic, they do not have
 
136
        # explicit tests for this.
 
137
        output = StringIO()
 
138
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
139
        request = medium.get_request()
 
140
        request.finished_writing()
 
141
        request.finished_reading()
 
142
        request2 = medium.get_request()
 
143
        request2.finished_writing()
 
144
        request2.finished_reading()
 
145
 
 
146
    def test_simple_pipes_client__accept_bytes_writes_to_writable(self):
 
147
        # accept_bytes writes to the writeable pipe.
 
148
        output = StringIO()
 
149
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
150
        medium._accept_bytes('abc')
 
151
        self.assertEqual('abc', output.getvalue())
 
152
    
 
153
    def test_simple_pipes_client_disconnect_does_nothing(self):
 
154
        # calling disconnect does nothing.
 
155
        input = StringIO()
 
156
        output = StringIO()
 
157
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
158
        # send some bytes to ensure disconnecting after activity still does not
 
159
        # close.
 
160
        medium._accept_bytes('abc')
 
161
        medium.disconnect()
 
162
        self.assertFalse(input.closed)
 
163
        self.assertFalse(output.closed)
 
164
 
 
165
    def test_simple_pipes_client_accept_bytes_after_disconnect(self):
 
166
        # calling disconnect on the client does not alter the pipe that
 
167
        # accept_bytes writes to.
 
168
        input = StringIO()
 
169
        output = StringIO()
 
170
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
171
        medium._accept_bytes('abc')
 
172
        medium.disconnect()
 
173
        medium._accept_bytes('abc')
 
174
        self.assertFalse(input.closed)
 
175
        self.assertFalse(output.closed)
 
176
        self.assertEqual('abcabc', output.getvalue())
 
177
    
 
178
    def test_simple_pipes_client_ignores_disconnect_when_not_connected(self):
 
179
        # Doing a disconnect on a new (and thus unconnected) SimplePipes medium
 
180
        # does nothing.
 
181
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
182
        medium.disconnect()
 
183
 
 
184
    def test_simple_pipes_client_can_always_read(self):
 
185
        # SmartSimplePipesClientMedium is never disconnected, so read_bytes
 
186
        # always tries to read from the underlying pipe.
 
187
        input = StringIO('abcdef')
 
188
        medium = smart.SmartSimplePipesClientMedium(input, None)
 
189
        self.assertEqual('abc', medium.read_bytes(3))
 
190
        medium.disconnect()
 
191
        self.assertEqual('def', medium.read_bytes(3))
 
192
        
 
193
    def test_simple_pipes_client_supports__flush(self):
 
194
        # invoking _flush on a SimplePipesClient should flush the output 
 
195
        # pipe. We test this by creating an output pipe that records
 
196
        # flush calls made to it.
 
197
        from StringIO import StringIO # get regular StringIO
 
198
        input = StringIO()
 
199
        output = StringIO()
 
200
        flush_calls = []
 
201
        def logging_flush(): flush_calls.append('flush')
 
202
        output.flush = logging_flush
 
203
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
204
        # this call is here to ensure we only flush once, not on every
 
205
        # _accept_bytes call.
 
206
        medium._accept_bytes('abc')
 
207
        medium._flush()
 
208
        medium.disconnect()
 
209
        self.assertEqual(['flush'], flush_calls)
 
210
 
 
211
    def test_construct_smart_ssh_client_medium(self):
 
212
        # the SSH client medium takes:
 
213
        # host, port, username, password, vendor
 
214
        # Constructing one should just save these and do nothing.
 
215
        # we test this by creating a empty bound socket and constructing
 
216
        # a medium.
 
217
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
218
        sock.bind(('127.0.0.1', 0))
 
219
        unopened_port = sock.getsockname()[1]
 
220
        # having vendor be invalid means that if it tries to connect via the
 
221
        # vendor it will blow up.
 
222
        medium = smart.SmartSSHClientMedium('127.0.0.1', unopened_port,
 
223
            username=None, password=None, vendor="not a vendor")
 
224
        sock.close()
 
225
 
 
226
    def test_ssh_client_connects_on_first_use(self):
 
227
        # The only thing that initiates a connection from the medium is giving
 
228
        # it bytes.
 
229
        output = StringIO()
 
230
        vendor = StringIOSSHVendor(StringIO(), output)
 
231
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
 
232
            'a password', vendor)
 
233
        medium._accept_bytes('abc')
 
234
        self.assertEqual('abc', output.getvalue())
 
235
        self.assertEqual([('connect_ssh', 'a username', 'a password',
 
236
            'a hostname', 'a port',
 
237
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
238
            vendor.calls)
 
239
    
 
240
    def test_ssh_client_changes_command_when_BZR_REMOTE_PATH_is_set(self):
 
241
        # The only thing that initiates a connection from the medium is giving
 
242
        # it bytes.
 
243
        output = StringIO()
 
244
        vendor = StringIOSSHVendor(StringIO(), output)
 
245
        orig_bzr_remote_path = os.environ.get('BZR_REMOTE_PATH')
 
246
        def cleanup_environ():
 
247
            osutils.set_or_unset_env('BZR_REMOTE_PATH', orig_bzr_remote_path)
 
248
        self.addCleanup(cleanup_environ)
 
249
        os.environ['BZR_REMOTE_PATH'] = 'fugly'
 
250
        medium = smart.SmartSSHClientMedium('a hostname', 'a port', 'a username',
 
251
            'a password', vendor)
 
252
        medium._accept_bytes('abc')
 
253
        self.assertEqual('abc', output.getvalue())
 
254
        self.assertEqual([('connect_ssh', 'a username', 'a password',
 
255
            'a hostname', 'a port',
 
256
            ['fugly', 'serve', '--inet', '--directory=/', '--allow-writes'])],
 
257
            vendor.calls)
 
258
    
 
259
    def test_ssh_client_disconnect_does_so(self):
 
260
        # calling disconnect should disconnect both the read_from and write_to
 
261
        # file-like object it from the ssh connection.
 
262
        input = StringIO()
 
263
        output = StringIO()
 
264
        vendor = StringIOSSHVendor(input, output)
 
265
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
266
        medium._accept_bytes('abc')
 
267
        medium.disconnect()
 
268
        self.assertTrue(input.closed)
 
269
        self.assertTrue(output.closed)
 
270
        self.assertEqual([
 
271
            ('connect_ssh', None, None, 'a hostname', None,
 
272
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
273
            ('close', ),
 
274
            ],
 
275
            vendor.calls)
 
276
 
 
277
    def test_ssh_client_disconnect_allows_reconnection(self):
 
278
        # calling disconnect on the client terminates the connection, but should
 
279
        # not prevent additional connections occuring.
 
280
        # we test this by initiating a second connection after doing a
 
281
        # disconnect.
 
282
        input = StringIO()
 
283
        output = StringIO()
 
284
        vendor = StringIOSSHVendor(input, output)
 
285
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
286
        medium._accept_bytes('abc')
 
287
        medium.disconnect()
 
288
        # the disconnect has closed output, so we need a new output for the
 
289
        # new connection to write to.
 
290
        input2 = StringIO()
 
291
        output2 = StringIO()
 
292
        vendor.read_from = input2
 
293
        vendor.write_to = output2
 
294
        medium._accept_bytes('abc')
 
295
        medium.disconnect()
 
296
        self.assertTrue(input.closed)
 
297
        self.assertTrue(output.closed)
 
298
        self.assertTrue(input2.closed)
 
299
        self.assertTrue(output2.closed)
 
300
        self.assertEqual([
 
301
            ('connect_ssh', None, None, 'a hostname', None,
 
302
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
303
            ('close', ),
 
304
            ('connect_ssh', None, None, 'a hostname', None,
 
305
            ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
306
            ('close', ),
 
307
            ],
 
308
            vendor.calls)
 
309
    
 
310
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
 
311
        # Doing a disconnect on a new (and thus unconnected) SSH medium
 
312
        # does not fail.  It's ok to disconnect an unconnected medium.
 
313
        medium = smart.SmartSSHClientMedium(None)
 
314
        medium.disconnect()
 
315
 
 
316
    def test_ssh_client_raises_on_read_when_not_connected(self):
 
317
        # Doing a read on a new (and thus unconnected) SSH medium raises
 
318
        # MediumNotConnected.
 
319
        medium = smart.SmartSSHClientMedium(None)
 
320
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
 
321
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
322
 
 
323
    def test_ssh_client_supports__flush(self):
 
324
        # invoking _flush on a SSHClientMedium should flush the output 
 
325
        # pipe. We test this by creating an output pipe that records
 
326
        # flush calls made to it.
 
327
        from StringIO import StringIO # get regular StringIO
 
328
        input = StringIO()
 
329
        output = StringIO()
 
330
        flush_calls = []
 
331
        def logging_flush(): flush_calls.append('flush')
 
332
        output.flush = logging_flush
 
333
        vendor = StringIOSSHVendor(input, output)
 
334
        medium = smart.SmartSSHClientMedium('a hostname', vendor=vendor)
 
335
        # this call is here to ensure we only flush once, not on every
 
336
        # _accept_bytes call.
 
337
        medium._accept_bytes('abc')
 
338
        medium._flush()
 
339
        medium.disconnect()
 
340
        self.assertEqual(['flush'], flush_calls)
 
341
        
 
342
    def test_construct_smart_tcp_client_medium(self):
 
343
        # the TCP client medium takes a host and a port.  Constructing it won't
 
344
        # connect to anything.
 
345
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
346
        sock.bind(('127.0.0.1', 0))
 
347
        unopened_port = sock.getsockname()[1]
 
348
        medium = smart.SmartTCPClientMedium('127.0.0.1', unopened_port)
 
349
        sock.close()
 
350
 
 
351
    def test_tcp_client_connects_on_first_use(self):
 
352
        # The only thing that initiates a connection from the medium is giving
 
353
        # it bytes.
 
354
        sock, medium = self.make_loopsocket_and_medium()
 
355
        bytes = []
 
356
        t = self.receive_bytes_on_server(sock, bytes)
 
357
        medium.accept_bytes('abc')
 
358
        t.join()
 
359
        sock.close()
 
360
        self.assertEqual(['abc'], bytes)
 
361
    
 
362
    def test_tcp_client_disconnect_does_so(self):
 
363
        # calling disconnect on the client terminates the connection.
 
364
        # we test this by forcing a short read during a socket.MSG_WAITALL
 
365
        # call: write 2 bytes, try to read 3, and then the client disconnects.
 
366
        sock, medium = self.make_loopsocket_and_medium()
 
367
        bytes = []
 
368
        t = self.receive_bytes_on_server(sock, bytes)
 
369
        medium.accept_bytes('ab')
 
370
        medium.disconnect()
 
371
        t.join()
 
372
        sock.close()
 
373
        self.assertEqual(['ab'], bytes)
 
374
        # now disconnect again: this should not do anything, if disconnection
 
375
        # really did disconnect.
 
376
        medium.disconnect()
 
377
    
 
378
    def test_tcp_client_ignores_disconnect_when_not_connected(self):
 
379
        # Doing a disconnect on a new (and thus unconnected) TCP medium
 
380
        # does not fail.  It's ok to disconnect an unconnected medium.
 
381
        medium = smart.SmartTCPClientMedium(None, None)
 
382
        medium.disconnect()
 
383
 
 
384
    def test_tcp_client_raises_on_read_when_not_connected(self):
 
385
        # Doing a read on a new (and thus unconnected) TCP medium raises
 
386
        # MediumNotConnected.
 
387
        medium = smart.SmartTCPClientMedium(None, None)
 
388
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 0)
 
389
        self.assertRaises(errors.MediumNotConnected, medium.read_bytes, 1)
 
390
 
 
391
    def test_tcp_client_supports__flush(self):
 
392
        # invoking _flush on a TCPClientMedium should do something useful.
 
393
        # RBC 20060922 not sure how to test/tell in this case.
 
394
        sock, medium = self.make_loopsocket_and_medium()
 
395
        bytes = []
 
396
        t = self.receive_bytes_on_server(sock, bytes)
 
397
        # try with nothing buffered
 
398
        medium._flush()
 
399
        medium._accept_bytes('ab')
 
400
        # and with something sent.
 
401
        medium._flush()
 
402
        medium.disconnect()
 
403
        t.join()
 
404
        sock.close()
 
405
        self.assertEqual(['ab'], bytes)
 
406
        # now disconnect again : this should not do anything, if disconnection
 
407
        # really did disconnect.
 
408
        medium.disconnect()
 
409
 
 
410
 
 
411
class TestSmartClientStreamMediumRequest(tests.TestCase):
 
412
    """Tests the for SmartClientStreamMediumRequest.
 
413
    
 
414
    SmartClientStreamMediumRequest is a helper for the three stream based 
 
415
    mediums: TCP, SSH, SimplePipes, so we only test it once, and then test that
 
416
    those three mediums implement the interface it expects.
 
417
    """
 
418
 
 
419
    def test_accept_bytes_after_finished_writing_errors(self):
 
420
        # calling accept_bytes after calling finished_writing raises 
 
421
        # WritingCompleted to prevent bad assumptions on stream environments
 
422
        # breaking the needs of message-based environments.
 
423
        output = StringIO()
 
424
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
425
        request = smart.SmartClientStreamMediumRequest(medium)
 
426
        request.finished_writing()
 
427
        self.assertRaises(errors.WritingCompleted, request.accept_bytes, None)
 
428
 
 
429
    def test_accept_bytes(self):
 
430
        # accept bytes should invoke _accept_bytes on the stream medium.
 
431
        # we test this by using the SimplePipes medium - the most trivial one
 
432
        # and checking that the pipes get the data.
 
433
        input = StringIO()
 
434
        output = StringIO()
 
435
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
436
        request = smart.SmartClientStreamMediumRequest(medium)
 
437
        request.accept_bytes('123')
 
438
        request.finished_writing()
 
439
        request.finished_reading()
 
440
        self.assertEqual('', input.getvalue())
 
441
        self.assertEqual('123', output.getvalue())
 
442
 
 
443
    def test_construct_sets_stream_request(self):
 
444
        # constructing a SmartClientStreamMediumRequest on a StreamMedium sets
 
445
        # the current request to the new SmartClientStreamMediumRequest
 
446
        output = StringIO()
 
447
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
448
        request = smart.SmartClientStreamMediumRequest(medium)
 
449
        self.assertIs(medium._current_request, request)
 
450
 
 
451
    def test_construct_while_another_request_active_throws(self):
 
452
        # constructing a SmartClientStreamMediumRequest on a StreamMedium with
 
453
        # a non-None _current_request raises TooManyConcurrentRequests.
 
454
        output = StringIO()
 
455
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
456
        medium._current_request = "a"
 
457
        self.assertRaises(errors.TooManyConcurrentRequests,
 
458
            smart.SmartClientStreamMediumRequest, medium)
 
459
 
 
460
    def test_finished_read_clears_current_request(self):
 
461
        # calling finished_reading clears the current request from the requests
 
462
        # medium
 
463
        output = StringIO()
 
464
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
465
        request = smart.SmartClientStreamMediumRequest(medium)
 
466
        request.finished_writing()
 
467
        request.finished_reading()
 
468
        self.assertEqual(None, medium._current_request)
 
469
 
 
470
    def test_finished_read_before_finished_write_errors(self):
 
471
        # calling finished_reading before calling finished_writing triggers a
 
472
        # WritingNotComplete error.
 
473
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
474
        request = smart.SmartClientStreamMediumRequest(medium)
 
475
        self.assertRaises(errors.WritingNotComplete, request.finished_reading)
 
476
        
 
477
    def test_read_bytes(self):
 
478
        # read bytes should invoke _read_bytes on the stream medium.
 
479
        # we test this by using the SimplePipes medium - the most trivial one
 
480
        # and checking that the data is supplied. Its possible that a 
 
481
        # faulty implementation could poke at the pipe variables them selves,
 
482
        # but we trust that this will be caught as it will break the integration
 
483
        # smoke tests.
 
484
        input = StringIO('321')
 
485
        output = StringIO()
 
486
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
487
        request = smart.SmartClientStreamMediumRequest(medium)
 
488
        request.finished_writing()
 
489
        self.assertEqual('321', request.read_bytes(3))
 
490
        request.finished_reading()
 
491
        self.assertEqual('', input.read())
 
492
        self.assertEqual('', output.getvalue())
 
493
 
 
494
    def test_read_bytes_before_finished_write_errors(self):
 
495
        # calling read_bytes before calling finished_writing triggers a
 
496
        # WritingNotComplete error because the Smart protocol is designed to be
 
497
        # compatible with strict message based protocols like HTTP where the
 
498
        # request cannot be submitted until the writing has completed.
 
499
        medium = smart.SmartSimplePipesClientMedium(None, None)
 
500
        request = smart.SmartClientStreamMediumRequest(medium)
 
501
        self.assertRaises(errors.WritingNotComplete, request.read_bytes, None)
 
502
 
 
503
    def test_read_bytes_after_finished_reading_errors(self):
 
504
        # calling read_bytes after calling finished_reading raises 
 
505
        # ReadingCompleted to prevent bad assumptions on stream environments
 
506
        # breaking the needs of message-based environments.
 
507
        output = StringIO()
 
508
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
509
        request = smart.SmartClientStreamMediumRequest(medium)
 
510
        request.finished_writing()
 
511
        request.finished_reading()
 
512
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
 
513
 
 
514
 
 
515
class RemoteTransportTests(tests.TestCaseWithTransport):
48
516
 
49
517
    def setUp(self):
50
 
        super(TCPClientTests, self).setUp()
 
518
        super(RemoteTransportTests, self).setUp()
51
519
        # We're allowed to set  the transport class here, so that we don't use
52
520
        # the default or a parameterized class, but rather use the
53
521
        # TestCaseWithTransport infrastructure to set up a smart server and
61
529
        t = self.get_transport()
62
530
        self.assertIsInstance(t, smart.SmartTransport)
63
531
 
64
 
    def test_get_client_from_transport(self):
 
532
    def test_get_medium_from_transport(self):
 
533
        """Remote transport has a medium always, which it can return."""
65
534
        t = self.get_transport()
66
 
        client = t.get_smart_client()
67
 
        self.assertIsInstance(client, smart.SmartStreamClient)
68
 
 
69
 
 
70
 
class BasicSmartTests(tests.TestCase):
 
535
        medium = t.get_smart_medium()
 
536
        self.assertIsInstance(medium, smart.SmartClientMedium)
 
537
 
 
538
 
 
539
class ErrorRaisingProtocol(object):
 
540
 
 
541
    def __init__(self, exception):
 
542
        self.exception = exception
 
543
 
 
544
    def next_read_size(self):
 
545
        raise self.exception
 
546
 
 
547
 
 
548
class SampleRequest(object):
 
549
    
 
550
    def __init__(self, expected_bytes):
 
551
        self.accepted_bytes = ''
 
552
        self._finished_reading = False
 
553
        self.expected_bytes = expected_bytes
 
554
        self.excess_buffer = ''
 
555
 
 
556
    def accept_bytes(self, bytes):
 
557
        self.accepted_bytes += bytes
 
558
        if self.accepted_bytes.startswith(self.expected_bytes):
 
559
            self._finished_reading = True
 
560
            self.excess_buffer = self.accepted_bytes[len(self.expected_bytes):]
 
561
 
 
562
    def next_read_size(self):
 
563
        if self._finished_reading:
 
564
            return 0
 
565
        else:
 
566
            return 1
 
567
 
 
568
 
 
569
class TestSmartServerStreamMedium(tests.TestCase):
 
570
 
 
571
    def portable_socket_pair(self):
 
572
        """Return a pair of TCP sockets connected to each other.
 
573
        
 
574
        Unlike socket.socketpair, this should work on Windows.
 
575
        """
 
576
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
577
        listen_sock.bind(('127.0.0.1', 0))
 
578
        listen_sock.listen(1)
 
579
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
580
        client_sock.connect(listen_sock.getsockname())
 
581
        server_sock, addr = listen_sock.accept()
 
582
        listen_sock.close()
 
583
        return server_sock, client_sock
71
584
    
72
585
    def test_smart_query_version(self):
73
586
        """Feed a canned query version to a server"""
 
587
        # wire-to-wire, using the whole stack
74
588
        to_server = StringIO('hello\n')
75
589
        from_server = StringIO()
76
 
        server = smart.SmartStreamServer(to_server, from_server, 
77
 
            local.LocalTransport(urlutils.local_path_to_url('/')))
78
 
        server._serve_one_request()
 
590
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
 
591
        server = smart.SmartServerPipeStreamMedium(
 
592
            to_server, from_server, transport)
 
593
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
594
                from_server.write)
 
595
        server._serve_one_request(protocol)
79
596
        self.assertEqual('ok\0011\n',
80
597
                         from_server.getvalue())
81
598
 
84
601
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
85
602
        to_server = StringIO('get\001./testfile\n')
86
603
        from_server = StringIO()
87
 
        server = smart.SmartStreamServer(to_server, from_server, transport)
88
 
        server._serve_one_request()
 
604
        server = smart.SmartServerPipeStreamMedium(
 
605
            to_server, from_server, transport)
 
606
        protocol = smart.SmartServerRequestProtocolOne(transport,
 
607
                from_server.write)
 
608
        server._serve_one_request(protocol)
89
609
        self.assertEqual('ok\n'
90
610
                         '17\n'
91
611
                         'contents\nof\nfile\n'
92
612
                         'done\n',
93
613
                         from_server.getvalue())
94
614
 
 
615
    def test_pipe_like_stream_with_bulk_data(self):
 
616
        sample_request_bytes = 'command\n9\nbulk datadone\n'
 
617
        to_server = StringIO(sample_request_bytes)
 
618
        from_server = StringIO()
 
619
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
620
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
621
        server._serve_one_request(sample_protocol)
 
622
        self.assertEqual('', from_server.getvalue())
 
623
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
 
624
        self.assertFalse(server.finished)
 
625
 
 
626
    def test_socket_stream_with_bulk_data(self):
 
627
        sample_request_bytes = 'command\n9\nbulk datadone\n'
 
628
        server_sock, client_sock = self.portable_socket_pair()
 
629
        server = smart.SmartServerSocketStreamMedium(
 
630
            server_sock, None)
 
631
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
632
        client_sock.sendall(sample_request_bytes)
 
633
        server._serve_one_request(sample_protocol)
 
634
        server_sock.close()
 
635
        self.assertEqual('', client_sock.recv(1))
 
636
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
 
637
        self.assertFalse(server.finished)
 
638
 
 
639
    def test_pipe_like_stream_shutdown_detection(self):
 
640
        to_server = StringIO('')
 
641
        from_server = StringIO()
 
642
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
643
        server._serve_one_request(SampleRequest('x'))
 
644
        self.assertTrue(server.finished)
 
645
        
 
646
    def test_socket_stream_shutdown_detection(self):
 
647
        server_sock, client_sock = self.portable_socket_pair()
 
648
        client_sock.close()
 
649
        server = smart.SmartServerSocketStreamMedium(
 
650
            server_sock, None)
 
651
        server._serve_one_request(SampleRequest('x'))
 
652
        self.assertTrue(server.finished)
 
653
        
 
654
    def test_pipe_like_stream_with_two_requests(self):
 
655
        # If two requests are read in one go, then two calls to
 
656
        # _serve_one_request should still process both of them as if they had
 
657
        # been received seperately.
 
658
        sample_request_bytes = 'command\n'
 
659
        to_server = StringIO(sample_request_bytes * 2)
 
660
        from_server = StringIO()
 
661
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
662
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
663
        server._serve_one_request(first_protocol)
 
664
        self.assertEqual(0, first_protocol.next_read_size())
 
665
        self.assertEqual('', from_server.getvalue())
 
666
        self.assertFalse(server.finished)
 
667
        # Make a new protocol, call _serve_one_request with it to collect the
 
668
        # second request.
 
669
        second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
670
        server._serve_one_request(second_protocol)
 
671
        self.assertEqual('', from_server.getvalue())
 
672
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
 
673
        self.assertFalse(server.finished)
 
674
        
 
675
    def test_socket_stream_with_two_requests(self):
 
676
        # If two requests are read in one go, then two calls to
 
677
        # _serve_one_request should still process both of them as if they had
 
678
        # been received seperately.
 
679
        sample_request_bytes = 'command\n'
 
680
        server_sock, client_sock = self.portable_socket_pair()
 
681
        server = smart.SmartServerSocketStreamMedium(
 
682
            server_sock, None)
 
683
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
684
        # Put two whole requests on the wire.
 
685
        client_sock.sendall(sample_request_bytes * 2)
 
686
        server._serve_one_request(first_protocol)
 
687
        self.assertEqual(0, first_protocol.next_read_size())
 
688
        self.assertFalse(server.finished)
 
689
        # Make a new protocol, call _serve_one_request with it to collect the
 
690
        # second request.
 
691
        second_protocol = SampleRequest(expected_bytes=sample_request_bytes)
 
692
        stream_still_open = server._serve_one_request(second_protocol)
 
693
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
 
694
        self.assertFalse(server.finished)
 
695
        server_sock.close()
 
696
        self.assertEqual('', client_sock.recv(1))
 
697
 
 
698
    def test_pipe_like_stream_error_handling(self):
 
699
        # Use plain python StringIO so we can monkey-patch the close method to
 
700
        # not discard the contents.
 
701
        from StringIO import StringIO
 
702
        to_server = StringIO('')
 
703
        from_server = StringIO()
 
704
        self.closed = False
 
705
        def close():
 
706
            self.closed = True
 
707
        from_server.close = close
 
708
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
709
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
 
710
        server._serve_one_request(fake_protocol)
 
711
        self.assertEqual('', from_server.getvalue())
 
712
        self.assertTrue(self.closed)
 
713
        self.assertTrue(server.finished)
 
714
        
 
715
    def test_socket_stream_error_handling(self):
 
716
        # Use plain python StringIO so we can monkey-patch the close method to
 
717
        # not discard the contents.
 
718
        from StringIO import StringIO
 
719
        server_sock, client_sock = self.portable_socket_pair()
 
720
        server = smart.SmartServerSocketStreamMedium(
 
721
            server_sock, None)
 
722
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
 
723
        server._serve_one_request(fake_protocol)
 
724
        # recv should not block, because the other end of the socket has been
 
725
        # closed.
 
726
        self.assertEqual('', client_sock.recv(1))
 
727
        self.assertTrue(server.finished)
 
728
        
 
729
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
 
730
        # Use plain python StringIO so we can monkey-patch the close method to
 
731
        # not discard the contents.
 
732
        to_server = StringIO('')
 
733
        from_server = StringIO()
 
734
        server = smart.SmartServerPipeStreamMedium(to_server, from_server, None)
 
735
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
 
736
        self.assertRaises(
 
737
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
 
738
        self.assertEqual('', from_server.getvalue())
 
739
 
 
740
    def test_socket_stream_keyboard_interrupt_handling(self):
 
741
        server_sock, client_sock = self.portable_socket_pair()
 
742
        server = smart.SmartServerSocketStreamMedium(
 
743
            server_sock, None)
 
744
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
 
745
        self.assertRaises(
 
746
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
 
747
        server_sock.close()
 
748
        self.assertEqual('', client_sock.recv(1))
 
749
        
 
750
 
 
751
class TestSmartTCPServer(tests.TestCase):
 
752
 
95
753
    def test_get_error_unexpected(self):
96
754
        """Error reported by server with no specific representation"""
97
755
        class FlakyTransport(object):
110
768
        finally:
111
769
            server.stop_background_thread()
112
770
 
113
 
    def test_server_subprocess(self):
114
 
        """Talk to a server started as a subprocess
115
 
        
116
 
        This is similar to running it over ssh, except that it runs in the same machine 
117
 
        without ssh intermediating.
118
 
        """
119
 
        args = [sys.executable, sys.argv[0], 'serve', '--inet']
120
 
        do_close_fds = True
121
 
        if sys.platform == 'win32':
122
 
            do_close_fds = False
123
 
        child = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE,
124
 
                                 close_fds=do_close_fds, universal_newlines=True)
125
 
        conn = smart.SmartStreamClient(lambda: (child.stdout, child.stdin))
126
 
        conn.query_version()
127
 
        conn.query_version()
128
 
        conn.disconnect()
129
 
        returncode = child.wait()
130
 
        self.assertEquals(0, returncode)
131
 
 
132
771
 
133
772
class SmartTCPTests(tests.TestCase):
134
773
    """Tests for connection/end to end behaviour using the TCP server.
200
839
        # we create a real connection not a loopback one, but it will use the
201
840
        # same server and pipes
202
841
        conn2 = self.transport.clone('.')
203
 
        self.assertTrue(self.transport._client is conn2._client)
 
842
        self.assertIs(self.transport._medium, conn2._medium)
204
843
 
205
844
    def test__remote_path(self):
206
845
        self.assertEquals('/foo/bar',
246
885
            'foo')
247
886
        
248
887
 
249
 
class SmartServerTests(tests.TestCaseWithTransport):
250
 
    """Test that call directly into the server logic, bypassing the network."""
 
888
class SmartServerRequestHandlerTests(tests.TestCaseWithTransport):
 
889
    """Test that call directly into the handler logic, bypassing the network."""
 
890
 
 
891
    def test_construct_request_handler(self):
 
892
        """Constructing a request handler should be easy and set defaults."""
 
893
        handler = smart.SmartServerRequestHandler(None)
 
894
        self.assertFalse(handler.finished_reading)
251
895
 
252
896
    def test_hello(self):
253
 
        server = smart.SmartServer(None)
254
 
        response = server.dispatch_command('hello', ())
255
 
        self.assertEqual(('ok', '1'), response.args)
256
 
        self.assertEqual(None, response.body)
 
897
        handler = smart.SmartServerRequestHandler(None)
 
898
        handler.dispatch_command('hello', ())
 
899
        self.assertEqual(('ok', '1'), handler.response.args)
 
900
        self.assertEqual(None, handler.response.body)
257
901
        
258
902
    def test_get_bundle(self):
259
903
        from bzrlib.bundle import serializer
262
906
        wt.add('hello')
263
907
        rev_id = wt.commit('add hello')
264
908
        
265
 
        server = smart.SmartServer(self.get_transport())
266
 
        response = server.dispatch_command('get_bundle', ('.', rev_id))
267
 
        bundle = serializer.read_bundle(StringIO(response.body))
 
909
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
910
        handler.dispatch_command('get_bundle', ('.', rev_id))
 
911
        bundle = serializer.read_bundle(StringIO(handler.response.body))
 
912
        self.assertEqual((), handler.response.args)
268
913
 
269
914
    def test_readonly_exception_becomes_transport_not_possible(self):
270
915
        """The response for a read-only error is ('ReadOnlyError')."""
271
 
        server = smart.SmartServer(self.get_readonly_transport())
 
916
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
272
917
        # send a mkdir for foo, with no explicit mode - should fail.
273
 
        response = server.dispatch_command('mkdir', ('foo', ''))
 
918
        handler.dispatch_command('mkdir', ('foo', ''))
274
919
        # and the failure should be an explicit ReadOnlyError
275
 
        self.assertEqual(("ReadOnlyError", ), response.args)
 
920
        self.assertEqual(("ReadOnlyError", ), handler.response.args)
276
921
        # XXX: TODO: test that other TransportNotPossible errors are
277
922
        # presented as TransportNotPossible - not possible to do that
278
923
        # until I figure out how to trigger that relatively cleanly via
279
924
        # the api. RBC 20060918
280
925
 
 
926
    def test_hello_has_finished_body_on_dispatch(self):
 
927
        """The 'hello' command should set finished_reading."""
 
928
        handler = smart.SmartServerRequestHandler(None)
 
929
        handler.dispatch_command('hello', ())
 
930
        self.assertTrue(handler.finished_reading)
 
931
        self.assertNotEqual(None, handler.response)
 
932
 
 
933
    def test_put_bytes_non_atomic(self):
 
934
        """'put_...' should set finished_reading after reading the bytes."""
 
935
        handler = smart.SmartServerRequestHandler(self.get_transport())
 
936
        handler.dispatch_command('put_non_atomic', ('a-file', '', 'F', ''))
 
937
        self.assertFalse(handler.finished_reading)
 
938
        handler.accept_body('1234')
 
939
        self.assertFalse(handler.finished_reading)
 
940
        handler.accept_body('5678')
 
941
        handler.end_of_body()
 
942
        self.assertTrue(handler.finished_reading)
 
943
        self.assertEqual(('ok', ), handler.response.args)
 
944
        self.assertEqual(None, handler.response.body)
 
945
        
 
946
    def test_readv_accept_body(self):
 
947
        """'readv' should set finished_reading after reading offsets."""
 
948
        self.build_tree(['a-file'])
 
949
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
950
        handler.dispatch_command('readv', ('a-file', ))
 
951
        self.assertFalse(handler.finished_reading)
 
952
        handler.accept_body('2,')
 
953
        self.assertFalse(handler.finished_reading)
 
954
        handler.accept_body('3')
 
955
        handler.end_of_body()
 
956
        self.assertTrue(handler.finished_reading)
 
957
        self.assertEqual(('readv', ), handler.response.args)
 
958
        # co - nte - nt of a-file is the file contents we are extracting from.
 
959
        self.assertEqual('nte', handler.response.body)
 
960
 
 
961
    def test_readv_short_read_response_contents(self):
 
962
        """'readv' when a short read occurs sets the response appropriately."""
 
963
        self.build_tree(['a-file'])
 
964
        handler = smart.SmartServerRequestHandler(self.get_readonly_transport())
 
965
        handler.dispatch_command('readv', ('a-file', ))
 
966
        # read beyond the end of the file.
 
967
        handler.accept_body('100,1')
 
968
        handler.end_of_body()
 
969
        self.assertTrue(handler.finished_reading)
 
970
        self.assertEqual(('ShortReadvError', 'a-file', '100', '1', '0'),
 
971
            handler.response.args)
 
972
        self.assertEqual(None, handler.response.body)
 
973
 
281
974
 
282
975
class SmartTransportRegistration(tests.TestCase):
283
976
 
287
980
        self.assertEqual('example.com', t._host)
288
981
 
289
982
 
290
 
class FakeClient(smart.SmartStreamClient):
291
 
    """Emulate a client for testing a transport's use of the client."""
292
 
 
293
 
    def __init__(self):
294
 
        smart.SmartStreamClient.__init__(self, None)
295
 
        self._calls = []
296
 
 
297
 
    def _call(self, *args):
298
 
        self._calls.append(('_call', args))
299
 
        return ('ok', )
300
 
 
301
 
    def _recv_bulk(self):
302
 
        return 'bar'
303
 
 
304
 
 
305
983
class TestSmartTransport(tests.TestCase):
306
984
        
307
985
    def test_use_connection_factory(self):
308
986
        # We want to be able to pass a client as a parameter to SmartTransport.
309
 
        client = FakeClient()
310
 
        transport = smart.SmartTransport('bzr://localhost/', client=client)
 
987
        input = StringIO("ok\n3\nbardone\n")
 
988
        output = StringIO()
 
989
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
990
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
311
991
 
312
992
        # We want to make sure the client is used when the first remote
313
 
        # method is called.  No method should have been called yet.
314
 
        self.assertEqual([], client._calls)
 
993
        # method is called.  No data should have been sent, or read.
 
994
        self.assertEqual(0, input.tell())
 
995
        self.assertEqual('', output.getvalue())
315
996
 
316
 
        # Now call a method that should result in a single request.
 
997
        # Now call a method that should result in a single request : as the
 
998
        # transport makes its own protocol instances, we check on the wire.
 
999
        # XXX: TODO: give the transport a protocol factory, which can make
 
1000
        # an instrumented protocol for us.
317
1001
        self.assertEqual('bar', transport.get_bytes('foo'))
318
 
        # The only call to _call should have been to get /foo.
319
 
        self.assertEqual([('_call', ('get', '/foo'))], client._calls)
 
1002
        # only the needed data should have been sent/received.
 
1003
        self.assertEqual(13, input.tell())
 
1004
        self.assertEqual('get\x01/foo\n', output.getvalue())
320
1005
 
321
1006
    def test__translate_error_readonly(self):
322
1007
        """Sending a ReadOnlyError to _translate_error raises TransportNotPossible."""
323
 
        client = FakeClient()
324
 
        transport = smart.SmartTransport('bzr://localhost/', client=client)
 
1008
        medium = smart.SmartClientMedium()
 
1009
        transport = smart.SmartTransport('bzr://localhost/', medium=medium)
325
1010
        self.assertRaises(errors.TransportNotPossible,
326
1011
            transport._translate_error, ("ReadOnlyError", ))
327
1012
 
328
1013
 
329
 
class InstrumentedClient(smart.SmartStreamClient):
330
 
    """A smart client whose writes are stored to a supplied list."""
331
 
 
332
 
    def __init__(self, write_output_list):
333
 
        smart.SmartStreamClient.__init__(self, None)
334
 
        self._write_output_list = write_output_list
335
 
 
336
 
    def _ensure_connection(self):
337
 
        """We are never strictly connected."""
338
 
 
339
 
    def _write_and_flush(self, bytes):
340
 
        self._write_output_list.append(bytes)
341
 
 
342
 
 
343
 
class InstrumentedServerProtocol(smart.SmartStreamServer):
 
1014
class InstrumentedServerProtocol(smart.SmartServerStreamMedium):
344
1015
    """A smart server which is backed by memory and saves its write requests."""
345
1016
 
346
1017
    def __init__(self, write_output_list):
347
 
        smart.SmartStreamServer.__init__(self, None, None,
348
 
            memory.MemoryTransport())
 
1018
        smart.SmartServerStreamMedium.__init__(self, memory.MemoryTransport())
349
1019
        self._write_output_list = write_output_list
350
1020
 
351
 
    def _write_and_flush(self, bytes):
352
 
        self._write_output_list.append(bytes)
353
 
 
354
1021
 
355
1022
class TestSmartProtocol(tests.TestCase):
356
1023
    """Tests for the smart protocol.
367
1034
 
368
1035
    def setUp(self):
369
1036
        super(TestSmartProtocol, self).setUp()
370
 
        self.to_server = []
371
 
        self.to_client = []
372
 
        self.smart_client = InstrumentedClient(self.to_server)
373
 
        self.smart_server = InstrumentedServerProtocol(self.to_client)
 
1037
        self.server_to_client = []
 
1038
        self.to_server = StringIO()
 
1039
        self.to_client = StringIO()
 
1040
        self.client_medium = smart.SmartSimplePipesClientMedium(self.to_client,
 
1041
            self.to_server)
 
1042
        self.client_protocol = smart.SmartClientRequestProtocolOne(
 
1043
            self.client_medium)
 
1044
        self.smart_server = InstrumentedServerProtocol(self.server_to_client)
 
1045
        self.smart_server_request = smart.SmartServerRequestHandler(None)
374
1046
 
375
1047
    def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
376
 
        client, server_protocol):
 
1048
        client, smart_server_request):
377
1049
        """Check that smart (de)serialises offsets as expected.
378
1050
        
379
1051
        We check both serialisation and deserialisation at the same time
382
1054
        
383
1055
        :param expected_offsets: a readv offset list.
384
1056
        :param expected_seralised: an expected serial form of the offsets.
385
 
        :param server: a SmartServer instance.
 
1057
        :param smart_server_request: a SmartServerRequestHandler instance.
386
1058
        """
387
 
        offsets = server_protocol.smart_server._deserialise_offsets(
388
 
            expected_serialised)
 
1059
        # XXX: 'smart_server_request' should be a SmartServerRequestProtocol in
 
1060
        # future.
 
1061
        offsets = smart_server_request._deserialise_offsets(expected_serialised)
389
1062
        self.assertEqual(expected_offsets, offsets)
390
1063
        serialised = client._serialise_offsets(offsets)
391
1064
        self.assertEqual(expected_serialised, serialised)
392
1065
 
 
1066
    def build_protocol_waiting_for_body(self):
 
1067
        out_stream = StringIO()
 
1068
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
 
1069
        protocol.has_dispatched = True
 
1070
        protocol.request = smart.SmartServerRequestHandler(None)
 
1071
        def handle_end_of_bytes():
 
1072
            self.end_received = True
 
1073
            self.assertEqual('abcdefg', protocol.request._body_bytes)
 
1074
            protocol.request.response = smart.SmartServerResponse(('ok', ))
 
1075
        protocol.request._end_of_body_handler = handle_end_of_bytes
 
1076
        # Call accept_bytes to make sure that internal state like _body_decoder
 
1077
        # is initialised.  This test should probably be given a clearer
 
1078
        # interface to work with that will not cause this inconsistency.
 
1079
        #   -- Andrew Bennetts, 2006-09-28
 
1080
        protocol.accept_bytes('')
 
1081
        return protocol
 
1082
 
 
1083
    def test_construct_version_one_server_protocol(self):
 
1084
        protocol = smart.SmartServerRequestProtocolOne(None, None)
 
1085
        self.assertEqual('', protocol.excess_buffer)
 
1086
        self.assertEqual('', protocol.in_buffer)
 
1087
        self.assertFalse(protocol.has_dispatched)
 
1088
        self.assertEqual(1, protocol.next_read_size())
 
1089
 
 
1090
    def test_construct_version_one_client_protocol(self):
 
1091
        # we can construct a client protocol from a client medium request
 
1092
        output = StringIO()
 
1093
        medium = smart.SmartSimplePipesClientMedium(None, output)
 
1094
        request = medium.get_request()
 
1095
        client_protocol = smart.SmartClientRequestProtocolOne(request)
 
1096
 
393
1097
    def test_server_offset_serialisation(self):
394
1098
        """The Smart protocol serialises offsets as a comma and \n string.
395
1099
 
398
1102
        one that should coalesce.
399
1103
        """
400
1104
        self.assertOffsetSerialisation([], '',
401
 
            self.smart_client, self.smart_server)
 
1105
            self.client_protocol, self.smart_server_request)
402
1106
        self.assertOffsetSerialisation([(1,2)], '1,2',
403
 
            self.smart_client, self.smart_server)
 
1107
            self.client_protocol, self.smart_server_request)
404
1108
        self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
405
 
            self.smart_client, self.smart_server)
 
1109
            self.client_protocol, self.smart_server_request)
406
1110
        self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
407
 
            '1,2\n3,4\n100,200', self.smart_client, self.smart_server)
408
 
 
409
 
 
 
1111
            '1,2\n3,4\n100,200', self.client_protocol, self.smart_server_request)
 
1112
 
 
1113
    def test_accept_bytes_to_protocol(self):
 
1114
        out_stream = StringIO()
 
1115
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
 
1116
        protocol.accept_bytes('abc')
 
1117
        self.assertEqual('abc', protocol.in_buffer)
 
1118
        protocol.accept_bytes('\n')
 
1119
        self.assertEqual("error\x01Generic bzr smart protocol error: bad request"
 
1120
            " u'abc'\n", out_stream.getvalue())
 
1121
        self.assertTrue(protocol.has_dispatched)
 
1122
        self.assertEqual(1, protocol.next_read_size())
 
1123
 
 
1124
    def test_accept_bytes_with_invalid_utf8_to_protocol(self):
 
1125
        out_stream = StringIO()
 
1126
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
 
1127
        # the byte 0xdd is not a valid UTF-8 string.
 
1128
        protocol.accept_bytes('\xdd\n')
 
1129
        self.assertEqual(
 
1130
            "error\x01Generic bzr smart protocol error: "
 
1131
            "one or more arguments of request '\\xdd\\n' are not valid UTF-8\n",
 
1132
            out_stream.getvalue())
 
1133
        self.assertTrue(protocol.has_dispatched)
 
1134
        self.assertEqual(1, protocol.next_read_size())
 
1135
 
 
1136
    def test_accept_body_bytes_to_protocol(self):
 
1137
        protocol = self.build_protocol_waiting_for_body()
 
1138
        self.assertEqual(6, protocol.next_read_size())
 
1139
        protocol.accept_bytes('7\nabc')
 
1140
        self.assertEqual(9, protocol.next_read_size())
 
1141
        protocol.accept_bytes('defgd')
 
1142
        protocol.accept_bytes('one\n')
 
1143
        self.assertEqual(0, protocol.next_read_size())
 
1144
        self.assertTrue(self.end_received)
 
1145
 
 
1146
    def test_accept_request_and_body_all_at_once(self):
 
1147
        mem_transport = memory.MemoryTransport()
 
1148
        mem_transport.put_bytes('foo', 'abcdefghij')
 
1149
        out_stream = StringIO()
 
1150
        protocol = smart.SmartServerRequestProtocolOne(mem_transport,
 
1151
                out_stream.write)
 
1152
        protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
 
1153
        self.assertEqual(0, protocol.next_read_size())
 
1154
        self.assertEqual('readv\n3\ndefdone\n', out_stream.getvalue())
 
1155
        self.assertEqual('', protocol.excess_buffer)
 
1156
        self.assertEqual('', protocol.in_buffer)
 
1157
 
 
1158
    def test_accept_excess_bytes_are_preserved(self):
 
1159
        out_stream = StringIO()
 
1160
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
 
1161
        protocol.accept_bytes('hello\nhello\n')
 
1162
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1163
        self.assertEqual("hello\n", protocol.excess_buffer)
 
1164
        self.assertEqual("", protocol.in_buffer)
 
1165
 
 
1166
    def test_accept_excess_bytes_after_body(self):
 
1167
        protocol = self.build_protocol_waiting_for_body()
 
1168
        protocol.accept_bytes('7\nabcdefgdone\nX')
 
1169
        self.assertTrue(self.end_received)
 
1170
        self.assertEqual("X", protocol.excess_buffer)
 
1171
        self.assertEqual("", protocol.in_buffer)
 
1172
        protocol.accept_bytes('Y')
 
1173
        self.assertEqual("XY", protocol.excess_buffer)
 
1174
        self.assertEqual("", protocol.in_buffer)
 
1175
 
 
1176
    def test_accept_excess_bytes_after_dispatch(self):
 
1177
        out_stream = StringIO()
 
1178
        protocol = smart.SmartServerRequestProtocolOne(None, out_stream.write)
 
1179
        protocol.accept_bytes('hello\n')
 
1180
        self.assertEqual("ok\x011\n", out_stream.getvalue())
 
1181
        protocol.accept_bytes('hel')
 
1182
        self.assertEqual("hel", protocol.excess_buffer)
 
1183
        protocol.accept_bytes('lo\n')
 
1184
        self.assertEqual("hello\n", protocol.excess_buffer)
 
1185
        self.assertEqual("", protocol.in_buffer)
 
1186
 
 
1187
    def test_sync_with_request_sets_finished_reading(self):
 
1188
        protocol = smart.SmartServerRequestProtocolOne(None, None)
 
1189
        request = smart.SmartServerRequestHandler(None)
 
1190
        protocol.sync_with_request(request)
 
1191
        self.assertEqual(1, protocol.next_read_size())
 
1192
        request.finished_reading = True
 
1193
        protocol.sync_with_request(request)
 
1194
        self.assertEqual(0, protocol.next_read_size())
 
1195
 
 
1196
    def test_query_version(self):
 
1197
        """query_version on a SmartClientProtocolOne should return a number.
 
1198
        
 
1199
        The protocol provides the query_version because the domain level clients
 
1200
        may all need to be able to probe for capabilities.
 
1201
        """
 
1202
        # What we really want to test here is that SmartClientProtocolOne calls
 
1203
        # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
 
1204
        # response of tuple-encoded (ok, 1).  Also, seperately we should test
 
1205
        # the error if the response is a non-understood version.
 
1206
        input = StringIO('ok\x011\n')
 
1207
        output = StringIO()
 
1208
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1209
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1210
        self.assertEqual(1, protocol.query_version())
 
1211
 
 
1212
    def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
 
1213
            input_tuples):
 
1214
        """Assert that each input_tuple serialises as expected_bytes, and the
 
1215
        bytes deserialise as expected_tuple.
 
1216
        """
 
1217
        # check the encoding of the server for all input_tuples matches
 
1218
        # expected bytes
 
1219
        for input_tuple in input_tuples:
 
1220
            server_output = StringIO()
 
1221
            server_protocol = smart.SmartServerRequestProtocolOne(
 
1222
                None, server_output.write)
 
1223
            server_protocol._send_response(input_tuple)
 
1224
            self.assertEqual(expected_bytes, server_output.getvalue())
 
1225
        # check the decoding of the client protocol from expected_bytes:
 
1226
        input = StringIO(expected_bytes)
 
1227
        output = StringIO()
 
1228
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1229
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1230
        protocol.call('foo')
 
1231
        self.assertEqual(expected_tuple, protocol.read_response_tuple())
 
1232
 
 
1233
    def test_client_call_empty_response(self):
 
1234
        # protocol.call() can get back an empty tuple as a response. This occurs
 
1235
        # when the parsed line is an empty line, and results in a tuple with
 
1236
        # one element - an empty string.
 
1237
        self.assertServerToClientEncoding('\n', ('', ), [(), ('', )])
 
1238
 
 
1239
    def test_client_call_three_element_response(self):
 
1240
        # protocol.call() can get back tuples of other lengths. A three element
 
1241
        # tuple should be unpacked as three strings.
 
1242
        self.assertServerToClientEncoding('a\x01b\x0134\n', ('a', 'b', '34'),
 
1243
            [('a', 'b', '34')])
 
1244
 
 
1245
    def test_client_call_with_body_bytes_uploads(self):
 
1246
        # protocol.call_with_upload should length-prefix the bytes onto the 
 
1247
        # wire.
 
1248
        expected_bytes = "foo\n7\nabcdefgdone\n"
 
1249
        input = StringIO("\n")
 
1250
        output = StringIO()
 
1251
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1252
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1253
        protocol.call_with_body_bytes(('foo', ), "abcdefg")
 
1254
        self.assertEqual(expected_bytes, output.getvalue())
 
1255
 
 
1256
    def test_client_call_with_body_readv_array(self):
 
1257
        # protocol.call_with_upload should encode the readv array and then
 
1258
        # length-prefix the bytes onto the wire.
 
1259
        expected_bytes = "foo\n7\n1,2\n5,6done\n"
 
1260
        input = StringIO("\n")
 
1261
        output = StringIO()
 
1262
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1263
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1264
        protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
 
1265
        self.assertEqual(expected_bytes, output.getvalue())
 
1266
 
 
1267
    def test_client_read_body_bytes_all(self):
 
1268
        # read_body_bytes should decode the body bytes from the wire into
 
1269
        # a response.
 
1270
        expected_bytes = "1234567"
 
1271
        server_bytes = "ok\n7\n1234567done\n"
 
1272
        input = StringIO(server_bytes)
 
1273
        output = StringIO()
 
1274
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1275
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1276
        protocol.call('foo')
 
1277
        protocol.read_response_tuple(True)
 
1278
        self.assertEqual(expected_bytes, protocol.read_body_bytes())
 
1279
 
 
1280
    def test_client_read_body_bytes_incremental(self):
 
1281
        # test reading a few bytes at a time from the body
 
1282
        # XXX: possibly we should test dribbling the bytes into the stringio
 
1283
        # to make the state machine work harder: however, as we use the
 
1284
        # LengthPrefixedBodyDecoder that is already well tested - we can skip
 
1285
        # that.
 
1286
        expected_bytes = "1234567"
 
1287
        server_bytes = "ok\n7\n1234567done\n"
 
1288
        input = StringIO(server_bytes)
 
1289
        output = StringIO()
 
1290
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1291
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1292
        protocol.call('foo')
 
1293
        protocol.read_response_tuple(True)
 
1294
        self.assertEqual(expected_bytes[0:2], protocol.read_body_bytes(2))
 
1295
        self.assertEqual(expected_bytes[2:4], protocol.read_body_bytes(2))
 
1296
        self.assertEqual(expected_bytes[4:6], protocol.read_body_bytes(2))
 
1297
        self.assertEqual(expected_bytes[6], protocol.read_body_bytes())
 
1298
 
 
1299
    def test_client_cancel_read_body_does_not_eat_body_bytes(self):
 
1300
        # cancelling the expected body needs to finish the request, but not
 
1301
        # read any more bytes.
 
1302
        expected_bytes = "1234567"
 
1303
        server_bytes = "ok\n7\n1234567done\n"
 
1304
        input = StringIO(server_bytes)
 
1305
        output = StringIO()
 
1306
        medium = smart.SmartSimplePipesClientMedium(input, output)
 
1307
        protocol = smart.SmartClientRequestProtocolOne(medium.get_request())
 
1308
        protocol.call('foo')
 
1309
        protocol.read_response_tuple(True)
 
1310
        protocol.cancel_read_body()
 
1311
        self.assertEqual(3, input.tell())
 
1312
        self.assertRaises(errors.ReadingCompleted, protocol.read_body_bytes)
 
1313
 
 
1314
 
 
1315
class LengthPrefixedBodyDecoder(tests.TestCase):
 
1316
 
 
1317
    # XXX: TODO: make accept_reading_trailer invoke translate_response or 
 
1318
    # something similar to the ProtocolBase method.
 
1319
 
 
1320
    def test_construct(self):
 
1321
        decoder = smart.LengthPrefixedBodyDecoder()
 
1322
        self.assertFalse(decoder.finished_reading)
 
1323
        self.assertEqual(6, decoder.next_read_size())
 
1324
        self.assertEqual('', decoder.read_pending_data())
 
1325
        self.assertEqual('', decoder.unused_data)
 
1326
 
 
1327
    def test_accept_bytes(self):
 
1328
        decoder = smart.LengthPrefixedBodyDecoder()
 
1329
        decoder.accept_bytes('')
 
1330
        self.assertFalse(decoder.finished_reading)
 
1331
        self.assertEqual(6, decoder.next_read_size())
 
1332
        self.assertEqual('', decoder.read_pending_data())
 
1333
        self.assertEqual('', decoder.unused_data)
 
1334
        decoder.accept_bytes('7')
 
1335
        self.assertFalse(decoder.finished_reading)
 
1336
        self.assertEqual(6, decoder.next_read_size())
 
1337
        self.assertEqual('', decoder.read_pending_data())
 
1338
        self.assertEqual('', decoder.unused_data)
 
1339
        decoder.accept_bytes('\na')
 
1340
        self.assertFalse(decoder.finished_reading)
 
1341
        self.assertEqual(11, decoder.next_read_size())
 
1342
        self.assertEqual('a', decoder.read_pending_data())
 
1343
        self.assertEqual('', decoder.unused_data)
 
1344
        decoder.accept_bytes('bcdefgd')
 
1345
        self.assertFalse(decoder.finished_reading)
 
1346
        self.assertEqual(4, decoder.next_read_size())
 
1347
        self.assertEqual('bcdefg', decoder.read_pending_data())
 
1348
        self.assertEqual('', decoder.unused_data)
 
1349
        decoder.accept_bytes('one')
 
1350
        self.assertFalse(decoder.finished_reading)
 
1351
        self.assertEqual(1, decoder.next_read_size())
 
1352
        self.assertEqual('', decoder.read_pending_data())
 
1353
        self.assertEqual('', decoder.unused_data)
 
1354
        decoder.accept_bytes('\nblarg')
 
1355
        self.assertTrue(decoder.finished_reading)
 
1356
        self.assertEqual(1, decoder.next_read_size())
 
1357
        self.assertEqual('', decoder.read_pending_data())
 
1358
        self.assertEqual('blarg', decoder.unused_data)
 
1359
        
 
1360
    def test_accept_bytes_all_at_once_with_excess(self):
 
1361
        decoder = smart.LengthPrefixedBodyDecoder()
 
1362
        decoder.accept_bytes('1\nadone\nunused')
 
1363
        self.assertTrue(decoder.finished_reading)
 
1364
        self.assertEqual(1, decoder.next_read_size())
 
1365
        self.assertEqual('a', decoder.read_pending_data())
 
1366
        self.assertEqual('unused', decoder.unused_data)
 
1367
 
 
1368
    def test_accept_bytes_exact_end_of_body(self):
 
1369
        decoder = smart.LengthPrefixedBodyDecoder()
 
1370
        decoder.accept_bytes('1\na')
 
1371
        self.assertFalse(decoder.finished_reading)
 
1372
        self.assertEqual(5, decoder.next_read_size())
 
1373
        self.assertEqual('a', decoder.read_pending_data())
 
1374
        self.assertEqual('', decoder.unused_data)
 
1375
        decoder.accept_bytes('done\n')
 
1376
        self.assertTrue(decoder.finished_reading)
 
1377
        self.assertEqual(1, decoder.next_read_size())
 
1378
        self.assertEqual('', decoder.read_pending_data())
 
1379
        self.assertEqual('', decoder.unused_data)
 
1380
 
 
1381
 
 
1382
class FakeHTTPMedium(object):
 
1383
    def __init__(self):
 
1384
        self.written_request = None
 
1385
        self._current_request = None
 
1386
    def send_http_smart_request(self, bytes):
 
1387
        self.written_request = bytes
 
1388
        return None
 
1389
 
 
1390
 
 
1391
class HTTPTunnellingSmokeTest(tests.TestCaseWithTransport):
 
1392
    
 
1393
    def _test_bulk_data(self, url_protocol):
 
1394
        # We should be able to send and receive bulk data in a single message.
 
1395
        # The 'readv' command in the smart protocol both sends and receives bulk
 
1396
        # data, so we use that.
 
1397
        self.build_tree(['data-file'])
 
1398
        http_server = HTTPServerWithSmarts()
 
1399
        http_server._url_protocol = url_protocol
 
1400
        http_server.setUp()
 
1401
        self.addCleanup(http_server.tearDown)
 
1402
 
 
1403
        http_transport = get_transport(http_server.get_url())
 
1404
 
 
1405
        medium = http_transport.get_smart_medium()
 
1406
        #remote_transport = RemoteTransport('fake_url', medium)
 
1407
        remote_transport = smart.SmartTransport('/', medium=medium)
 
1408
        self.assertEqual(
 
1409
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
 
1410
 
 
1411
    def test_bulk_data_pycurl(self):
 
1412
        try:
 
1413
            self._test_bulk_data('http+pycurl')
 
1414
        except errors.UnsupportedProtocol, e:
 
1415
            raise tests.TestSkipped(str(e))
 
1416
    
 
1417
    def test_bulk_data_urllib(self):
 
1418
        self._test_bulk_data('http+urllib')
 
1419
 
 
1420
    def test_smart_http_medium_request_accept_bytes(self):
 
1421
        medium = FakeHTTPMedium()
 
1422
        request = SmartClientHTTPMediumRequest(medium)
 
1423
        request.accept_bytes('abc')
 
1424
        request.accept_bytes('def')
 
1425
        self.assertEqual(None, medium.written_request)
 
1426
        request.finished_writing()
 
1427
        self.assertEqual('abcdef', medium.written_request)
 
1428
 
 
1429
    def _test_http_send_smart_request(self, url_protocol):
 
1430
        http_server = HTTPServerWithSmarts()
 
1431
        http_server._url_protocol = url_protocol
 
1432
        http_server.setUp()
 
1433
        self.addCleanup(http_server.tearDown)
 
1434
 
 
1435
        post_body = 'hello\n'
 
1436
        expected_reply_body = 'ok\x011\n'
 
1437
 
 
1438
        http_transport = get_transport(http_server.get_url())
 
1439
        medium = http_transport.get_smart_medium()
 
1440
        response = medium.send_http_smart_request(post_body)
 
1441
        reply_body = response.read()
 
1442
        self.assertEqual(expected_reply_body, reply_body)
 
1443
 
 
1444
    def test_http_send_smart_request_pycurl(self):
 
1445
        try:
 
1446
            self._test_http_send_smart_request('http+pycurl')
 
1447
        except errors.UnsupportedProtocol, e:
 
1448
            raise tests.TestSkipped(str(e))
 
1449
 
 
1450
    def test_http_send_smart_request_urllib(self):
 
1451
        self._test_http_send_smart_request('http+urllib')
 
1452
 
 
1453
    def test_http_server_with_smarts(self):
 
1454
        http_server = HTTPServerWithSmarts()
 
1455
        http_server.setUp()
 
1456
        self.addCleanup(http_server.tearDown)
 
1457
 
 
1458
        post_body = 'hello\n'
 
1459
        expected_reply_body = 'ok\x011\n'
 
1460
 
 
1461
        smart_server_url = http_server.get_url() + '.bzr/smart'
 
1462
        reply = urllib2.urlopen(smart_server_url, post_body).read()
 
1463
 
 
1464
        self.assertEqual(expected_reply_body, reply)
 
1465
 
 
1466
    def test_smart_http_server_post_request_handler(self):
 
1467
        http_server = HTTPServerWithSmarts()
 
1468
        http_server.setUp()
 
1469
        self.addCleanup(http_server.tearDown)
 
1470
        httpd = http_server._get_httpd()
 
1471
 
 
1472
        socket = SampleSocket(
 
1473
            'POST /.bzr/smart HTTP/1.0\r\n'
 
1474
            # HTTP/1.0 posts must have a Content-Length.
 
1475
            'Content-Length: 6\r\n'
 
1476
            '\r\n'
 
1477
            'hello\n')
 
1478
        request_handler = SmartRequestHandler(
 
1479
            socket, ('localhost', 80), httpd)
 
1480
        response = socket.writefile.getvalue()
 
1481
        self.assertStartsWith(response, 'HTTP/1.0 200 ')
 
1482
        # This includes the end of the HTTP headers, and all the body.
 
1483
        expected_end_of_response = '\r\n\r\nok\x011\n'
 
1484
        self.assertEndsWith(response, expected_end_of_response)
 
1485
 
 
1486
 
 
1487
class SampleSocket(object):
 
1488
    """A socket-like object for use in testing the HTTP request handler."""
 
1489
    
 
1490
    def __init__(self, socket_read_content):
 
1491
        """Constructs a sample socket.
 
1492
 
 
1493
        :param socket_read_content: a byte sequence
 
1494
        """
 
1495
        # Use plain python StringIO so we can monkey-patch the close method to
 
1496
        # not discard the contents.
 
1497
        from StringIO import StringIO
 
1498
        self.readfile = StringIO(socket_read_content)
 
1499
        self.writefile = StringIO()
 
1500
        self.writefile.close = lambda: None
 
1501
        
 
1502
    def makefile(self, mode='r', bufsize=None):
 
1503
        if 'r' in mode:
 
1504
            return self.readfile
 
1505
        else:
 
1506
            return self.writefile
 
1507
 
 
1508
        
410
1509
# TODO: Client feature that does get_bundle and then installs that into a
411
1510
# branch; this can be used in place of the regular pull/fetch operation when
412
1511
# coming from a smart server.