~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Patch Queue Manager
  • Date: 2016-04-21 04:10:52 UTC
  • mfrom: (6616.1.1 fix-en-user-guide)
  • Revision ID: pqm@pqm.ubuntu.com-20160421041052-clcye7ns1qcl2n7w
(richard-wilbur) Ensure build of English use guide always uses English text
 even when user's locale specifies a different language. (Jelmer Vernooij)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2016 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
 
22
import errno
21
23
import os
22
24
import socket
 
25
import subprocess
 
26
import sys
23
27
import threading
 
28
import time
 
29
 
 
30
from testtools.matchers import DocTestMatches
24
31
 
25
32
import bzrlib
26
33
from bzrlib import (
27
34
        bzrdir,
 
35
        controldir,
 
36
        debug,
28
37
        errors,
29
38
        osutils,
30
39
        tests,
31
 
        transport,
 
40
        transport as _mod_transport,
32
41
        urlutils,
33
42
        )
34
43
from bzrlib.smart import (
37
46
        message,
38
47
        protocol,
39
48
        request as _mod_request,
40
 
        server,
 
49
        server as _mod_server,
41
50
        vfs,
42
51
)
43
52
from bzrlib.tests import (
 
53
    features,
44
54
    test_smart,
45
55
    test_server,
46
56
    )
53
63
        )
54
64
 
55
65
 
 
66
def create_file_pipes():
 
67
    r, w = os.pipe()
 
68
    # These must be opened without buffering, or we get undefined results
 
69
    rf = os.fdopen(r, 'rb', 0)
 
70
    wf = os.fdopen(w, 'wb', 0)
 
71
    return rf, wf
 
72
 
 
73
 
 
74
def portable_socket_pair():
 
75
    """Return a pair of TCP sockets connected to each other.
 
76
 
 
77
    Unlike socket.socketpair, this should work on Windows.
 
78
    """
 
79
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
80
    listen_sock.bind(('127.0.0.1', 0))
 
81
    listen_sock.listen(1)
 
82
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
83
    client_sock.connect(listen_sock.getsockname())
 
84
    server_sock, addr = listen_sock.accept()
 
85
    listen_sock.close()
 
86
    return server_sock, client_sock
 
87
 
 
88
 
56
89
class StringIOSSHVendor(object):
57
90
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
58
91
 
67
100
        return StringIOSSHConnection(self)
68
101
 
69
102
 
 
103
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
 
104
    """The first connection will be considered closed.
 
105
 
 
106
    The second connection will succeed normally.
 
107
    """
 
108
 
 
109
    def __init__(self, read_from, write_to, fail_at_write=True):
 
110
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
 
111
            write_to)
 
112
        self.fail_at_write = fail_at_write
 
113
        self._first = True
 
114
 
 
115
    def connect_ssh(self, username, password, host, port, command):
 
116
        self.calls.append(('connect_ssh', username, password, host, port,
 
117
            command))
 
118
        if self._first:
 
119
            self._first = False
 
120
            return ClosedSSHConnection(self)
 
121
        return StringIOSSHConnection(self)
 
122
 
 
123
 
70
124
class StringIOSSHConnection(ssh.SSHConnection):
71
125
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
72
126
 
82
136
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
83
137
 
84
138
 
85
 
class _InvalidHostnameFeature(tests.Feature):
 
139
class ClosedSSHConnection(ssh.SSHConnection):
 
140
    """An SSH connection that just has closed channels."""
 
141
 
 
142
    def __init__(self, vendor):
 
143
        self.vendor = vendor
 
144
 
 
145
    def close(self):
 
146
        self.vendor.calls.append(('close', ))
 
147
 
 
148
    def get_sock_or_pipes(self):
 
149
        # We create matching pipes, and then close the ssh side
 
150
        bzr_read, ssh_write = create_file_pipes()
 
151
        # We always fail when bzr goes to read
 
152
        ssh_write.close()
 
153
        if self.vendor.fail_at_write:
 
154
            # If set, we'll also fail when bzr goes to write
 
155
            ssh_read, bzr_write = create_file_pipes()
 
156
            ssh_read.close()
 
157
        else:
 
158
            bzr_write = self.vendor.write_to
 
159
        return 'pipes', (bzr_read, bzr_write)
 
