~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Patch Queue Manager
  • Date: 2016-02-01 19:13:13 UTC
  • mfrom: (6614.2.2 trunk)
  • Revision ID: pqm@pqm.ubuntu.com-20160201191313-wdfvmfff1djde6oq
(vila) Release 2.7.0 (Vincent Ladeuil)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2011 Canonical Ltd
 
1
# Copyright (C) 2006-2016 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
 
22
import errno
21
23
import os
22
24
import socket
 
25
import subprocess
 
26
import sys
23
27
import threading
 
28
import time
 
29
 
 
30
from testtools.matchers import DocTestMatches
24
31
 
25
32
import bzrlib
26
33
from bzrlib import (
27
34
        bzrdir,
 
35
        controldir,
 
36
        debug,
28
37
        errors,
29
38
        osutils,
30
39
        tests,
31
 
        transport,
 
40
        transport as _mod_transport,
32
41
        urlutils,
33
42
        )
34
43
from bzrlib.smart import (
37
46
        message,
38
47
        protocol,
39
48
        request as _mod_request,
40
 
        server,
 
49
        server as _mod_server,
41
50
        vfs,
42
51
)
43
52
from bzrlib.tests import (
54
63
        )
55
64
 
56
65
 
 
66
def create_file_pipes():
 
67
    r, w = os.pipe()
 
68
    # These must be opened without buffering, or we get undefined results
 
69
    rf = os.fdopen(r, 'rb', 0)
 
70
    wf = os.fdopen(w, 'wb', 0)
 
71
    return rf, wf
 
72
 
 
73
 
 
74
def portable_socket_pair():
 
75
    """Return a pair of TCP sockets connected to each other.
 
76
 
 
77
    Unlike socket.socketpair, this should work on Windows.
 
78
    """
 
79
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
80
    listen_sock.bind(('127.0.0.1', 0))
 
81
    listen_sock.listen(1)
 
82
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
83
    client_sock.connect(listen_sock.getsockname())
 
84
    server_sock, addr = listen_sock.accept()
 
85
    listen_sock.close()
 
86
    return server_sock, client_sock
 
87
 
 
88
 
57
89
class StringIOSSHVendor(object):
58
90
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
59
91
 
68
100
        return StringIOSSHConnection(self)
69
101
 
70
102
 
 
103
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
 
104
    """The first connection will be considered closed.
 
105
 
 
106
    The second connection will succeed normally.
 
107
    """
 
108
 
 
109
    def __init__(self, read_from, write_to, fail_at_write=True):
 
110
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
 
111
            write_to)
 
112
        self.fail_at_write = fail_at_write
 
113
        self._first = True
 
114
 
 
115
    def connect_ssh(self, username, password, host, port, command):
 
116
        self.calls.append(('connect_ssh', username, password, host, port,
 
117
            command))
 
118
        if self._first:
 
119
            self._first = False
 
120
            return ClosedSSHConnection(self)
 
121
        return StringIOSSHConnection(self)
 
122
 
 
123
 
71
124
class StringIOSSHConnection(ssh.SSHConnection):
72
125
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
73
126
 
83
136
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
84
137
 
85
138
 
 
139
class ClosedSSHConnection(ssh.SSHConnection):
 
140
    """An SSH connection that just has closed channels."""
 
141
 
 
142
    def __init__(self, vendor):
 
143
        self.vendor = vendor
 
144
 
 
145
    def close(self):
 
146
        self.vendor.calls.append(('close', ))
 
147
 
 
148
    def get_sock_or_pipes(self):
 
149
        # We create matching pipes, and then close the ssh side
 
150
        bzr_read, ssh_write = create_file_pipes()
 
151
        # We always fail when bzr goes to read
 
152
        ssh_write.close()
 
153
        if self.vendor.fail_at_write:
 
154
            # If set, we'll also fail when bzr goes to write
 
155
            ssh_read, bzr_write = create_file_pipes()
 
156
            ssh_read.close()
 
157
        else:
 
158
            bzr_write = self.vendor.write_to
 
159
        return 'pipes', (bzr_read, bzr_write)
 
160
 
 
161
 
86
162
class _InvalidHostnameFeature(features.Feature):
87
163
    """Does 'non_existent.invalid' fail to resolve?
88
164
 
178
254
        client_medium._accept_bytes('abc')
179
255
        self.assertEqual('abc', output.getvalue())
180
256
 
 
257
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
 
258
        # It is unfortunate that we have to use Popen for this. However,
 
259
        # os.pipe() does not behave the same as subprocess.Popen().
 
260
        # On Windows, if you use os.pipe() and close the write side,
 
261
        # read.read() hangs. On Linux, read.read() returns the empty string.
 
262
        p = subprocess.Popen([sys.executable, '-c',
 
263
            'import sys\n'
 
264
            'sys.stdout.write(sys.stdin.read(4))\n'
 
265
            'sys.stdout.close()\n'],
 
266
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
267
        client_medium = medium.SmartSimplePipesClientMedium(
 
268
            p.stdout, p.stdin, 'base')
 
269
        client_medium._accept_bytes('abc\n')
 
270
        self.assertEqual('abc', client_medium._read_bytes(3))
 
271
        p.wait()
 
272
        # While writing to the underlying pipe,
 
273
        #   Windows py2.6.6 we get IOError(EINVAL)
 
274
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
275
        # In both cases, it should be wrapped to ConnectionReset
 
276
        self.assertRaises(errors.ConnectionReset,
 
277
                          client_medium._accept_bytes, 'more')
 
278
 
 
279
    def test_simple_pipes__accept_bytes_pipe_closed(self):
 
280
        child_read, client_write = create_file_pipes()
 
281
        client_medium = medium.SmartSimplePipesClientMedium(
 
282
            None, client_write, 'base')
 
283
        client_medium._accept_bytes('abc\n')
 
284
        self.assertEqual('abc\n', child_read.read(4))
 
285
        # While writing to the underlying pipe,
 
286
        #   Windows py2.6.6 we get IOError(EINVAL)
 
287
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
288
        # In both cases, it should be wrapped to ConnectionReset
 
289
        child_read.close()
 
290
        self.assertRaises(errors.ConnectionReset,
 
291
                          client_medium._accept_bytes, 'more')
 
292
 
 
293
    def test_simple_pipes__flush_pipe_closed(self):
 
294
        child_read, client_write = create_file_pipes()
 
295
        client_medium = medium.SmartSimplePipesClientMedium(
 
296
            None, client_write, 'base')
 
297
        client_medium._accept_bytes('abc\n')
 
298
        child_read.close()
 
299
        # Even though the pipe is closed, flush on the write side seems to be a
 
300
        # no-op, rather than a failure.
 
301
        client_medium._flush()
 
302
 
 
303
    def test_simple_pipes__flush_subprocess_closed(self):
 
304
        p = subprocess.Popen([sys.executable, '-c',
 
305
            'import sys\n'
 
306
            'sys.stdout.write(sys.stdin.read(4))\n'
 
307
            'sys.stdout.close()\n'],
 
308
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
309
        client_medium = medium.SmartSimplePipesClientMedium(
 
310
            p.stdout, p.stdin, 'base')
 
311
        client_medium._accept_bytes('abc\n')
 
312
        p.wait()
 
313
        # Even though the child process is dead, flush seems to be a no-op.
 
314
        client_medium._flush()
 
315
 
 
316
    def test_simple_pipes__read_bytes_pipe_closed(self):
 
317
        child_read, client_write = create_file_pipes()
 
318
        client_medium = medium.SmartSimplePipesClientMedium(
 
319
            child_read, client_write, 'base')
 
320
        client_medium._accept_bytes('abc\n')
 
321
        client_write.close()
 
322
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
323
        self.assertEqual('', client_medium._read_bytes(4))
 
324
 
 
325
    def test_simple_pipes__read_bytes_subprocess_closed(self):
 
326
        p = subprocess.Popen([sys.executable, '-c',
 
327
            'import sys\n'
 
328
            'if sys.platform == "win32":\n'
 
329
            '    import msvcrt, os\n'
 
330
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
331
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
332
            'sys.stdout.write(sys.stdin.read(4))\n'
 
333
            'sys.stdout.close()\n'],
 
334
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
335
        client_medium = medium.SmartSimplePipesClientMedium(
 
336
            p.stdout, p.stdin, 'base')
 
337
        client_medium._accept_bytes('abc\n')
 
338
        p.wait()
 
339
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
340
        self.assertEqual('', client_medium._read_bytes(4))
 
341
 
181
342
    def test_simple_pipes_client_disconnect_does_nothing(self):
182
343
        # calling disconnect does nothing.
183
344
        input = StringIO()
339
500
            ],
340
501
            vendor.calls)
341
502
 
 
503
    def test_ssh_client_repr(self):
 
504
        client_medium = medium.SmartSSHClientMedium(
 
505
            'base', medium.SSHParams("example.com", "4242", "username"))
 
506
        self.assertEqual(
 
507
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
 
508
            repr(client_medium))
 
509
 
 
510
    def test_ssh_client_repr_no_port(self):
 
511
        client_medium = medium.SmartSSHClientMedium(
 
512
            'base', medium.SSHParams("example.com", None, "username"))
 
513
        self.assertEqual(
 
514
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
 
515
            repr(client_medium))
 
516
 
 
517
    def test_ssh_client_repr_no_username(self):
 
518
        client_medium = medium.SmartSSHClientMedium(
 
519
            'base', medium.SSHParams("example.com", None, None))
 
520
        self.assertEqual(
 
521
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
 
522
            repr(client_medium))
 
523
 
342
524
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
343
525
        # Doing a disconnect on a new (and thus unconnected) SSH medium
344
526
        # does not fail.  It's ok to disconnect an unconnected medium.
565
747
        request.finished_reading()
566
748
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
567
749
 
 
750
    def test_reset(self):
 
751
        server_sock, client_sock = portable_socket_pair()
 
752
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
 
753
        #       bzr where it exists.
 
754
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
755
        client_medium._socket = client_sock
 
756
        client_medium._connected = True
 
757
        req = client_medium.get_request()
 
758
        self.assertRaises(errors.TooManyConcurrentRequests,
 
759
            client_medium.get_request)
 
760
        client_medium.reset()
 
761
        # The stream should be reset, marked as disconnected, though ready for
 
762
        # us to make a new request
 
763
        self.assertFalse(client_medium._connected)
 
764
        self.assertIs(None, client_medium._socket)
 
765
        try:
 
766
            self.assertEqual('', client_sock.recv(1))
 
767
        except socket.error, e:
 
768
            if e.errno not in (errno.EBADF,):
 
769
                raise
 
770
        req = client_medium.get_request()
 
771
 
568
772
 
569
773
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
570
774
 
571
775
    def test_plausible_url(self):
572
 
        self.assert_(self.get_url().startswith('bzr://'))
 
776
        self.assertTrue(self.get_url().startswith('bzr://'))
573
777
 
574
778
    def test_probe_transport(self):
575
779
        t = self.get_transport()
618
822
        super(TestSmartServerStreamMedium, self).setUp()
619
823
        self.overrideEnv('BZR_NO_SMART_VFS', None)
620
824
 
621
 
    def portable_socket_pair(self):
622
 
        """Return a pair of TCP sockets connected to each other.
623
 
 
624
 
        Unlike socket.socketpair, this should work on Windows.
625
 
        """
626
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
627
 
        listen_sock.bind(('127.0.0.1', 0))
628
 
        listen_sock.listen(1)
629
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
630
 
        client_sock.connect(listen_sock.getsockname())
631
 
        server_sock, addr = listen_sock.accept()
632
 
        listen_sock.close()
633
 
        return server_sock, client_sock
 
825
    def create_pipe_medium(self, to_server, from_server, transport,
 
826
                           timeout=4.0):
 
827
        """Create a new SmartServerPipeStreamMedium."""
 
828
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
829
            transport, timeout=timeout)
 
830
 
 
831
    def create_pipe_context(self, to_server_bytes, transport):
 
832
        """Create a SmartServerSocketStreamMedium.
 
833
 
 
834
        This differes from create_pipe_medium, in that we initialize the
 
835
        request that is sent to the server, and return the StringIO class that
 
836
        will hold the response.
 
837
        """
 
838
        to_server = StringIO(to_server_bytes)
 
839
        from_server = StringIO()
 
840
        m = self.create_pipe_medium(to_server, from_server, transport)
 
841
        return m, from_server
 
842
 
 
843
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
844
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
845
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
846
            timeout=timeout)
 
847
 
 
848
    def create_socket_context(self, transport, timeout=4.0):
 
849
        """Create a new SmartServerSocketStreamMedium with default context.
 
850
 
 
851
        This will call portable_socket_pair and pass the server side to
 
852
        create_socket_medium along with transport.
 
853
        It then returns the client_sock and the server.
 