160
 
 
161
 
 
162
class _InvalidHostnameFeature(features.Feature):
86
163
    """Does 'non_existent.invalid' fail to resolve?
87
164
 
88
165
    RFC 2606 states that .invalid is reserved for invalid domain names, and
177
254
        client_medium._accept_bytes('abc')
178
255
        self.assertEqual('abc', output.getvalue())
179
256
 
 
257
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
 
258
        # It is unfortunate that we have to use Popen for this. However,
 
259
        # os.pipe() does not behave the same as subprocess.Popen().
 
260
        # On Windows, if you use os.pipe() and close the write side,
 
261
        # read.read() hangs. On Linux, read.read() returns the empty string.
 
262
        p = subprocess.Popen([sys.executable, '-c',
 
263
            'import sys\n'
 
264
            'sys.stdout.write(sys.stdin.read(4))\n'
 
265
            'sys.stdout.close()\n'],
 
266
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
267
        client_medium = medium.SmartSimplePipesClientMedium(
 
268
            p.stdout, p.stdin, 'base')
 
269
        client_medium._accept_bytes('abc\n')
 
270
        self.assertEqual('abc', client_medium._read_bytes(3))
 
271
        p.wait()
 
272
        # While writing to the underlying pipe,
 
273
        #   Windows py2.6.6 we get IOError(EINVAL)
 
274
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
275
        # In both cases, it should be wrapped to ConnectionReset
 
276
        self.assertRaises(errors.ConnectionReset,
 
277
                          client_medium._accept_bytes, 'more')
 
278
 
 
279
    def test_simple_pipes__accept_bytes_pipe_closed(self):
 
280
        child_read, client_write = create_file_pipes()
 
281
        client_medium = medium.SmartSimplePipesClientMedium(
 
282
            None, client_write, 'base')
 
283
        client_medium._accept_bytes('abc\n')
 
284
        self.assertEqual('abc\n', child_read.read(4))
 
285
        # While writing to the underlying pipe,
 
286
        #   Windows py2.6.6 we get IOError(EINVAL)
 
287
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
288
        # In both cases, it should be wrapped to ConnectionReset
 
289
        child_read.close()
 
290
        self.assertRaises(errors.ConnectionReset,
 
291
                          client_medium._accept_bytes, 'more')
 
292
 
 
293
    def test_simple_pipes__flush_pipe_closed(self):
 
294
        child_read, client_write = create_file_pipes()
 
295
        client_medium = medium.SmartSimplePipesClientMedium(
 
296
            None, client_write, 'base')
 
297
        client_medium._accept_bytes('abc\n')
 
298
        child_read.close()
 
299
        # Even though the pipe is closed, flush on the write side seems to be a
 
300
        # no-op, rather than a failure.
 
301
        client_medium._flush()
 
302
 
 
303
    def test_simple_pipes__flush_subprocess_closed(self):
 
304
        p = subprocess.Popen([sys.executable, '-c',
 
305
            'import sys\n'
 
306
            'sys.stdout.write(sys.stdin.read(4))\n'
 
307
            'sys.stdout.close()\n'],
 
308
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
309
        client_medium = medium.SmartSimplePipesClientMedium(
 
310
            p.stdout, p.stdin, 'base')
 
311
        client_medium._accept_bytes('abc\n')
 
312
        p.wait()
 
313
        # Even though the child process is dead, flush seems to be a no-op.
 
314
        client_medium._flush()
 
315
 
 
316
    def test_simple_pipes__read_bytes_pipe_closed(self):
 
317
        child_read, client_write = create_file_pipes()
 
318
        client_medium = medium.SmartSimplePipesClientMedium(
 
319
            child_read, client_write, 'base')
 
320
        client_medium._accept_bytes('abc\n')
 
321
        client_write.close()
 
322
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
323
        self.assertEqual('', client_medium._read_bytes(4))
 
324
 
 
325
    def test_simple_pipes__read_bytes_subprocess_closed(self):
 
326
        p = subprocess.Popen([sys.executable, '-c',
 
327
            'import sys\n'
 
328
            'if sys.platform == "win32":\n'
 
329
            '    import msvcrt, os\n'
 
330
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
331
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
332
            'sys.stdout.write(sys.stdin.read(4))\n'
 
333
            'sys.stdout.close()\n'],
 
334
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
335
        client_medium = medium.SmartSimplePipesClientMedium(
 
336
            p.stdout, p.stdin, 'base')
 
337
        client_medium._accept_bytes('abc\n')
 
338
        p.wait()
 
339
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
340
        self.assertEqual('', client_medium._read_bytes(4))
 
341
 
180
342
    def test_simple_pipes_client_disconnect_does_nothing(self):
181
343
        # calling disconnect does nothing.
182
344
        input = StringIO()
338
500
            ],
339
501
            vendor.calls)
340
502
 
 
503
    def test_ssh_client_repr(self):
 
504
        client_medium = medium.SmartSSHClientMedium(
 
505
            'base', medium.SSHParams("example.com", "4242", "username"))
 
506
        self.assertEqual(
 
507
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
 
508
            repr(client_medium))
 
509
 
 
510
    def test_ssh_client_repr_no_port(self):
 
511
        client_medium = medium.SmartSSHClientMedium(
 
512
            'base', medium.SSHParams("example.com", None, "username"))
 
513
        self.assertEqual(
 
514
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
 
515
            repr(client_medium))
 
516
 
 
517
    def test_ssh_client_repr_no_username(self):
 
518
        client_medium = medium.SmartSSHClientMedium(
 
519
            'base', medium.SSHParams("example.com", None, None))
 
520
        self.assertEqual(
 
521
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
 
522
            repr(client_medium))
 
523
 
341
524
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
342
525
        # Doing a disconnect on a new (and thus unconnected) SSH medium
343
526
        # does not fail.  It's ok to disconnect an unconnected medium.
564
747
        request.finished_reading()
565
748
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
566
749
 
 
750
    def test_reset(self):
 
751
        server_sock, client_sock = portable_socket_pair()
 
752
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
 
753
        #       bzr where it exists.
 
754
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
755
        client_medium._socket = client_sock
 
756
        client_medium._connected = True
 
757
        req = client_medium.get_request()
 
758
        self.assertRaises(errors.TooManyConcurrentRequests,
 
759
            client_medium.get_request)
 
760
        client_medium.reset()
 
761
        # The stream should be reset, marked as disconnected, though ready for
 
762
        # us to make a new request
 
763
        self.assertFalse(client_medium._connected)
 
764
        self.assertIs(None, client_medium._socket)
 
765
        try:
 
766
            self.assertEqual('', client_sock.recv(1))
 
767
        except socket.error, e:
 
768
            if e.errno not in (errno.EBADF,):
 
769
                raise
 
770
        req = client_medium.get_request()
 
771
 
567
772
 
568
773
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
569
774
 
570
775
    def test_plausible_url(self):
571
 
        self.assert_(self.get_url().startswith('bzr://'))
 
776
        self.assertTrue(self.get_url().startswith('bzr://'))
572
777
 
573
778
    def test_probe_transport(self):
574
779
        t = self.get_transport()
615
820
 
616
821
    def setUp(self):
617
822
        super(TestSmartServerStreamMedium, self).setUp()
618
 
        self._captureVar('BZR_NO_SMART_VFS', None)
619
 
 
620
 
    def portable_socket_pair(self):
621
 
        """Return a pair of TCP sockets connected to each other.
622
 
 
623
 
        Unlike socket.socketpair, this should work on Windows.
624
 
        """
625
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
626
 
        listen_sock.bind(('127.0.0.1', 0))
627
 
        listen_sock.listen(1)
628
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
629
 
        client_sock.connect(listen_sock.getsockname())
630
 
        server_sock, addr = listen_sock.accept()
631
 
        listen_sock.close()
632
 
        return server_sock, client_sock
 
823
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
824
 
 
825
    def create_pipe_medium(self, to_server, from_server, transport,
 
826
                           timeout=4.0):
 
827
        """Create a new SmartServerPipeStreamMedium."""
 
828
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
829
            transport, timeout=timeout)
 
830
 
 
831
    def create_pipe_context(self, to_server_bytes, transport):
 
832
        """Create a SmartServerSocketStreamMedium.
 
833
 
 
834
        This differes from create_pipe_medium, in that we initialize the
 
835
        request that is sent to the server, and return the StringIO class that
 
836
        will hold the response.
 
837
        """
 
838
        to_server = StringIO(to_server_bytes)
 
839
        from_server = StringIO()
 
840
        m = self.create_pipe_medium(to_server, from_server, transport)
 
841
        return m, from_server
 
842
 
 
843
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
844
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
845
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
846
            timeout=timeout)
 
847
 
 
848
    def create_socket_context(self, transport, timeout=4.0):
 
849
        """Create a new SmartServerSocketStreamMedium with default context.
 
850
 
 
851
        This will call portable_socket_pair and pass the server side to
 
852
        create_socket_medium along with transport.
 
853
        It then returns the client_sock and the server.
 
854
        """
 
855
        server_sock, client_sock = portable_socket_pair()
 
856
        server = self.create_socket_medium(server_sock, transport,
 
857
                                           timeout=timeout)
 
858
        return server, client_sock
633
859
 
634
860
    def test_smart_query_version(self):
635
861
        """Feed a canned query version to a server"""
636
862
        # wire-to-wire, using the whole stack
637
 
        to_server = StringIO('hello\n')
638
 
        from_server = StringIO()
639
863
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
640
 
        server = medium.SmartServerPipeStreamMedium(
641
 
            to_server, from_server, transport)
 
864
        server, from_server = self.create_pipe_context('hello\n', transport)
642
865
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
643
866
                from_server.write)
644
867
        server._serve_one_request(smart_protocol)
648
871
    def test_response_to_canned_get(self):
649
872
        transport = memory.MemoryTransport('memory:///')
650
873
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
651
 
        to_server = StringIO('get\001./testfile\n')
652
 
        from_server = StringIO()
653
 
        server = medium.SmartServerPipeStreamMedium(
654
 
            to_server, from_server, transport)
 
874
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
875
            transport)
655
876
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
656
877
                from_server.write)
657
878
        server._serve_one_request(smart_protocol)
668
889
        # VFS requests use filenames, not raw UTF-8.
669
890
        hpss_path = urlutils.escape(utf8_filename)
670
891
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
671
 
        to_server = StringIO('get\001' + hpss_path + '\n')
672
 
        from_server = StringIO()
673
 
        server = medium.SmartServerPipeStreamMedium(
674
 
            to_server, from_server, transport)
 
892
        server, from_server = self.create_pipe_context(
 
893
                'get\001' + hpss_path + '\n', transport)
675
894
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
676
895
                from_server.write)
677
896
        server._serve_one_request(smart_protocol)
683
902
 
684
903
    def test_pipe_like_stream_with_bulk_data(self):
685
904
        sample_request_bytes = 'command\n9\nbulk datadone\n'
686
 
        to_server = StringIO(sample_request_bytes)
687
 
        from_server = StringIO()
688
 
        server = medium.SmartServerPipeStreamMedium(
689
 
            to_server, from_server, None)
 
905
        server, from_server = self.create_pipe_context(
 
906
            sample_request_bytes, None)
690
907
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
691
908
        server._serve_one_request(sample_protocol)
692
909
        self.assertEqual('', from_server.getvalue())
695
912
 
696
913
    def test_socket_stream_with_bulk_data(self):
697
914
        sample_request_bytes = 'command\n9\nbulk datadone\n'
698
 
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = medium.SmartServerSocketStreamMedium(
700
 
            server_sock, None)
 
915
        server, client_sock = self.create_socket_context(None)
701
916
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
917
        client_sock.sendall(sample_request_bytes)
703
918
        server._serve_one_request(sample_protocol)
704
 
        server_sock.close()
 
919
        server._disconnect_client()
705
920
        self.assertEqual('', client_sock.recv(1))
706
921
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
707
922
        self.assertFalse(server.finished)
708
923
 
709
924
    def test_pipe_like_stream_shutdown_detection(self):
710
 
        to_server = StringIO('')
711
 
        from_server = StringIO()
712
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
925
        server, _ = self.create_pipe_context('', None)
713
926
        server._serve_one_request(SampleRequest('x'))
714
927
        self.assertTrue(server.finished)
715
928
 
716
929
    def test_socket_stream_shutdown_detection(self):
717
 
        server_sock, client_sock = self.portable_socket_pair()
 
930
        server, client_sock = self.create_socket_context(None)
718
931
        client_sock.close()
719
 
        server = medium.SmartServerSocketStreamMedium(
720
 
            server_sock, None)
721
932
        server._serve_one_request(SampleRequest('x'))
722
933
        self.assertTrue(server.finished)
723
934
 
734
945
        rest_of_request_bytes = 'lo\n'
735
946
        expected_response = (
736
947
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
737
 
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = medium.SmartServerSocketStreamMedium(
739
 
            server_sock, None)
 
948
        server, client_sock = self.create_socket_context(None)
740
949
        client_sock.sendall(incomplete_request_bytes)
741
950
        server_protocol = server._build_protocol()
742
951
        client_sock.sendall(rest_of_request_bytes)
743
952
        server._serve_one_request(server_protocol)
744
 
        server_sock.close()
 
953
        server._disconnect_client()
745
954
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
746
955
                         "Not a version 2 response to 'hello' request.")
747
956
        self.assertEqual('', client_sock.recv(1))
766
975
        to_server_w = os.fdopen(to_server_w, 'w', 0)
767
976
        from_server_r = os.fdopen(from_server_r, 'r', 0)
768
977
        from_server = os.fdopen(from_server, 'w', 0)
769
 
        server = medium.SmartServerPipeStreamMedium(
770
 
            to_server, from_server, None)
 
978
        server = self.create_pipe_medium(to_server, from_server, None)
771
979
        # Like test_socket_stream_incomplete_request, write an incomplete
772
980
        # request (that does not end in '\n') and build a protocol from it.
773
981
        to_server_w.write(incomplete_request_bytes)
788
996
        # _serve_one_request should still process both of them as if they had
789
997
        # been received separately.
790
998
        sample_request_bytes = 'command\n'
791
 
        to_server = StringIO(sample_request_bytes * 2)
792
 
        from_server = StringIO()
793
 
        server = medium.SmartServerPipeStreamMedium(
794
 
            to_server, from_server, None)
 
999
        server, from_server = self.create_pipe_context(
 
1000
            sample_request_bytes * 2, None)
795
1001
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
796
1002
        server._serve_one_request(first_protocol)
797
1003
        self.assertEqual(0, first_protocol.next_read_size())
810
1016
        # _serve_one_request should still process both of them as if they had
811
1017
        # been received separately.
812
1018
        sample_request_bytes = 'command\n'
813
 
        server_sock, client_sock = self.portable_socket_pair()
814
 
        server = medium.SmartServerSocketStreamMedium(
815
 
            server_sock, None)
 
1019
        server, client_sock = self.create_socket_context(None)
816
1020
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
817
1021
        # Put two whole requests on the wire.
818
1022
        client_sock.sendall(sample_request_bytes * 2)
825
1029
        stream_still_open = server._serve_one_request(second_protocol)
826
1030
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
827
1031
        self.assertFalse(server.finished)
828
 
        server_sock.close()
 
1032
        server._disconnect_client()
829
1033
        self.assertEqual('', client_sock.recv(1))
830
1034
 
831
1035
    def test_pipe_like_stream_error_handling(self):
838
1042
        def close():
839
1043
            self.closed = True
840
1044
        from_server.close = close
841
 
        server = medium.SmartServerPipeStreamMedium(
 
1045
        server = self.create_pipe_medium(
842
1046
            to_server, from_server, None)
843
1047
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
844
1048
        server._serve_one_request(fake_protocol)
847
1051
        self.assertTrue(server.finished)
848
1052
 
849
1053
    def test_socket_stream_error_handling(self):
850
 
        server_sock, client_sock = self.portable_socket_pair()
851
 
        server = medium.SmartServerSocketStreamMedium(
852
 
            server_sock, None)
 
1054
        server, client_sock = self.create_socket_context(None)
853
1055
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
854
1056
        server._serve_one_request(fake_protocol)
855
1057
        # recv should not block, because the other end of the socket has been
858
1060
        self.assertTrue(server.finished)
859
1061
 
860
1062
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
861
 
        to_server = StringIO('')
862
 
        from_server = StringIO()
863
 
        server = medium.SmartServerPipeStreamMedium(
864
 
            to_server, from_server, None)
 
1063
        server, from_server = self.create_pipe_context('', None)
865
1064
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
866
1065
        self.assertRaises(
867
1066
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
868
1067
        self.assertEqual('', from_server.getvalue())
869
1068
 
870
1069
    def test_socket_stream_keyboard_interrupt_handling(self):
871
 
        server_sock, client_sock = self.portable_socket_pair()
872
 
        server = medium.SmartServerSocketStreamMedium(
873
 
            server_sock, None)
 
1070
        server, client_sock = self.create_socket_context(None)
874
1071
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
875
1072
        self.assertRaises(
876
1073
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
877
 
        server_sock.close()
 
1074
        server._disconnect_client()
878
1075
        self.assertEqual('', client_sock.recv(1))
879
1076
 
880
1077
    def build_protocol_pipe_like(self, bytes):
881
 
        to_server = StringIO(bytes)
882
 
        from_server = StringIO()
883
 
        server = medium.SmartServerPipeStreamMedium(
884
 
            to_server, from_server, None)
 
1078
        server, _ = self.create_pipe_context(bytes, None)
885
1079
        return server._build_protocol()
886
1080
 
887
1081
    def build_protocol_socket(self, bytes):
888
 
        server_sock, client_sock = self.portable_socket_pair()
889
 
        server = medium.SmartServerSocketStreamMedium(
890
 
            server_sock, None)
 
1082
        server, client_sock = self.create_socket_context(None)
891
1083
        client_sock.sendall(bytes)
892
1084
        client_sock.close()
893
1085
        return server._build_protocol()
933
1125
        server_protocol = self.build_protocol_socket('bzr request 2\n')
934
1126
        self.assertProtocolTwo(server_protocol)
935
1127
 
 
1128
    def test__build_protocol_returns_if_stopping(self):
 
1129
        # _build_protocol should notice that we are stopping, and return
 
1130
        # without waiting for bytes from the client.
 
1131
        server, client_sock = self.create_socket_context(None)
 
1132
        server._stop_gracefully()
 
1133
        self.assertIs(None, server._build_protocol())
 
1134
 
 
1135
    def test_socket_set_timeout(self):
 
1136
        server, _ = self.create_socket_context(None, timeout=1.23)
 
1137
        self.assertEqual(1.23, server._client_timeout)
 
1138
 
 
1139
    def test_pipe_set_timeout(self):
 
1140
        server = self.create_pipe_medium(None, None, None,
 
1141
            timeout=1.23)
 
1142
        self.assertEqual(1.23, server._client_timeout)
 
1143
 
 
1144
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
1145
        server, client_sock = self.create_socket_context(None)
 
1146
        client_sock.sendall('data\n')
 
1147
        # This should not block or consume any actual content
 
1148
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
1149
        data = server.read_bytes(5)
 
1150
        self.assertEqual('data\n', data)
 
1151
 
 
1152
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
1153
        server, client_sock = self.create_socket_context(None)
 
1154
        # This should timeout quickly, reporting that there wasn't any data
 
1155
        self.assertRaises(errors.ConnectionTimeout,
 
1156
                          server._wait_for_bytes_with_timeout, 0.01)
 
1157
        client_sock.close()
 
1158
        data = server.read_bytes(1)
 
1159
        self.assertEqual('', data)
 
1160
 
 
1161
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
1162
        server, client_sock = self.create_socket_context(None)
 
1163
        # With the socket closed, this should return right away.
 
1164
        # It seems select.select() returns that you *can* read on the socket,
 
1165
        # even though it closed. Presumably as a way to tell it is closed?
 
1166
        # Testing shows that without sock.close() this times-out failing the
 
1167
        # test, but with it, it returns False immediately.
 
1168
        client_sock.close()
 
1169
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
1170
        data = server.read_bytes(1)
 
1171
        self.assertEqual('', data)
 
1172
 
 
1173
    def test_socket_wait_for_bytes_with_shutdown(self):
 
1174
        server, client_sock = self.create_socket_context(None)
 
1175
        t = time.time()
 
1176
        # Override the _timer functionality, so that time never increments,
 
1177
        # this way, we can be sure we stopped because of the flag, and not
 
1178
        # because of a timeout, etc.
 
1179
        server._timer = lambda: t
 
1180
        server._client_poll_timeout = 0.1
 
1181
        server._stop_gracefully()
 
1182
        server._wait_for_bytes_with_timeout(1.0)
 
1183
 
 
1184
    def test_socket_serve_timeout_closes_socket(self):
 
1185
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1186
        # This should timeout quickly, and then close the connection so that
 
1187
        # client_sock recv doesn't block.
 
1188
        server.serve()
 
1189
        self.assertEqual('', client_sock.recv(1))
 
1190
 
 
1191
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1192
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1193
        # You can't select() on a StringIO
 
1194
        (r_server, w_client) = os.pipe()
 
1195
        self.addCleanup(os.close, w_client)
 
1196
        with os.fdopen(r_server, 'rb') as rf_server:
 
1197
            server = self.create_pipe_medium(
 
1198
                rf_server, None, None)
 
1199
            os.write(w_client, 'data\n')
 
1200
            # This should not block or consume any actual content
 
1201
            server._wait_for_bytes_with_timeout(0.1)
 
1202
            data = server.read_bytes(5)
 
1203
            self.assertEqual('data\n', data)
 
1204
 
 
1205
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1206
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1207
        # You can't select() on a StringIO
 
1208
        (r_server, w_client) = os.pipe()
 
1209
        # We can't add an os.close cleanup here, because we need to control
 
1210
        # when the file handle gets closed ourselves.
 
1211
        with os.fdopen(r_server, 'rb') as rf_server:
 
1212
            server = self.create_pipe_medium(
 
1213
                rf_server, None, None)
 
1214
            if sys.platform == 'win32':
 
1215
                # Windows cannot select() on a pipe, so we just always return
 
1216
                server._wait_for_bytes_with_timeout(0.01)
 
1217
            else:
 
1218
                self.assertRaises(errors.ConnectionTimeout,
 
1219
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1220
            os.close(w_client)
 
1221
            data = server.read_bytes(5)
 
1222
            self.assertEqual('', data)
 
1223
 
 
1224
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1225
        server, _ = self.create_pipe_context('', None)
 
1226
        # Our file doesn't support polling, so we should always just return
 
1227
        # 'you have data to consume.
 
1228
        server._wait_for_bytes_with_timeout(0.01)
 
1229
 
936
1230
 
937
1231
class TestGetProtocolFactoryForBytes(tests.TestCase):
938
1232
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
968
1262
 
969
1263
class TestSmartTCPServer(tests.TestCase):
970
1264
 
 
1265
    def make_server(self):
 
1266
        """Create a SmartTCPServer that we can exercise.
 