854
        """
 
855
        server_sock, client_sock = portable_socket_pair()
 
856
        server = self.create_socket_medium(server_sock, transport,
 
857
                                           timeout=timeout)
 
858
        return server, client_sock
634
859
 
635
860
    def test_smart_query_version(self):
636
861
        """Feed a canned query version to a server"""
637
862
        # wire-to-wire, using the whole stack
638
 
        to_server = StringIO('hello\n')
639
 
        from_server = StringIO()
640
863
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
641
 
        server = medium.SmartServerPipeStreamMedium(
642
 
            to_server, from_server, transport)
 
864
        server, from_server = self.create_pipe_context('hello\n', transport)
643
865
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
644
866
                from_server.write)
645
867
        server._serve_one_request(smart_protocol)
649
871
    def test_response_to_canned_get(self):
650
872
        transport = memory.MemoryTransport('memory:///')
651
873
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
652
 
        to_server = StringIO('get\001./testfile\n')
653
 
        from_server = StringIO()
654
 
        server = medium.SmartServerPipeStreamMedium(
655
 
            to_server, from_server, transport)
 
874
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
875
            transport)
656
876
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
657
877
                from_server.write)
658
878
        server._serve_one_request(smart_protocol)
669
889
        # VFS requests use filenames, not raw UTF-8.
670
890
        hpss_path = urlutils.escape(utf8_filename)
671
891
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
672
 
        to_server = StringIO('get\001' + hpss_path + '\n')
673
 
        from_server = StringIO()
674
 
        server = medium.SmartServerPipeStreamMedium(
675
 
            to_server, from_server, transport)
 
892
        server, from_server = self.create_pipe_context(
 
893
                'get\001' + hpss_path + '\n', transport)
676
894
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
677
895
                from_server.write)
678
896
        server._serve_one_request(smart_protocol)
684
902
 
685
903
    def test_pipe_like_stream_with_bulk_data(self):
686
904
        sample_request_bytes = 'command\n9\nbulk datadone\n'
687
 
        to_server = StringIO(sample_request_bytes)
688
 
        from_server = StringIO()
689
 
        server = medium.SmartServerPipeStreamMedium(
690
 
            to_server, from_server, None)
 
905
        server, from_server = self.create_pipe_context(
 
906
            sample_request_bytes, None)
691
907
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
692
908
        server._serve_one_request(sample_protocol)
693
909
        self.assertEqual('', from_server.getvalue())
696
912
 
697
913
    def test_socket_stream_with_bulk_data(self):
698
914
        sample_request_bytes = 'command\n9\nbulk datadone\n'
699
 
        server_sock, client_sock = self.portable_socket_pair()
700
 
        server = medium.SmartServerSocketStreamMedium(
701
 
            server_sock, None)
 
915
        server, client_sock = self.create_socket_context(None)
702
916
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
703
917
        client_sock.sendall(sample_request_bytes)
704
918
        server._serve_one_request(sample_protocol)
705
 
        server_sock.close()
 
919
        server._disconnect_client()
706
920
        self.assertEqual('', client_sock.recv(1))
707
921
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
708
922
        self.assertFalse(server.finished)
709
923
 
710
924
    def test_pipe_like_stream_shutdown_detection(self):
711
 
        to_server = StringIO('')
712
 
        from_server = StringIO()
713
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
925
        server, _ = self.create_pipe_context('', None)
714
926
        server._serve_one_request(SampleRequest('x'))
715
927
        self.assertTrue(server.finished)
716
928
 
717
929
    def test_socket_stream_shutdown_detection(self):
718
 
        server_sock, client_sock = self.portable_socket_pair()
 
930
        server, client_sock = self.create_socket_context(None)
719
931
        client_sock.close()
720
 
        server = medium.SmartServerSocketStreamMedium(
721
 
            server_sock, None)
722
932
        server._serve_one_request(SampleRequest('x'))
723
933
        self.assertTrue(server.finished)
724
934
 
735
945
        rest_of_request_bytes = 'lo\n'
736
946
        expected_response = (
737
947
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
738
 
        server_sock, client_sock = self.portable_socket_pair()
739
 
        server = medium.SmartServerSocketStreamMedium(
740
 
            server_sock, None)
 
948
        server, client_sock = self.create_socket_context(None)
741
949
        client_sock.sendall(incomplete_request_bytes)
742
950
        server_protocol = server._build_protocol()
743
951
        client_sock.sendall(rest_of_request_bytes)
744
952
        server._serve_one_request(server_protocol)
745
 
        server_sock.close()
 
953
        server._disconnect_client()
746
954
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
747
955
                         "Not a version 2 response to 'hello' request.")
748
956
        self.assertEqual('', client_sock.recv(1))
767
975
        to_server_w = os.fdopen(to_server_w, 'w', 0)
768
976
        from_server_r = os.fdopen(from_server_r, 'r', 0)
769
977
        from_server = os.fdopen(from_server, 'w', 0)
770
 
        server = medium.SmartServerPipeStreamMedium(
771
 
            to_server, from_server, None)
 
978
        server = self.create_pipe_medium(to_server, from_server, None)
772
979
        # Like test_socket_stream_incomplete_request, write an incomplete
773
980
        # request (that does not end in '\n') and build a protocol from it.
774
981
        to_server_w.write(incomplete_request_bytes)
789
996
        # _serve_one_request should still process both of them as if they had
790
997
        # been received separately.
791
998
        sample_request_bytes = 'command\n'
792
 
        to_server = StringIO(sample_request_bytes * 2)
793
 
        from_server = StringIO()
794
 
        server = medium.SmartServerPipeStreamMedium(
795
 
            to_server, from_server, None)
 
999
        server, from_server = self.create_pipe_context(
 
1000
            sample_request_bytes * 2, None)
796
1001
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
797
1002
        server._serve_one_request(first_protocol)
798
1003
        self.assertEqual(0, first_protocol.next_read_size())
811
1016
        # _serve_one_request should still process both of them as if they had
812
1017
        # been received separately.
813
1018
        sample_request_bytes = 'command\n'
814
 
        server_sock, client_sock = self.portable_socket_pair()
815
 
        server = medium.SmartServerSocketStreamMedium(
816
 
            server_sock, None)
 
1019
        server, client_sock = self.create_socket_context(None)
817
1020
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
818
1021
        # Put two whole requests on the wire.
819
1022
        client_sock.sendall(sample_request_bytes * 2)
826
1029
        stream_still_open = server._serve_one_request(second_protocol)
827
1030
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
828
1031
        self.assertFalse(server.finished)
829
 
        server_sock.close()
 
1032
        server._disconnect_client()
830
1033
        self.assertEqual('', client_sock.recv(1))
831
1034
 
832
1035
    def test_pipe_like_stream_error_handling(self):
839
1042
        def close():
840
1043
            self.closed = True
841
1044
        from_server.close = close
842
 
        server = medium.SmartServerPipeStreamMedium(
 
1045
        server = self.create_pipe_medium(
843
1046
            to_server, from_server, None)
844
1047
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
845
1048
        server._serve_one_request(fake_protocol)
848
1051
        self.assertTrue(server.finished)
849
1052
 
850
1053
    def test_socket_stream_error_handling(self):
851
 
        server_sock, client_sock = self.portable_socket_pair()
852
 
        server = medium.SmartServerSocketStreamMedium(
853
 
            server_sock, None)
 
1054
        server, client_sock = self.create_socket_context(None)
854
1055
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
855
1056
        server._serve_one_request(fake_protocol)
856
1057
        # recv should not block, because the other end of the socket has been
859
1060
        self.assertTrue(server.finished)
860
1061
 
861
1062
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
862
 
        to_server = StringIO('')
863
 
        from_server = StringIO()
864
 
        server = medium.SmartServerPipeStreamMedium(
865
 
            to_server, from_server, None)
 
1063
        server, from_server = self.create_pipe_context('', None)
866
1064
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
867
1065
        self.assertRaises(
868
1066
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
869
1067
        self.assertEqual('', from_server.getvalue())
870
1068
 
871
1069
    def test_socket_stream_keyboard_interrupt_handling(self):
872
 
        server_sock, client_sock = self.portable_socket_pair()
873
 
        server = medium.SmartServerSocketStreamMedium(
874
 
            server_sock, None)
 
1070
        server, client_sock = self.create_socket_context(None)
875
1071
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
876
1072
        self.assertRaises(
877
1073
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
878
 
        server_sock.close()
 
1074
        server._disconnect_client()
879
1075
        self.assertEqual('', client_sock.recv(1))
880
1076
 
881
1077
    def build_protocol_pipe_like(self, bytes):
882
 
        to_server = StringIO(bytes)
883
 
        from_server = StringIO()
884
 
        server = medium.SmartServerPipeStreamMedium(
885
 
            to_server, from_server, None)
 
1078
        server, _ = self.create_pipe_context(bytes, None)
886
1079
        return server._build_protocol()
887
1080
 
888
1081
    def build_protocol_socket(self, bytes):
889
 
        server_sock, client_sock = self.portable_socket_pair()
890
 
        server = medium.SmartServerSocketStreamMedium(
891
 
            server_sock, None)
 
1082
        server, client_sock = self.create_socket_context(None)
892
1083
        client_sock.sendall(bytes)
893
1084
        client_sock.close()
894
1085
        return server._build_protocol()
934
1125
        server_protocol = self.build_protocol_socket('bzr request 2\n')
935
1126
        self.assertProtocolTwo(server_protocol)
936
1127
 
 
1128
    def test__build_protocol_returns_if_stopping(self):
 
1129
        # _build_protocol should notice that we are stopping, and return
 
1130
        # without waiting for bytes from the client.
 
1131
        server, client_sock = self.create_socket_context(None)
 
1132
        server._stop_gracefully()
 
1133
        self.assertIs(None, server._build_protocol())
 
1134
 
 
1135
    def test_socket_set_timeout(self):
 
1136
        server, _ = self.create_socket_context(None, timeout=1.23)
 
1137
        self.assertEqual(1.23, server._client_timeout)
 
1138
 
 
1139
    def test_pipe_set_timeout(self):
 
1140
        server = self.create_pipe_medium(None, None, None,
 
1141
            timeout=1.23)
 
1142
        self.assertEqual(1.23, server._client_timeout)
 
1143
 
 
1144
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
1145
        server, client_sock = self.create_socket_context(None)
 
1146
        client_sock.sendall('data\n')
 
1147
        # This should not block or consume any actual content
 
1148
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
1149
        data = server.read_bytes(5)
 
1150
        self.assertEqual('data\n', data)
 
1151
 
 
1152
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
1153
        server, client_sock = self.create_socket_context(None)
 
1154
        # This should timeout quickly, reporting that there wasn't any data
 
1155
        self.assertRaises(errors.ConnectionTimeout,
 
1156
                          server._wait_for_bytes_with_timeout, 0.01)
 
1157
        client_sock.close()
 
1158
        data = server.read_bytes(1)
 
1159
        self.assertEqual('', data)
 
1160
 
 
1161
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
1162
        server, client_sock = self.create_socket_context(None)
 
1163
        # With the socket closed, this should return right away.
 
1164
        # It seems select.select() returns that you *can* read on the socket,
 
1165
        # even though it closed. Presumably as a way to tell it is closed?
 
1166
        # Testing shows that without sock.close() this times-out failing the
 
1167
        # test, but with it, it returns False immediately.
 
1168
        client_sock.close()
 
1169
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
1170
        data = server.read_bytes(1)
 
1171
        self.assertEqual('', data)
 
1172
 
 
1173
    def test_socket_wait_for_bytes_with_shutdown(self):
 
1174
        server, client_sock = self.create_socket_context(None)
 
1175
        t = time.time()
 
1176
        # Override the _timer functionality, so that time never increments,
 
1177
        # this way, we can be sure we stopped because of the flag, and not
 
1178
        # because of a timeout, etc.
 
1179
        server._timer = lambda: t
 
1180
        server._client_poll_timeout = 0.1
 
1181
        server._stop_gracefully()
 
1182
        server._wait_for_bytes_with_timeout(1.0)
 
1183
 
 
1184
    def test_socket_serve_timeout_closes_socket(self):
 
1185
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1186
        # This should timeout quickly, and then close the connection so that
 
1187
        # client_sock recv doesn't block.
 
1188
        server.serve()
 
1189
        self.assertEqual('', client_sock.recv(1))
 
1190
 
 
1191
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1192
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1193
        # You can't select() on a StringIO
 
1194
        (r_server, w_client) = os.pipe()
 
1195
        self.addCleanup(os.close, w_client)
 
1196
        with os.fdopen(r_server, 'rb') as rf_server:
 
1197
            server = self.create_pipe_medium(
 
1198
                rf_server, None, None)
 
1199
            os.write(w_client, 'data\n')
 
1200
            # This should not block or consume any actual content
 
1201
            server._wait_for_bytes_with_timeout(0.1)
 
1202
            data = server.read_bytes(5)
 
1203
            self.assertEqual('data\n', data)
 
1204
 
 
1205
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1206
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1207
        # You can't select() on a StringIO
 
1208
        (r_server, w_client) = os.pipe()
 
1209
        # We can't add an os.close cleanup here, because we need to control
 
1210
        # when the file handle gets closed ourselves.
 
1211
        with os.fdopen(r_server, 'rb') as rf_server:
 
1212
            server = self.create_pipe_medium(
 
1213
                rf_server, None, None)
 
1214
            if sys.platform == 'win32':
 
1215
                # Windows cannot select() on a pipe, so we just always return
 
1216
                server._wait_for_bytes_with_timeout(0.01)
 
1217
            else:
 
1218
                self.assertRaises(errors.ConnectionTimeout,
 
1219
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1220
            os.close(w_client)
 
1221
            data = server.read_bytes(5)
 
1222
            self.assertEqual('', data)
 
1223
 
 
1224
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1225
        server, _ = self.create_pipe_context('', None)
 
1226
        # Our file doesn't support polling, so we should always just return
 
1227
        # 'you have data to consume.
 
1228
        server._wait_for_bytes_with_timeout(0.01)
 
1229
 
937
1230
 
938
1231
class TestGetProtocolFactoryForBytes(tests.TestCase):
939
1232
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
969
1262
 
970
1263
class TestSmartTCPServer(tests.TestCase):
971
1264
 
 
1265
    def make_server(self):
 
1266
        """Create a SmartTCPServer that we can exercise.
 