1267
 
 
1268
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1269
        version overrides lots of functionality like 'serve', and we want to
 
1270
        test the raw service.
 
1271
 
 
1272
        This will start the server in another thread, and wait for it to
 
1273
        indicate it has finished starting up.
 
1274
 
 
1275
        :return: (server, server_thread)
 
1276
        """
 
1277
        t = _mod_transport.get_transport_from_url('memory:///')
 
1278
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1279
        server._ACCEPT_TIMEOUT = 0.1
 
1280
        # We don't use 'localhost' because that might be an IPv6 address.
 
1281
        server.start_server('127.0.0.1', 0)
 
1282
        server_thread = threading.Thread(target=server.serve,
 
1283
                                         args=(self.id(),))
 
1284
        server_thread.start()
 
1285
        # Ensure this gets called at some point
 
1286
        self.addCleanup(server._stop_gracefully)
 
1287
        server._started.wait()
 
1288
        return server, server_thread
 
1289
 
 
1290
    def ensure_client_disconnected(self, client_sock):
 
1291
        """Ensure that a socket is closed, discarding all errors."""
 
1292
        try:
 
1293
            client_sock.close()
 
1294
        except Exception:
 
1295
            pass
 
1296
 
 
1297
    def connect_to_server(self, server):
 
1298
        """Create a client socket that can talk to the server."""
 
1299
        client_sock = socket.socket()
 
1300
        server_info = server._server_socket.getsockname()
 
1301
        client_sock.connect(server_info)
 
1302
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1303
        return client_sock
 
1304
 
 
1305
    def connect_to_server_and_hangup(self, server):
 
1306
        """Connect to the server, and then hang up.
 
1307
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1308
        """
 
1309
        # If the server has already signaled that the socket is closed, we
 
1310
        # don't need to try to connect to it. Not being set, though, the server
 
1311
        # might still close the socket while we try to connect to it. So we
 
1312
        # still have to catch the exception.
 
1313
        if server._stopped.isSet():
 
1314
            return
 
1315
        try:
 
1316
            client_sock = self.connect_to_server(server)
 
1317
            client_sock.close()
 
1318
        except socket.error, e:
 
1319
            # If the server has hung up already, that is fine.
 
1320
            pass
 
1321
 
 
1322
    def say_hello(self, client_sock):
 
1323
        """Send the 'hello' smart RPC, and expect the response."""
 
1324
        client_sock.send('hello\n')
 
1325
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1326
 
 
1327
    def shutdown_server_cleanly(self, server, server_thread):
 
1328
        server._stop_gracefully()
 
1329
        self.connect_to_server_and_hangup(server)
 
1330
        server._stopped.wait()
 
1331
        server._fully_stopped.wait()
 
1332
        server_thread.join()
 
1333
 
971
1334
    def test_get_error_unexpected(self):
972
1335
        """Error reported by server with no specific representation"""
973
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1336
        self.overrideEnv('BZR_NO_SMART_VFS', None)
974
1337
        class FlakyTransport(object):
975
1338
            base = 'a_url'
976
1339
            def external_url(self):
991
1354
                                t.get, 'something')
992
1355
        self.assertContainsRe(str(err), 'some random exception')
993
1356
 
 
1357
    def test_propagates_timeout(self):
 
1358
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1359
        server_sock, client_sock = portable_socket_pair()
 
1360
        handler = server._make_handler(server_sock)
 
1361
        self.assertEqual(1.23, handler._client_timeout)
 
1362
 
 
1363
    def test_serve_conn_tracks_connections(self):
 
1364
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1365
        server_sock, client_sock = portable_socket_pair()
 
1366
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1367
        self.assertEqual(1, len(server._active_connections))
 
1368
        # We still want to talk on the connection. Polling should indicate it
 
1369
        # is still active.
 
1370
        server._poll_active_connections()
 
1371
        self.assertEqual(1, len(server._active_connections))
 
1372
        # Closing the socket will end the active thread, and polling will
 
1373
        # notice and remove it from the active set.
 
1374
        client_sock.close()
 
1375
        server._poll_active_connections(0.1)
 
1376
        self.assertEqual(0, len(server._active_connections))
 
1377
 
 
1378
    def test_serve_closes_out_finished_connections(self):
 
1379
        server, server_thread = self.make_server()
 
1380
        # The server is started, connect to it.
 
1381
        client_sock = self.connect_to_server(server)
 
1382
        # We send and receive on the connection, so that we know the
 
1383
        # server-side has seen the connect, and started handling the
 
1384
        # results.
 
1385
        self.say_hello(client_sock)
 
1386
        self.assertEqual(1, len(server._active_connections))
 
1387
        # Grab a handle to the thread that is processing our request
 
1388
        _, server_side_thread = server._active_connections[0]
 
1389
        # Close the connection, ask the server to stop, and wait for the
 
1390
        # server to stop, as well as the thread that was servicing the
 
1391
        # client request.
 
1392
        client_sock.close()
 
1393
        # Wait for the server-side request thread to notice we are closed.
 
1394
        server_side_thread.join()
 
1395
        # Stop the server, it should notice the connection has finished.
 
1396
        self.shutdown_server_cleanly(server, server_thread)
 
1397
        # The server should have noticed that all clients are gone before
 
1398
        # exiting.
 
1399
        self.assertEqual(0, len(server._active_connections))
 
1400
 
 
1401
    def test_serve_reaps_finished_connections(self):
 
1402
        server, server_thread = self.make_server()
 
1403
        client_sock1 = self.connect_to_server(server)
 
1404
        # We send and receive on the connection, so that we know the
 
1405
        # server-side has seen the connect, and started handling the
 
1406
        # results.
 
1407
        self.say_hello(client_sock1)
 
1408
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1409
        client_sock1.close()
 
1410
        server_side_thread1.join()
 
1411
        # By waiting until the first connection is fully done, the server
 
1412
        # should notice after another connection that the first has finished.
 
1413
        client_sock2 = self.connect_to_server(server)
 
1414
        self.say_hello(client_sock2)
 
1415
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1416
        # There is a race condition. We know that client_sock2 has been
 
1417
        # registered, but not that _poll_active_connections has been called. We
 
1418
        # know that it will be called before the server will accept a new
 
1419
        # connection, however. So connect one more time, and assert that we
 
1420
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1421
        # connection is not connection 1
 
1422
        client_sock3 = self.connect_to_server(server)
 
1423
        self.say_hello(client_sock3)
 
1424
        # Copy the list, so we don't have it mutating behind our back
 
1425
        conns = list(server._active_connections)
 
1426
        self.assertEqual(2, len(conns))
 
1427
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1428
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1429
        client_sock2.close()
 
1430
        client_sock3.close()
 
1431
        self.shutdown_server_cleanly(server, server_thread)
 
1432
 
 
1433
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1434
        server, server_thread = self.make_server()
 
1435
        # We need something big enough that it won't fit in a single recv. So
 
1436
        # the server thread gets blocked writing content to the client until we
 
1437
        # finish reading on the client.
 
1438
        server.backing_transport.put_bytes('bigfile',
 
1439
            'a'*1024*1024)
 
1440
        client_sock = self.connect_to_server(server)
 
1441
        self.say_hello(client_sock)
 
1442
        _, server_side_thread = server._active_connections[0]
 
1443
        # Start the RPC, but don't finish reading the response
 
1444
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1445
            'base', client_sock)
 
1446
        client_client = client._SmartClient(client_medium)
 
1447
        resp, response_handler = client_client.call_expecting_body('get',
 
1448
            'bigfile')
 
1449
        self.assertEqual(('ok',), resp)
 
1450
        # Ask the server to stop gracefully, and wait for it.
 
1451
        server._stop_gracefully()
 
1452
        self.connect_to_server_and_hangup(server)
 
1453
        server._stopped.wait()
 
1454
        # It should not be accepting another connection.
 
1455
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1456
        response_handler.read_body_bytes()
 
1457
        client_sock.close()
 
1458
        server_side_thread.join()
 
1459
        server_thread.join()
 
1460
        self.assertTrue(server._fully_stopped.isSet())
 
1461
        log = self.get_log()
 
1462
        self.assertThat(log, DocTestMatches("""\
 
1463
    INFO  Requested to stop gracefully
 
1464
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1465
    INFO  Waiting for 1 client(s) to finish
 
1466
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1467
 
 
1468
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1469
        server, server_thread = self.make_server()
 
1470
        client_sock = self.connect_to_server(server)
 
1471
        self.say_hello(client_sock)
 
1472
        server_handler, server_side_thread = server._active_connections[0]
 
1473
        self.assertFalse(server_handler.finished)
 
1474
        server._stop_gracefully()
 
1475
        self.assertTrue(server_handler.finished)
 
1476
        client_sock.close()
 
1477
        self.connect_to_server_and_hangup(server)
 
1478
        server_thread.join()
 
1479
 
994
1480
 