1267
 
 
1268
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1269
        version overrides lots of functionality like 'serve', and we want to
 
1270
        test the raw service.
 
1271
 
 
1272
        This will start the server in another thread, and wait for it to
 
1273
        indicate it has finished starting up.
 
1274
 
 
1275
        :return: (server, server_thread)
 
1276
        """
 
1277
        t = _mod_transport.get_transport_from_url('memory:///')
 
1278
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1279
        server._ACCEPT_TIMEOUT = 0.1
 
1280
        # We don't use 'localhost' because that might be an IPv6 address.
 
1281
        server.start_server('127.0.0.1', 0)
 
1282
        server_thread = threading.Thread(target=server.serve,
 
1283
                                         args=(self.id(),))
 
1284
        server_thread.start()
 
1285
        # Ensure this gets called at some point
 
1286
        self.addCleanup(server._stop_gracefully)
 
1287
        server._started.wait()
 
1288
        return server, server_thread
 
1289
 
 
1290
    def ensure_client_disconnected(self, client_sock):
 
1291
        """Ensure that a socket is closed, discarding all errors."""
 
1292
        try:
 
1293
            client_sock.close()
 
1294
        except Exception:
 
1295
            pass
 
1296
 
 
1297
    def connect_to_server(self, server):
 
1298
        """Create a client socket that can talk to the server."""
 
1299
        client_sock = socket.socket()
 
1300
        server_info = server._server_socket.getsockname()
 
1301
        client_sock.connect(server_info)
 
1302
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1303
        return client_sock
 
1304
 
 
1305
    def connect_to_server_and_hangup(self, server):
 
1306
        """Connect to the server, and then hang up.
 
1307
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1308
        """
 
1309
        # If the server has already signaled that the socket is closed, we
 
1310
        # don't need to try to connect to it. Not being set, though, the server
 
1311
        # might still close the socket while we try to connect to it. So we
 
1312
        # still have to catch the exception.
 
1313
        if server._stopped.isSet():
 
1314
            return
 
1315
        try:
 
1316
            client_sock = self.connect_to_server(server)
 
1317
            client_sock.close()
 
1318
        except socket.error, e:
 
1319
            # If the server has hung up already, that is fine.
 
1320
            pass
 
1321
 
 
1322
    def say_hello(self, client_sock):
 
1323
        """Send the 'hello' smart RPC, and expect the response."""
 
1324
        client_sock.send('hello\n')
 
1325
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1326
 
 
1327
    def shutdown_server_cleanly(self, server, server_thread):
 
1328
        server._stop_gracefully()
 
1329
        self.connect_to_server_and_hangup(server)
 
1330
        server._stopped.wait()
 
1331
        server._fully_stopped.wait()
 
1332
        server_thread.join()
 
1333
 
972
1334
    def test_get_error_unexpected(self):
973
1335
        """Error reported by server with no specific representation"""
974
1336
        self.overrideEnv('BZR_NO_SMART_VFS', None)
992
1354
                                t.get, 'something')
993
1355
        self.assertContainsRe(str(err), 'some random exception')
994
1356
 
 
1357
    def test_propagates_timeout(self):
 
1358
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1359
        server_sock, client_sock = portable_socket_pair()
 
1360
        handler = server._make_handler(server_sock)
 
1361
        self.assertEqual(1.23, handler._client_timeout)
 
1362
 
 
1363
    def test_serve_conn_tracks_connections(self):
 
1364
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1365
        server_sock, client_sock = portable_socket_pair()
 
1366
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1367
        self.assertEqual(1, len(server._active_connections))
 
1368
        # We still want to talk on the connection. Polling should indicate it
 
1369
        # is still active.
 
1370
        server._poll_active_connections()
 
1371
        self.assertEqual(1, len(server._active_connections))
 
1372
        # Closing the socket will end the active thread, and polling will
 
1373
        # notice and remove it from the active set.
 
1374
        client_sock.close()
 
1375
        server._poll_active_connections(0.1)
 
1376
        self.assertEqual(0, len(server._active_connections))
 
1377
 
 
1378
    def test_serve_closes_out_finished_connections(self):
 
1379
        server, server_thread = self.make_server()
 
1380
        # The server is started, connect to it.
 
1381
        client_sock = self.connect_to_server(server)
 
1382
        # We send and receive on the connection, so that we know the
 
1383
        # server-side has seen the connect, and started handling the
 
1384
        # results.
 
1385
        self.say_hello(client_sock)
 
1386
        self.assertEqual(1, len(server._active_connections))
 
1387
        # Grab a handle to the thread that is processing our request
 
1388
        _, server_side_thread = server._active_connections[0]
 
1389
        # Close the connection, ask the server to stop, and wait for the
 
1390
        # server to stop, as well as the thread that was servicing the
 
1391
        # client request.
 
1392
        client_sock.close()
 
1393
        # Wait for the server-side request thread to notice we are closed.
 
1394
        server_side_thread.join()
 
1395
        # Stop the server, it should notice the connection has finished.
 
1396
        self.shutdown_server_cleanly(server, server_thread)
 
1397
        # The server should have noticed that all clients are gone before
 
1398
        # exiting.
 
1399
        self.assertEqual(0, len(server._active_connections))
 
1400
 
 
1401
    def test_serve_reaps_finished_connections(self):
 
1402
        server, server_thread = self.make_server()
 
1403
        client_sock1 = self.connect_to_server(server)
 
1404
        # We send and receive on the connection, so that we know the
 
1405
        # server-side has seen the connect, and started handling the
 
1406
        # results.
 
1407
        self.say_hello(client_sock1)
 
1408
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1409
        client_sock1.close()
 
1410
        server_side_thread1.join()
 
1411
        # By waiting until the first connection is fully done, the server
 
1412
        # should notice after another connection that the first has finished.
 
1413
        client_sock2 = self.connect_to_server(server)
 
1414
        self.say_hello(client_sock2)
 
1415
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1416
        # There is a race condition. We know that client_sock2 has been
 
1417
        # registered, but not that _poll_active_connections has been called. We
 
1418
        # know that it will be called before the server will accept a new
 
1419
        # connection, however. So connect one more time, and assert that we
 
1420
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1421
        # connection is not connection 1
 
1422
        client_sock3 = self.connect_to_server(server)
 
1423
        self.say_hello(client_sock3)
 
1424
        # Copy the list, so we don't have it mutating behind our back
 
1425
        conns = list(server._active_connections)
 
1426
        self.assertEqual(2, len(conns))
 
1427
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1428
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1429
        client_sock2.close()
 
1430
        client_sock3.close()
 
1431
        self.shutdown_server_cleanly(server, server_thread)
 
1432
 
 
1433
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1434
        server, server_thread = self.make_server()
 
1435
        # We need something big enough that it won't fit in a single recv. So
 
1436
        # the server thread gets blocked writing content to the client until we
 
1437
        # finish reading on the client.
 
1438
        server.backing_transport.put_bytes('bigfile',
 
1439
            'a'*1024*1024)
 
1440
        client_sock = self.connect_to_server(server)
 
1441
        self.say_hello(client_sock)
 
1442
        _, server_side_thread = server._active_connections[0]
 
1443
        # Start the RPC, but don't finish reading the response
 
1444
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1445
            'base', client_sock)
 
1446
        client_client = client._SmartClient(client_medium)
 
1447
        resp, response_handler = client_client.call_expecting_body('get',
 
1448
            'bigfile')
 
1449
        self.assertEqual(('ok',), resp)
 
1450
        # Ask the server to stop gracefully, and wait for it.
 
1451
        server._stop_gracefully()
 
1452
        self.connect_to_server_and_hangup(server)
 
1453
        server._stopped.wait()
 
1454
        # It should not be accepting another connection.
 
1455
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1456
        response_handler.read_body_bytes()
 
1457
        client_sock.close()
 
1458
        server_side_thread.join()
 
1459
        server_thread.join()
 
1460
        self.assertTrue(server._fully_stopped.isSet())
 
1461
        log = self.get_log()
 
1462
        self.assertThat(log, DocTestMatches("""\
 