995
1481
class SmartTCPTests(tests.TestCase):
996
1482
    """Tests for connection/end to end behaviour using the TCP server.
1014
1500
            mem_server.start_server()
1015
1501
            self.addCleanup(mem_server.stop_server)
1016
1502
            self.permit_url(mem_server.get_url())
1017
 
            self.backing_transport = transport.get_transport(
 
1503
            self.backing_transport = _mod_transport.get_transport_from_url(
1018
1504
                mem_server.get_url())
1019
1505
        else:
1020
1506
            self.backing_transport = backing_transport
1021
1507
        if readonly:
1022
1508
            self.real_backing_transport = self.backing_transport
1023
 
            self.backing_transport = transport.get_transport(
 
1509
            self.backing_transport = _mod_transport.get_transport_from_url(
1024
1510
                "readonly+" + self.backing_transport.abspath('.'))
1025
 
        self.server = server.SmartTCPServer(self.backing_transport)
 
1511
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1512
                                                 client_timeout=4.0)
1026
1513
        self.server.start_server('127.0.0.1', 0)
1027
1514
        self.server.start_background_thread('-' + self.id())
1028
1515
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1077
1564
 
1078
1565
    def test_smart_transport_has(self):
1079
1566
        """Checking for file existence over smart."""
1080
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1567
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1081
1568
        self.backing_transport.put_bytes("foo", "contents of foo\n")
1082
1569
        self.assertTrue(self.transport.has("foo"))
1083
1570
        self.assertFalse(self.transport.has("non-foo"))
1084
1571
 
1085
1572
    def test_smart_transport_get(self):
1086
1573
        """Read back a file over smart."""
1087
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1574
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1088
1575
        self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
1089
1576
        fp = self.transport.get("foo")
1090
1577
        self.assertEqual('contents\nof\nfoo\n', fp.read())
1094
1581
        # The path in a raised NoSuchFile exception should be the precise path
1095
1582
        # asked for by the client. This gives meaningful and unsurprising errors
1096
1583
        # for users.
1097
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1584
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1098
1585
        err = self.assertRaises(
1099
1586
            errors.NoSuchFile, self.transport.get, 'not%20a%20file')
1100
1587
        self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file'])
1108
1595
                      conn2.get_smart_medium())
1109
1596
 
1110
1597
    def test__remote_path(self):
1111
 
        self.assertEquals('/foo/bar',
 
1598
        self.assertEqual('/foo/bar',
1112
1599
                          self.transport._remote_path('foo/bar'))
1113
1600
 
1114
1601
    def test_clone_changes_base(self):
1115
1602
        """Cloning transport produces one with a new base location"""
1116
1603
        conn2 = self.transport.clone('subdir')
1117
 
        self.assertEquals(self.transport.base + 'subdir/',
 
1604
        self.assertEqual(self.transport.base + 'subdir/',
1118
1605
                          conn2.base)
1119
1606
 
1120
1607
    def test_open_dir(self):
1121
1608
        """Test changing directory"""
1122
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1609
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1123
1610
        transport = self.transport
1124
1611
        self.backing_transport.mkdir('toffee')
1125
1612
        self.backing_transport.mkdir('toffee/apple')
1126
 
        self.assertEquals('/toffee', transport._remote_path('toffee'))
 
1613
        self.assertEqual('/toffee', transport._remote_path('toffee'))
1127
1614
        toffee_trans = transport.clone('toffee')
1128
1615
        # Check that each transport has only the contents of its directory
1129
1616
        # directly visible. If state was being held in the wrong object, it's
1139
1626
        transport = self.transport
1140
1627
        t = self.backing_transport
1141
1628
        bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
1142
 
        result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
 
1629
        result_dir = controldir.ControlDir.open_containing_from_transport(
 
1630
            transport)
1143
1631
 
1144
1632
 
1145
1633
class ReadOnlyEndToEndTests(SmartTCPTests):
1147
1635
 
1148
1636
    def test_mkdir_error_readonly(self):
1149
1637
        """TransportNotPossible should be preserved from the backing transport."""
1150
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1638
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1151
1639
        self.start_server(readonly=True)
1152
1640
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1153
1641
            'foo')
1154
1642
 
 
1643
    def test_rename_error_readonly(self):
 
1644
        """TransportNotPossible should be preserved from the backing transport."""
 
1645
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
1646
        self.start_server(readonly=True)
 
1647
        self.assertRaises(errors.TransportNotPossible, self.transport.rename,
 
1648
                          'foo', 'bar')
 
1649
 
 
1650
    def test_open_write_stream_error_readonly(self):
 
1651
        """TransportNotPossible should be preserved from the backing transport."""
 
1652
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
1653
        self.start_server(readonly=True)
 
1654
        self.assertRaises(
 
1655
            errors.TransportNotPossible, self.transport.open_write_stream,
 
1656
            'foo')
 
1657
 
1155
1658
 
1156
1659
class TestServerHooks(SmartTCPTests):
1157
1660
 
1162
1665
    def test_server_started_hook_memory(self):
1163
1666
        """The server_started hook fires when the server is started."""
1164
1667
        self.hook_calls = []
1165
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1668
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1166
1669
            self.capture_server_call, None)
1167
1670
        self.start_server()
1168
1671
        # at this point, the server will be starting a thread up.
1176
1679
    def test_server_started_hook_file(self):
1177
1680
        """The server_started hook fires when the server is started."""
1178
1681
        self.hook_calls = []
1179
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1682
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1180
1683
            self.capture_server_call, None)
1181
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1684
        self.start_server(
 
1685
            backing_transport=_mod_transport.get_transport_from_path("."))
1182
1686
        # at this point, the server will be starting a thread up.
1183
1687
        # there is no indicator at the moment, so bodge it by doing a request.
1184
1688
        self.transport.has('.')
1192
1696
    def test_server_stopped_hook_simple_memory(self):
1193
1697
        """The server_stopped hook fires when the server is stopped."""
1194
1698
        self.hook_calls = []
1195
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1699
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1196
1700
            self.capture_server_call, None)
1197
1701
        self.start_server()
1198
1702
        result = [([self.backing_transport.base], self.transport.base)]
1209
1713
    def test_server_stopped_hook_simple_file(self):
1210
1714
        """The server_stopped hook fires when the server is stopped."""
1211
1715
        self.hook_calls = []
1212
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1716
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1213
1717
            self.capture_server_call, None)
1214
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1718
        self.start_server(
 
1719
            backing_transport=_mod_transport.get_transport_from_path("."))
1215
1720
        result = [(
1216
1721
            [self.backing_transport.base, self.backing_transport.external_url()]
1217
1722
            , self.transport.base)]
1261
1766
 
1262
1767
    def setUp(self):
1263
1768
        super(SmartServerRequestHandlerTests, self).setUp()
1264
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1769
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1265
1770
 
1266
1771
    def build_handler(self, transport):
1267
1772
        """Returns a handler for the commands in protocol version one."""
1286
1791
        handler = vfs.HasRequest(None, '/')
1287
1792
        # set environment variable after construction to make sure it's
1288
1793
        # examined.
1289
 
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
1290
 
        # has called _captureVar, so it will be restored to the right state
1291
 
        # afterwards.
1292
 
        os.environ['BZR_NO_SMART_VFS'] = ''
 
1794
        self.overrideEnv('BZR_NO_SMART_VFS', '')
1293
1795
        self.assertRaises(errors.DisabledMethod, handler.execute)
1294
1796
 
1295
1797
    def test_readonly_exception_becomes_transport_not_possible(self):
1356
1858
class RemoteTransportRegistration(tests.TestCase):
1357
1859
 
1358
1860
    def test_registration(self):
1359
 
        t = transport.get_transport('bzr+ssh://example.com/path')
 
1861
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
1360
1862
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1361
 
        self.assertEqual('example.com', t._host)
 
1863
        self.assertEqual('example.com', t._parsed_url.host)
1362
1864
 
1363
1865
    def test_bzr_https(self):
1364
1866
        # https://bugs.launchpad.net/bzr/+bug/128456
1365
 
        t = transport.get_transport('bzr+https://example.com/path')
 
1867
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
1366
1868
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1367
1869
        self.assertStartsWith(
1368
1870
            t._http_transport.base,
1496
1998
        smart_protocol._has_dispatched = True
1497
1999
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
1498
2000
            None, _mod_request.request_handlers, '/')
 
2001
        # GZ 2010-08-10: Cycle with closure affects 4 tests
1499
2002
        class FakeCommand(_mod_request.SmartServerRequest):
1500
2003
            def do_body(self_cmd, body_bytes):
1501
2004
                self.end_received = True
1602
2105
        self.assertTrue(self.end_received)
1603
2106
 
1604
2107
    def test_accept_request_and_body_all_at_once(self):
1605
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
2108
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1606
2109
        mem_transport = memory.MemoryTransport()
1607
2110
        mem_transport.put_bytes('foo', 'abcdefghij')
1608
2111
        out_stream = StringIO()
1868
2371
        self.assertTrue(self.end_received)
1869
2372
 
1870
2373
    def test_accept_request_and_body_all_at_once(self):
1871
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
2374
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1872
2375
        mem_transport = memory.MemoryTransport()
1873
2376
        mem_transport.put_bytes('foo', 'abcdefghij')
1874
2377
        out_stream = StringIO()
2415
2918
        self.assertEqual('aaa', stream.next())
2416
2919
        self.assertEqual('bbb', stream.next())
2417
2920
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2418
 
        self.assertEqual(('error', 'Boom!'), exc.error_tuple)
 
2921
        self.assertEqual(('error', 'Exception', 'Boom!'), exc.error_tuple)
2419
2922
 
2420
2923
    def test_interrupted_by_connection_lost(self):
2421
2924
        interrupted_body_stream = (
2544
3047
        from_server = StringIO()
2545
3048
        transport = memory.MemoryTransport('memory:///')
2546
3049
        server = medium.SmartServerPipeStreamMedium(
2547
 
            to_server, from_server, transport)
 
3050
            to_server, from_server, transport, timeout=4.0)
2548
3051
        proto = server._build_protocol()
2549
3052
        message_handler = proto.message_handler
2550
3053
        server._serve_one_request(proto)
2696
3199
        requester, output = self.make_client_encoder_and_output()
2697
3200
        requester.set_headers({'header name': 'header value'})
2698
3201
        requester.call('one arg')
2699
 
        self.assertEquals(
 
3202
        self.assertEqual(
2700
3203
            'bzr message 3 (bzr 1.6)\n' # protocol version
2701
3204
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2702
3205
            's\x00\x00\x00\x0bl7:one arge' # args
2712
3215
        requester, output = self.make_client_encoder_and_output()
2713
3216
        requester.set_headers({'header name': 'header value'})
2714
3217
        requester.call_with_body_bytes(('one arg',), 'body bytes')
2715
 
        self.assertEquals(
 
3218
        self.assertEqual(
2716
3219
            'bzr message 3 (bzr 1.6)\n' # protocol version
2717
3220
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2718
3221
            's\x00\x00\x00\x0bl7:one arge' # args
2747
3250
        requester.set_headers({'header name': 'header value'})
2748
3251
        stream = ['chunk 1', 'chunk two']
2749
3252
        requester.call_with_body_stream(('one arg',), stream)
2750
 
        self.assertEquals(
 
3253
        self.assertEqual(
2751
3254
            'bzr message 3 (bzr 1.6)\n' # protocol version
2752
3255
            '\x00\x00\x00\x1fd11:header name12:header valuee' # headers
2753
3256
            's\x00\x00\x00\x0bl7:one arge' # args
2762
3265
        requester.set_headers({})
2763
3266
        stream = []
2764
3267
        requester.call_with_body_stream(('one arg',), stream)
2765
 
        self.assertEquals(
 
3268
        self.assertEqual(
2766
3269
            'bzr message 3 (bzr 1.6)\n' # protocol version
2767
3270
            '\x00\x00\x00\x02de' # headers
2768
3271
            's\x00\x00\x00\x0bl7:one arge' # args
2784
3287
            raise Exception('Boom!')
2785
3288
        self.assertRaises(Exception, requester.call_with_body_stream,
2786
3289
            ('one arg',), stream_that_fails())
2787
 
        self.assertEquals(
 
3290
        self.assertEqual(
2788
3291
            'bzr message 3 (bzr 1.6)\n' # protocol version
2789
3292
            '\x00\x00\x00\x02de' # headers
2790
3293
            's\x00\x00\x00\x0bl7:one arge' # args
2795
3298
            'e', # end
2796
3299
            output.getvalue())
2797
3300
 
 
3301
    def test_records_start_of_body_stream(self):
 
3302
        requester, output = self.make_client_encoder_and_output()
 
3303
        requester.set_headers({})
 
3304
        in_stream = [False]
 
3305
        def stream_checker():
 
3306
            self.assertTrue(requester.body_stream_started)
 
3307
            in_stream[0] = True
 
3308
            yield 'content'
 
3309
        flush_called = []
 
3310
        orig_flush = requester.flush
 
3311
        def tracked_flush():
 
3312
            flush_called.append(in_stream[0])
 
3313
            if in_stream[0]:
 
3314
                self.assertTrue(requester.body_stream_started)
 
3315
            else:
 
3316
                self.assertFalse(requester.body_stream_started)
 
3317
            return orig_flush()
 
3318
        requester.flush = tracked_flush
 
3319
        requester.call_with_body_stream(('one arg',), stream_checker())
 
3320
        self.assertEqual(
 
3321
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
3322
            '\x00\x00\x00\x02de' # headers
 
3323
            's\x00\x00\x00\x0bl7:one arge' # args
 
3324
            'b\x00\x00\x00\x07content' # body
 
3325
            'e', output.getvalue())
 
3326
        self.assertEqual([False, True, True], flush_called)
 
3327
 
2798
3328
 
2799
3329
class StubMediumRequest(object):
2800
3330
    """A stub medium request that tracks the number of times accept_bytes is