1463
    INFO  Requested to stop gracefully
 
1464
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1465
    INFO  Waiting for 1 client(s) to finish
 
1466
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1467
 
 
1468
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1469
        server, server_thread = self.make_server()
 
1470
        client_sock = self.connect_to_server(server)
 
1471
        self.say_hello(client_sock)
 
1472
        server_handler, server_side_thread = server._active_connections[0]
 
1473
        self.assertFalse(server_handler.finished)
 
1474
        server._stop_gracefully()
 
1475
        self.assertTrue(server_handler.finished)
 
1476
        client_sock.close()
 
1477
        self.connect_to_server_and_hangup(server)
 
1478
        server_thread.join()
 
1479
 
995
1480
 
996
1481
class SmartTCPTests(tests.TestCase):
997
1482
    """Tests for connection/end to end behaviour using the TCP server.
1015
1500
            mem_server.start_server()
1016
1501
            self.addCleanup(mem_server.stop_server)
1017
1502
            self.permit_url(mem_server.get_url())
1018
 
            self.backing_transport = transport.get_transport(
 
1503
            self.backing_transport = _mod_transport.get_transport_from_url(
1019
1504
                mem_server.get_url())
1020
1505
        else:
1021
1506
            self.backing_transport = backing_transport
1022
1507
        if readonly:
1023
1508
            self.real_backing_transport = self.backing_transport
1024
 
            self.backing_transport = transport.get_transport(
 
1509
            self.backing_transport = _mod_transport.get_transport_from_url(
1025
1510
                "readonly+" + self.backing_transport.abspath('.'))
1026
 
        self.server = server.SmartTCPServer(self.backing_transport)
 
1511
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1512
                                                 client_timeout=4.0)
1027
1513
        self.server.start_server('127.0.0.1', 0)
1028
1514
        self.server.start_background_thread('-' + self.id())
1029
1515
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1109
1595
                      conn2.get_smart_medium())
1110
1596
 
1111
1597
    def test__remote_path(self):
1112
 
        self.assertEquals('/foo/bar',
 
1598
        self.assertEqual('/foo/bar',
1113
1599
                          self.transport._remote_path('foo/bar'))
1114
1600
 
1115
1601
    def test_clone_changes_base(self):
1116
1602
        """Cloning transport produces one with a new base location"""
1117
1603
        conn2 = self.transport.clone('subdir')
1118
 
        self.assertEquals(self.transport.base + 'subdir/',
 
1604
        self.assertEqual(self.transport.base + 'subdir/',
1119
1605
                          conn2.base)
1120
1606
 
1121
1607
    def test_open_dir(self):
1124
1610
        transport = self.transport
1125
1611
        self.backing_transport.mkdir('toffee')
1126
1612
        self.backing_transport.mkdir('toffee/apple')
1127
 
        self.assertEquals('/toffee', transport._remote_path('toffee'))
 
1613
        self.assertEqual('/toffee', transport._remote_path('toffee'))
1128
1614
        toffee_trans = transport.clone('toffee')
1129
1615
        # Check that each transport has only the contents of its directory
1130
1616
        # directly visible. If state was being held in the wrong object, it's
1140
1626
        transport = self.transport
1141
1627
        t = self.backing_transport
1142
1628
        bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
1143
 
        result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
 
1629
        result_dir = controldir.ControlDir.open_containing_from_transport(
 
1630
            transport)
1144
1631
 
1145
1632
 
1146
1633
class ReadOnlyEndToEndTests(SmartTCPTests):
1153
1640
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1154
1641
            'foo')
1155
1642
 
 
1643
    def test_rename_error_readonly(self):
 
1644
        """TransportNotPossible should be preserved from the backing transport."""
 
1645
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
1646
        self.start_server(readonly=True)
 
1647
        self.assertRaises(errors.TransportNotPossible, self.transport.rename,
 
1648
                          'foo', 'bar')
 
1649
 
 
1650
    def test_open_write_stream_error_readonly(self):
 
1651
        """TransportNotPossible should be preserved from the backing transport."""
 
1652
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
1653
        self.start_server(readonly=True)
 
1654
        self.assertRaises(
 
1655
            errors.TransportNotPossible, self.transport.open_write_stream,
 
1656
            'foo')
 
1657
 
1156
1658
 
1157
1659
class TestServerHooks(SmartTCPTests):
1158
1660
 
1163
1665
    def test_server_started_hook_memory(self):
1164
1666
        """The server_started hook fires when the server is started."""
1165
1667
        self.hook_calls = []
1166
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1668
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1167
1669
            self.capture_server_call, None)
1168
1670
        self.start_server()
1169
1671
        # at this point, the server will be starting a thread up.
1177
1679
    def test_server_started_hook_file(self):
1178
1680
        """The server_started hook fires when the server is started."""
1179
1681
        self.hook_calls = []
1180
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1682
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1181
1683
            self.capture_server_call, None)
1182
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1684
        self.start_server(
 
1685
            backing_transport=_mod_transport.get_transport_from_path("."))
1183
1686
        # at this point, the server will be starting a thread up.
1184
1687
        # there is no indicator at the moment, so bodge it by doing a request.
1185
1688
        self.transport.has('.')
1193
1696
    def test_server_stopped_hook_simple_memory(self):
1194
1697
        """The server_stopped hook fires when the server is stopped."""
1195
1698
        self.hook_calls = []
1196
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1699
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1197
1700
            self.capture_server_call, None)
1198
1701
        self.start_server()
1199
1702
        result = [([self.backing_transport.base], self.transport.base)]
1210
1713
    def test_server_stopped_hook_simple_file(self):
1211
1714
        """The server_stopped hook fires when the server is stopped."""
1212
1715
        self.hook_calls = []
1213
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1716
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1214
1717
            self.capture_server_call, None)
1215
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1718
        self.start_server(
 
1719
            backing_transport=_mod_transport.get_transport_from_path("."))
1216
1720
        result = [(
1217
1721
            [self.backing_transport.base, self.backing_transport.external_url()]
1218
1722
            , self.transport.base)]
1354
1858
class RemoteTransportRegistration(tests.TestCase):
1355
1859
 
1356
1860
    def test_registration(self):
1357
 
        t = transport.get_transport('bzr+ssh://example.com/path')
 
1861
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
1358
1862
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1359
1863
        self.assertEqual('example.com', t._parsed_url.host)
1360
1864
 
1361
1865
    def test_bzr_https(self):
1362
1866
        # https://bugs.launchpad.net/bzr/+bug/128456
1363
 
        t = transport.get_transport('bzr+https://example.com/path')
 
1867
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
1364
1868
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1365
1869
        self.assertStartsWith(
1366
1870
            t._http_transport.base,
2543
3047
        from_server = StringIO()
2544
3048
        transport = memory.MemoryTransport('memory:///')
2545
3049
        server = medium.SmartServerPipeStreamMedium(
2546
 
            to_server, from_server, transport)
 
3050
            to_server, from_server, transport, timeout=4.0)
2547
3051
        proto = server._build_protocol()
2548
3052
        message_handler = proto.message_handler
2549
3053
        server._serve_one_request(proto)
2695
3199
        requester, output = self.make_client_encoder_and_output()
2696
3200
        requester.set_headers({'header name': 'header value'})
2697
3201
        requester.call('one arg')
2698
 
        self.assertEquals(
 
3202
        self.assertEqual(
2699
3203
            'bzr message 3 (bzr 1.6)\n' # protocol version
2700
3204
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2701
3205
            's\x00\x00\x00\x0bl7:one arge' # args
2711
3215
        requester, output = self.make_client_encoder_and_output()
2712
3216
        requester.set_headers({'header name': 'header value'})
2713
3217
        requester.call_with_body_bytes(('one arg',), 'body bytes')
2714
 
        self.assertEquals(
 
3218
        self.assertEqual(
2715
3219
            'bzr message 3 (bzr 1.6)\n' # protocol version
2716
3220
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2717
3221
            's\x00\x00\x00\x0bl7:one arge' # args
2746
3250
        requester.set_headers({'header name': 'header value'})
2747
3251
        stream = ['chunk 1', 'chunk two']
2748
3252
        requester.call_with_body_stream(('one arg',), stream)
2749
 
        self.assertEquals(
 
3253
        self.assertEqual(
2750
3254
            'bzr message 3 (bzr 1.6)\n' # protocol version
2751
3255
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2752
3256
            's\x00\x00\x00\x0bl7:one arge' # args
2761
3265
        requester.set_headers({})
2762
3266
        stream = []
2763
3267
        requester.call_with_body_stream(('one arg',), stream)
2764
 
        self.assertEquals(
 
3268
        self.assertEqual(
2765
3269
            'bzr message 3 (bzr 1.6)\n' # protocol version
2766
3270
            '\x00\x00\x00\x02de' # headers
2767
3271
            's\x00\x00\x00\x0bl7:one arge' # args
2783
3287
            raise Exception('Boom!')
2784
3288
        self.assertRaises(Exception, requester.call_with_body_stream,
2785
3289
            ('one arg',), stream_that_fails())
2786
 
        self.assertEquals(
 
3290
        self.assertEqual(
2787
3291
            'bzr message 3 (bzr 1.6)\n' # protocol version
2788
3292
            '\x00\x00\x00\x02de' # headers
2789
3293
            's\x00\x00\x00\x0bl7:one arge' # args
2794
3298
            'e', # end
2795
3299
            output.getvalue())
2796
3300
 
 
3301
    def test_records_start_of_body_stream(self):
 
3302
        requester, output = self.make_client_encoder_and_output()
 
3303
        requester.set_headers({})
 
3304
        in_stream = [False]
 
3305
        def stream_checker():
 
3306
            self.assertTrue(requester.body_stream_started)
 
3307
            in_stream[0] = True
 
3308
            yield 'content'
 
3309
        flush_called = []
 
3310
        orig_flush = requester.flush
 
3311
        def tracked_flush():
 
3312
            flush_called.append(in_stream[0])
 
3313
            if in_stream[0]:
 
3314
                self.assertTrue(requester.body_stream_started)
 
3315
            else:
 
3316
                self.assertFalse(requester.body_stream_started)
 
3317
            return orig_flush()
 
3318
        requester.flush = tracked_flush
 
3319
        requester.call_with_body_stream(('one arg',), stream_checker())
 
3320
        self.assertEqual(
 
3321
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
3322
            '\x00\x00\x00\x02de' # headers
 
3323
            's\x00\x00\x00\x0bl7:one arge' # args
 
3324
            'b\x00\x00\x00\x07content' # body
 
3325
            'e', output.getvalue())
 
3326
        self.assertEqual([False, True, True], flush_called)
 
3327
 
2797
3328
 
2798
3329
class StubMediumRequest(object):
2799
3330
    """A stub medium request that tracks the number of times accept_bytes is
2869
3400
    """
2870
3401
 
2871
3402
    def setUp(self):
2872
 
        tests.TestCase.setUp(self)
 
3403
        super(TestResponseEncoderBufferingProtocolThree, self).setUp()
2873
3404
        self.writes = []
2874
3405
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2875
3406
 
3219
3750
        # encoder.
3220
3751
 
3221
3752
 
 
3753
class Test_SmartClientRequest(tests.TestCase):
 
3754
 
 
3755
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
 
3756
        response_io = StringIO(response)
 
3757
        output = StringIO()
 
3758
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
 
3759
                    fail_at_write=fail_at_write)
 
3760
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3761
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3762
        smart_client = client._SmartClient(client_medium, headers={})
 
3763
        return output, vendor, smart_client
 
3764
 
 
3765
    def make_response(self, args, body=None, body_stream=None):
 
3766
        response_io = StringIO()
 
3767
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
 
3768
            body_stream=body_stream)
 
3769
        responder = protocol.ProtocolThreeResponder(response_io.write)
 
3770
        responder.send_response(response)
 
3771
        return response_io.getvalue()
 
3772
 
 
3773
    def test__call_doesnt_retry_append(self):
 
3774
        response = self.make_response(('appended', '8'))
 
3775
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3776
            fail_at_write=False, response=response)
 
3777
        smart_request = client._SmartClientRequest(smart_client, 'append',
 
3778
            ('foo', ''), body='content\n')
 
3779
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3780
 
 
3781
    def test__call_retries_get_bytes(self):
 
3782
        response = self.make_response(('ok',), 'content\n')
 
3783
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3784
            fail_at_write=False, response=response)
 
3785
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3786
            ('foo',))
 
3787
        response, response_handler = smart_request._call(3)
 
3788
        self.assertEqual(('ok',), response)
 
3789
        self.assertEqual('content\n', response_handler.read_body_bytes())
 
3790
 
 
3791
    def test__call_noretry_get_bytes(self):
 
3792
        debug.debug_flags.add('noretry')
 
3793
        response = self.make_response(('ok',), 'content\n')
 
3794
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3795
            fail_at_write=False, response=response)
 
3796
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3797
            ('foo',))
 
3798
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3799
 
 
3800
    def test__send_no_retry_pipes(self):
 
3801
        client_read, server_write = create_file_pipes()
 
3802
        server_read, client_write = create_file_pipes()
 
3803
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
 
3804
            client_write, base='/')
 
3805
        smart_client = client._SmartClient(client_medium)
 
3806
        smart_request = client._SmartClientRequest(smart_client,
 
3807
            'hello', ())
 
3808
        # Close the server side
 
3809
        server_read.close()
 
3810
        encoder, response_handler = smart_request._construct_protocol(3)
 
3811
        self.assertRaises(errors.ConnectionReset,
 
3812
            smart_request._send_no_retry, encoder)
 
3813
 
 
3814
    def test__send_read_response_sockets(self):
 
3815
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
3816
        listen_sock.bind(('127.0.0.1', 0))
 
3817
        listen_sock.listen(1)
 
3818
        host, port = listen_sock.getsockname()
 
3819
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
 
3820
        client_medium._ensure_connection()
 
3821
        smart_client = client._SmartClient(client_medium)
 
3822
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3823
        # Accept the connection, but don't actually talk to the client.
 
3824
        server_sock, _ = listen_sock.accept()
 
3825
        server_sock.close()
 
3826
        # Sockets buffer and don't really notice that the server has closed the
 
3827
        # connection until we try to read again.
 
3828
        handler = smart_request._send(3)
 
3829
        self.assertRaises(errors.ConnectionReset,
 
3830
            handler.read_response_tuple, expect_body=False)
 
3831
 
 
3832
    def test__send_retries_on_write(self):
 
3833
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3834
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3835
        handler = smart_request._send(3)
 
3836
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3837
                         '\x00\x00\x00\x02de'   # empty headers
 
3838
                         's\x00\x00\x00\tl5:helloee',
 
3839
                         output.getvalue())
 
3840
        self.assertEqual(
 
3841
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3842
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3843
             ('close',),
 
3844
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3845
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3846
            ],
 
3847
            vendor.calls)
 
3848
 
 
3849
    def test__send_doesnt_retry_read_failure(self):
 
3850
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3851
            fail_at_write=False)
 
3852
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3853
        handler = smart_request._send(3)
 
3854
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3855
                         '\x00\x00\x00\x02de'   # empty headers
 
3856
                         's\x00\x00\x00\tl5:helloee',
 
3857
                         output.getvalue())
 
3858
        self.assertEqual(
 
3859
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3860
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3861
            ],
 
3862
            vendor.calls)
 
3863
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
3864
 
 
3865
    def test__send_request_retries_body_stream_if_not_started(self):
 
3866
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3867
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3868
            body_stream=['a', 'b'])
 
3869
        response_handler = smart_request._send(3)
 
3870
        # We connect, get disconnected, and notice before consuming the stream,
 
3871
        # so we try again one time and succeed.
 
3872
        self.assertEqual(
 
3873
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3874
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3875
             ('close',),
 
3876
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3877
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3878
            ],
 
3879
            vendor.calls)
 
3880
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3881
                         '\x00\x00\x00\x02de'   # empty headers
 
3882
                         's\x00\x00\x00\tl5:helloe'
 
3883
                         'b\x00\x00\x00\x01a'
 
3884
                         'b\x00\x00\x00\x01b'
 
3885
                         'e',
 
3886
                         output.getvalue())
 
3887
 
 
3888
    def test__send_request_stops_if_body_started(self):
 
3889
        # We intentionally use the python StringIO so that we can subclass it.
 
3890
        from StringIO import StringIO
 
3891
        response = StringIO()
 
3892
 
 
3893
        class FailAfterFirstWrite(StringIO):
 
3894
            """Allow one 'write' call to pass, fail the rest"""
 
3895
            def __init__(self):
 
3896
                StringIO.__init__(self)
 
3897
                self._first = True
 
3898
 
 
3899
            def write(self, s):
 
3900
                if self._first:
 
3901
                    self._first = False
 
3902
                    return StringIO.write(self, s)
 
3903
                raise IOError(errno.EINVAL, 'invalid file handle')
 
3904
        output = FailAfterFirstWrite()
 
3905
 
 
3906
        vendor = FirstRejectedStringIOSSHVendor(response, output,
 
3907
            fail_at_write=False)
 
3908
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3909
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3910
        smart_client = client._SmartClient(client_medium, headers={})
 
3911
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3912
            body_stream=['a', 'b'])
 
3913
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3914
        # We connect, and manage to get to the point that we start consuming
 
3915
        # the body stream. The next write fails, so we just stop.
 
3916
        self.assertEqual(
 
3917
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3918
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3919
             ('close',),
 
3920
            ],
 
3921
            vendor.calls)
 
3922
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3923
                         '\x00\x00\x00\x02de'   # empty headers
 
3924
                         's\x00\x00\x00\tl5:helloe',
 
3925
                         output.getvalue())
 
3926
 
 
3927
    def test__send_disabled_retry(self):
 
3928
        debug.debug_flags.add('noretry')
 
3929
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3930
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3931
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3932
        self.assertEqual(
 
3933
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3934
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3935
             ('close',),
 
3936
            ],
 
3937
            vendor.calls)
 
3938
 
 
3939
 
3222
3940
class LengthPrefixedBodyDecoder(tests.TestCase):
3223
3941
 
3224
3942
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3558
4276
        # still work correctly.
3559
4277
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3560
4278
        new_transport = base_transport.clone('c')
3561
 
        self.assertEqual('bzr+http://host/%7Ea/b/c/', new_transport.base)
 
4279
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
3562
4280
        self.assertEqual(
3563
4281
            'c/',
3564
4282
            new_transport._client.remote_path_from_transport(new_transport))
3567
4285
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
3568
4286
        r = t._redirected_to('http://www.example.com/foo',
3569
4287
                             'http://www.example.com/bar')
3570
 
        self.assertEquals(type(r), type(t))
 
4288
        self.assertEqual(type(r), type(t))
3571
4289
 
3572
4290
    def test__redirect_sibling_protocol(self):
3573
4291
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
3574
4292
        r = t._redirected_to('http://www.example.com/foo',
3575
4293
                             'https://www.example.com/bar')
3576
 
        self.assertEquals(type(r), type(t))
 
4294
        self.assertEqual(type(r), type(t))
3577
4295
        self.assertStartsWith(r.base, 'bzr+https')
3578
4296
 
3579
4297
    def test__redirect_to_with_user(self):
3580
4298
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
3581
4299
        r = t._redirected_to('http://www.example.com/foo',
3582
4300
                             'http://www.example.com/bar')
3583
 
        self.assertEquals(type(r), type(t))
3584
 
        self.assertEquals('joe', t._parsed_url.user)
3585
 
        self.assertEquals(t._parsed_url.user, r._parsed_url.user)
 
4301
        self.assertEqual(type(r), type(t))
 
4302
        self.assertEqual('joe', t._parsed_url.user)
 
4303
        self.assertEqual(t._parsed_url.user, r._parsed_url.user)
3586
4304
 
3587
4305
    def test_redirected_to_same_host_different_protocol(self):
3588
4306
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
3589
4307
        r = t._redirected_to('http://www.example.com/foo',
3590
4308
                             'ftp://www.example.com/foo')
3591
 
        self.assertNotEquals(type(r), type(t))
 
4309
        self.assertNotEqual(type(r), type(t))
3592
4310
 
3593
4311