2818
3348
    'b\x00\x00\x00\x03aaa' # body part ('aaa')
2819
3349
    'b\x00\x00\x00\x03bbb' # body part ('bbb')
2820
3350
    'oE' # status flag (error)
2821
 
    's\x00\x00\x00\x10l5:error5:Boom!e' # err struct ('error', 'Boom!')
 
3351
    # err struct ('error', 'Exception', 'Boom!')
 
3352
    's\x00\x00\x00\x1bl5:error9:Exception5:Boom!e'
2822
3353
    'e' # EOM
2823
3354
    )
2824
3355
 
2869
3400
    """
2870
3401
 
2871
3402
    def setUp(self):
2872
 
        tests.TestCase.setUp(self)
 
3403
        super(TestResponseEncoderBufferingProtocolThree, self).setUp()
2873
3404
        self.writes = []
2874
3405
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2875
3406
 
3219
3750
        # encoder.
3220
3751
 
3221
3752
 
 
3753
class Test_SmartClientRequest(tests.TestCase):
 
3754
 
 
3755
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
 
3756
        response_io = StringIO(response)
 
3757
        output = StringIO()
 
3758
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
 
3759
                    fail_at_write=fail_at_write)
 
3760
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3761
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3762
        smart_client = client._SmartClient(client_medium, headers={})
 
3763
        return output, vendor, smart_client
 
3764
 
 
3765
    def make_response(self, args, body=None, body_stream=None):
 
3766
        response_io = StringIO()
 
3767
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
 
3768
            body_stream=body_stream)
 
3769
        responder = protocol.ProtocolThreeResponder(response_io.write)
 
3770
        responder.send_response(response)
 
3771
        return response_io.getvalue()
 
3772
 
 
3773
    def test__call_doesnt_retry_append(self):
 
3774
        response = self.make_response(('appended', '8'))
 
3775
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3776
            fail_at_write=False, response=response)
 
3777
        smart_request = client._SmartClientRequest(smart_client, 'append',
 
3778
            ('foo', ''), body='content\n')
 
3779
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3780
 
 
3781
    def test__call_retries_get_bytes(self):
 
3782
        response = self.make_response(('ok',), 'content\n')
 
3783
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3784
            fail_at_write=False, response=response)
 
3785
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3786
            ('foo',))
 
3787
        response, response_handler = smart_request._call(3)
 
3788
        self.assertEqual(('ok',), response)
 
3789
        self.assertEqual('content\n', response_handler.read_body_bytes())
 
3790
 
 
3791
    def test__call_noretry_get_bytes(self):
 
3792
        debug.debug_flags.add('noretry')
 
3793
        response = self.make_response(('ok',), 'content\n')
 
3794
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3795
            fail_at_write=False, response=response)
 
3796
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3797
            ('foo',))
 
3798
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3799
 
 
3800
    def test__send_no_retry_pipes(self):
 
3801
        client_read, server_write = create_file_pipes()
 
3802
        server_read, client_write = create_file_pipes()
 
3803
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
 
3804
            client_write, base='/')
 
3805
        smart_client = client._SmartClient(client_medium)
 
3806
        smart_request = client._SmartClientRequest(smart_client,
 
3807
            'hello', ())
 
3808
        # Close the server side
 
3809
        server_read.close()
 
3810
        encoder, response_handler = smart_request._construct_protocol(3)
 
3811
        self.assertRaises(errors.ConnectionReset,
 
3812
            smart_request._send_no_retry, encoder)
 
3813
 
 
3814
    def test__send_read_response_sockets(self):
 
3815
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
3816
        listen_sock.bind(('127.0.0.1', 0))
 
3817
        listen_sock.listen(1)
 
3818
        host, port = listen_sock.getsockname()
 
3819
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
 
3820
        client_medium._ensure_connection()
 
3821
        smart_client = client._SmartClient(client_medium)
 
3822
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3823
        # Accept the connection, but don't actually talk to the client.
 
3824
        server_sock, _ = listen_sock.accept()
 
3825
        server_sock.close()
 
3826
        # Sockets buffer and don't really notice that the server has closed the
 
3827
        # connection until we try to read again.
 
3828
        handler = smart_request._send(3)
 
3829
        self.assertRaises(errors.ConnectionReset,
 
3830
            handler.read_response_tuple, expect_body=False)
 
3831
 
 
3832
    def test__send_retries_on_write(self):
 
3833
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3834
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3835
        handler = smart_request._send(3)
 
3836
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3837
                         '\x00\x00\x00\x02de'   # empty headers
 
3838
                         's\x00\x00\x00\tl5:helloee',
 
3839
                         output.getvalue())
 
3840
        self.assertEqual(
 
3841
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3842
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3843
             ('close',),
 
3844
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3845
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3846
            ],
 
3847
            vendor.calls)
 
3848
 
 
3849
    def test__send_doesnt_retry_read_failure(self):
 
3850
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3851
            fail_at_write=False)
 
3852
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3853
        handler = smart_request._send(3)
 
3854
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3855
                         '\x00\x00\x00\x02de'   # empty headers
 
3856
                         's\x00\x00\x00\tl5:helloee',
 
3857
                         output.getvalue())
 
3858
        self.assertEqual(
 
3859
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3860
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3861
            ],
 
3862
            vendor.calls)
 
3863
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
3864
 
 
3865
    def test__send_request_retries_body_stream_if_not_started(self):
 
3866
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3867
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3868
            body_stream=['a', 'b'])
 
3869
        response_handler = smart_request._send(3)
 
3870
        # We connect, get disconnected, and notice before consuming the stream,
 
3871
        # so we try again one time and succeed.
 
3872
        self.assertEqual(
 
3873
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3874
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3875
             ('close',),
 
3876
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3877
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3878
            ],
 
3879
            vendor.calls)
 
3880
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3881
                         '\x00\x00\x00\x02de'   # empty headers
 
3882
                         's\x00\x00\x00\tl5:helloe'
 
3883
                         'b\x00\x00\x00\x01a'
 
3884
                         'b\x00\x00\x00\x01b'
 
3885
                         'e',
 
3886
                         output.getvalue())
 
3887
 
 
3888
    def test__send_request_stops_if_body_started(self):
 
3889
        # We intentionally use the python StringIO so that we can subclass it.
 
3890
        from StringIO import StringIO
 
3891
        response = StringIO()
 
3892
 
 
3893
        class FailAfterFirstWrite(StringIO):
 
3894
            """Allow one 'write' call to pass, fail the rest"""
 
3895
            def __init__(self):
 
3896
                StringIO.__init__(self)
 
3897
                self._first = True
 
3898
 
 
3899
            def write(self, s):
 
3900
                if self._first:
 
3901
                    self._first = False
 
3902
                    return StringIO.write(self, s)
 
3903
                raise IOError(errno.EINVAL, 'invalid file handle')
 
3904
        output = FailAfterFirstWrite()
 
3905
 
 
3906
        vendor = FirstRejectedStringIOSSHVendor(response, output,
 
3907
            fail_at_write=False)
 
3908
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3909
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3910
        smart_client = client._SmartClient(client_medium, headers={})
 
3911
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3912
            body_stream=['a', 'b'])
 
3913
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3914
        # We connect, and manage to get to the point that we start consuming
 
3915
        # the body stream. The next write fails, so we just stop.
 
3916
        self.assertEqual(
 
3917
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3918
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3919
             ('close',),
 
3920
            ],
 
3921
            vendor.calls)
 
3922
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3923
                         '\x00\x00\x00\x02de'   # empty headers
 
3924
                         's\x00\x00\x00\tl5:helloe',
 
3925
                         output.getvalue())
 
3926
 
 
3927
    def test__send_disabled_retry(self):
 
3928
        debug.debug_flags.add('noretry')
 
3929
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3930
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3931
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3932
        self.assertEqual(
 
3933
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3934
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3935
             ('close',),
 
3936
            ],
 
3937
            vendor.calls)
 
3938
 
 
3939
 
3222
3940
class LengthPrefixedBodyDecoder(tests.TestCase):
3223
3941
 
3224
3942
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3517
4235
    def setUp(self):
3518
4236
        super(HTTPTunnellingSmokeTest, self).setUp()
3519
4237
        # We use the VFS layer as part of HTTP tunnelling tests.
3520
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
4238
        self.overrideEnv('BZR_NO_SMART_VFS', None)
3521
4239
 
3522
4240
    def test_smart_http_medium_request_accept_bytes(self):
3523
4241
        medium = FakeHTTPMedium()
3558
4276
        # still work correctly.
3559
4277
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3560
4278
        new_transport = base_transport.clone('c')
3561
 
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
 
4279
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
3562
4280
        self.assertEqual(
3563
4281
            'c/',
3564
4282
            new_transport._client.remote_path_from_transport(new_transport))
3567
4285
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
3568
4286
        r = t._redirected_to('http://www.example.com/foo',
3569
4287
                             'http://www.example.com/bar')
3570
 
        self.assertEquals(type(r), type(t))
 
4288
        self.assertEqual(type(r), type(t))
3571
4289
 
3572
4290
    def test__redirect_sibling_protocol(self):
3573
4291
        t = remote.RemoteHTTPTransport('bzr+http://www.example.com/foo')
3574
4292
        r = t._redirected_to('http://www.example.com/foo',
3575
4293
                             'https://www.example.com/bar')
3576
 
        self.assertEquals(type(r), type(t))
 
4294
        self.assertEqual(type(r), type(t))
3577
4295
        self.assertStartsWith(r.base, 'bzr+https')
3578
4296
 
3579
4297
    def test__redirect_to_with_user(self):
3580
4298
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
3581
4299
        r = t._redirected_to('http://www.example.com/foo',
3582
4300
                             'http://www.example.com/bar')
3583
 
        self.assertEquals(type(r), type(t))
3584
 
        self.assertEquals('joe', t._user)
3585
 
        self.assertEquals(t._user, r._user)
 
4301
        self.assertEqual(type(r), type(t))
 
4302
        self.assertEqual('joe', t._parsed_url.user)
 
4303
        self.assertEqual(t._parsed_url.user, r._parsed_url.user)
3586
4304
 
3587
4305
    def test_redirected_to_same_host_different_protocol(self):
3588
4306
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')
3589
4307
        r = t._redirected_to('http://www.example.com/foo',
3590
4308
                             'ftp://www.example.com/foo')
3591
 
        self.assertNotEquals(type(r), type(t))
 
4309
        self.assertNotEqual(type(r), type(t))
3592
4310
 
3593
4311