~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_transport.py

  • Committer: Patch Queue Manager
  • Date: 2016-01-15 09:21:49 UTC
  • mfrom: (6606.2.1 autodoc-unicode)
  • Revision ID: pqm@pqm.ubuntu.com-20160115092149-z5f4sfq3jvaz0enb
(vila) Fix autodoc runner when LANG=C. (Jelmer Vernooij)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
18
18
 
19
19
# all of this deals with byte strings so this is safe
20
20
from cStringIO import StringIO
 
21
import doctest
 
22
import errno
21
23
import os
22
24
import socket
 
25
import subprocess
 
26
import sys
23
27
import threading
 
28
import time
 
29
 
 
30
from testtools.matchers import DocTestMatches
24
31
 
25
32
import bzrlib
26
33
from bzrlib import (
27
34
        bzrdir,
 
35
        controldir,
 
36
        debug,
28
37
        errors,
29
38
        osutils,
30
39
        tests,
31
 
        transport,
 
40
        transport as _mod_transport,
32
41
        urlutils,
33
42
        )
34
43
from bzrlib.smart import (
37
46
        message,
38
47
        protocol,
39
48
        request as _mod_request,
40
 
        server,
 
49
        server as _mod_server,
41
50
        vfs,
42
51
)
43
52
from bzrlib.tests import (
 
53
    features,
44
54
    test_smart,
45
55
    test_server,
46
56
    )
53
63
        )
54
64
 
55
65
 
 
66
def create_file_pipes():
 
67
    r, w = os.pipe()
 
68
    # These must be opened without buffering, or we get undefined results
 
69
    rf = os.fdopen(r, 'rb', 0)
 
70
    wf = os.fdopen(w, 'wb', 0)
 
71
    return rf, wf
 
72
 
 
73
 
 
74
def portable_socket_pair():
 
75
    """Return a pair of TCP sockets connected to each other.
 
76
 
 
77
    Unlike socket.socketpair, this should work on Windows.
 
78
    """
 
79
    listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
80
    listen_sock.bind(('127.0.0.1', 0))
 
81
    listen_sock.listen(1)
 
82
    client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
83
    client_sock.connect(listen_sock.getsockname())
 
84
    server_sock, addr = listen_sock.accept()
 
85
    listen_sock.close()
 
86
    return server_sock, client_sock
 
87
 
 
88
 
56
89
class StringIOSSHVendor(object):
57
90
    """A SSH vendor that uses StringIO to buffer writes and answer reads."""
58
91
 
67
100
        return StringIOSSHConnection(self)
68
101
 
69
102
 
 
103
class FirstRejectedStringIOSSHVendor(StringIOSSHVendor):
 
104
    """The first connection will be considered closed.
 
105
 
 
106
    The second connection will succeed normally.
 
107
    """
 
108
 
 
109
    def __init__(self, read_from, write_to, fail_at_write=True):
 
110
        super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
 
111
            write_to)
 
112
        self.fail_at_write = fail_at_write
 
113
        self._first = True
 
114
 
 
115
    def connect_ssh(self, username, password, host, port, command):
 
116
        self.calls.append(('connect_ssh', username, password, host, port,
 
117
            command))
 
118
        if self._first:
 
119
            self._first = False
 
120
            return ClosedSSHConnection(self)
 
121
        return StringIOSSHConnection(self)
 
122
 
 
123
 
70
124
class StringIOSSHConnection(ssh.SSHConnection):
71
125
    """A SSH connection that uses StringIO to buffer writes and answer reads."""
72
126
 
82
136
        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
83
137
 
84
138
 
85
 
class _InvalidHostnameFeature(tests.Feature):
 
139
class ClosedSSHConnection(ssh.SSHConnection):
 
140
    """An SSH connection that just has closed channels."""
 
141
 
 
142
    def __init__(self, vendor):
 
143
        self.vendor = vendor
 
144
 
 
145
    def close(self):
 
146
        self.vendor.calls.append(('close', ))
 
147
 
 
148
    def get_sock_or_pipes(self):
 
149
        # We create matching pipes, and then close the ssh side
 
150
        bzr_read, ssh_write = create_file_pipes()
 
151
        # We always fail when bzr goes to read
 
152
        ssh_write.close()
 
153
        if self.vendor.fail_at_write:
 
154
            # If set, we'll also fail when bzr goes to write
 
155
            ssh_read, bzr_write = create_file_pipes()
 
156
            ssh_read.close()
 
157
        else:
 
158
            bzr_write = self.vendor.write_to
 
159
        return 'pipes', (bzr_read, bzr_write)
 
160
 
 
161
 
 
162
class _InvalidHostnameFeature(features.Feature):
86
163
    """Does 'non_existent.invalid' fail to resolve?
87
164
 
88
165
    RFC 2606 states that .invalid is reserved for invalid domain names, and
177
254
        client_medium._accept_bytes('abc')
178
255
        self.assertEqual('abc', output.getvalue())
179
256
 
 
257
    def test_simple_pipes__accept_bytes_subprocess_closed(self):
 
258
        # It is unfortunate that we have to use Popen for this. However,
 
259
        # os.pipe() does not behave the same as subprocess.Popen().
 
260
        # On Windows, if you use os.pipe() and close the write side,
 
261
        # read.read() hangs. On Linux, read.read() returns the empty string.
 
262
        p = subprocess.Popen([sys.executable, '-c',
 
263
            'import sys\n'
 
264
            'sys.stdout.write(sys.stdin.read(4))\n'
 
265
            'sys.stdout.close()\n'],
 
266
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
267
        client_medium = medium.SmartSimplePipesClientMedium(
 
268
            p.stdout, p.stdin, 'base')
 
269
        client_medium._accept_bytes('abc\n')
 
270
        self.assertEqual('abc', client_medium._read_bytes(3))
 
271
        p.wait()
 
272
        # While writing to the underlying pipe,
 
273
        #   Windows py2.6.6 we get IOError(EINVAL)
 
274
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
275
        # In both cases, it should be wrapped to ConnectionReset
 
276
        self.assertRaises(errors.ConnectionReset,
 
277
                          client_medium._accept_bytes, 'more')
 
278
 
 
279
    def test_simple_pipes__accept_bytes_pipe_closed(self):
 
280
        child_read, client_write = create_file_pipes()
 
281
        client_medium = medium.SmartSimplePipesClientMedium(
 
282
            None, client_write, 'base')
 
283
        client_medium._accept_bytes('abc\n')
 
284
        self.assertEqual('abc\n', child_read.read(4))
 
285
        # While writing to the underlying pipe,
 
286
        #   Windows py2.6.6 we get IOError(EINVAL)
 
287
        #   Lucid py2.6.5, we get IOError(EPIPE)
 
288
        # In both cases, it should be wrapped to ConnectionReset
 
289
        child_read.close()
 
290
        self.assertRaises(errors.ConnectionReset,
 
291
                          client_medium._accept_bytes, 'more')
 
292
 
 
293
    def test_simple_pipes__flush_pipe_closed(self):
 
294
        child_read, client_write = create_file_pipes()
 
295
        client_medium = medium.SmartSimplePipesClientMedium(
 
296
            None, client_write, 'base')
 
297
        client_medium._accept_bytes('abc\n')
 
298
        child_read.close()
 
299
        # Even though the pipe is closed, flush on the write side seems to be a
 
300
        # no-op, rather than a failure.
 
301
        client_medium._flush()
 
302
 
 
303
    def test_simple_pipes__flush_subprocess_closed(self):
 
304
        p = subprocess.Popen([sys.executable, '-c',
 
305
            'import sys\n'
 
306
            'sys.stdout.write(sys.stdin.read(4))\n'
 
307
            'sys.stdout.close()\n'],
 
308
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
309
        client_medium = medium.SmartSimplePipesClientMedium(
 
310
            p.stdout, p.stdin, 'base')
 
311
        client_medium._accept_bytes('abc\n')
 
312
        p.wait()
 
313
        # Even though the child process is dead, flush seems to be a no-op.
 
314
        client_medium._flush()
 
315
 
 
316
    def test_simple_pipes__read_bytes_pipe_closed(self):
 
317
        child_read, client_write = create_file_pipes()
 
318
        client_medium = medium.SmartSimplePipesClientMedium(
 
319
            child_read, client_write, 'base')
 
320
        client_medium._accept_bytes('abc\n')
 
321
        client_write.close()
 
322
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
323
        self.assertEqual('', client_medium._read_bytes(4))
 
324
 
 
325
    def test_simple_pipes__read_bytes_subprocess_closed(self):
 
326
        p = subprocess.Popen([sys.executable, '-c',
 
327
            'import sys\n'
 
328
            'if sys.platform == "win32":\n'
 
329
            '    import msvcrt, os\n'
 
330
            '    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)\n'
 
331
            '    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)\n'
 
332
            'sys.stdout.write(sys.stdin.read(4))\n'
 
333
            'sys.stdout.close()\n'],
 
334
            stdout=subprocess.PIPE, stdin=subprocess.PIPE)
 
335
        client_medium = medium.SmartSimplePipesClientMedium(
 
336
            p.stdout, p.stdin, 'base')
 
337
        client_medium._accept_bytes('abc\n')
 
338
        p.wait()
 
339
        self.assertEqual('abc\n', client_medium._read_bytes(4))
 
340
        self.assertEqual('', client_medium._read_bytes(4))
 
341
 
180
342
    def test_simple_pipes_client_disconnect_does_nothing(self):
181
343
        # calling disconnect does nothing.
182
344
        input = StringIO()
338
500
            ],
339
501
            vendor.calls)
340
502
 
 
503
    def test_ssh_client_repr(self):
 
504
        client_medium = medium.SmartSSHClientMedium(
 
505
            'base', medium.SSHParams("example.com", "4242", "username"))
 
506
        self.assertEquals(
 
507
            "SmartSSHClientMedium(bzr+ssh://username@example.com:4242/)",
 
508
            repr(client_medium))
 
509
 
 
510
    def test_ssh_client_repr_no_port(self):
 
511
        client_medium = medium.SmartSSHClientMedium(
 
512
            'base', medium.SSHParams("example.com", None, "username"))
 
513
        self.assertEquals(
 
514
            "SmartSSHClientMedium(bzr+ssh://username@example.com/)",
 
515
            repr(client_medium))
 
516
 
 
517
    def test_ssh_client_repr_no_username(self):
 
518
        client_medium = medium.SmartSSHClientMedium(
 
519
            'base', medium.SSHParams("example.com", None, None))
 
520
        self.assertEquals(
 
521
            "SmartSSHClientMedium(bzr+ssh://example.com/)",
 
522
            repr(client_medium))
 
523
 
341
524
    def test_ssh_client_ignores_disconnect_when_not_connected(self):
342
525
        # Doing a disconnect on a new (and thus unconnected) SSH medium
343
526
        # does not fail.  It's ok to disconnect an unconnected medium.
564
747
        request.finished_reading()
565
748
        self.assertRaises(errors.ReadingCompleted, request.read_bytes, None)
566
749
 
 
750
    def test_reset(self):
 
751
        server_sock, client_sock = portable_socket_pair()
 
752
        # TODO: Use SmartClientAlreadyConnectedSocketMedium for the versions of
 
753
        #       bzr where it exists.
 
754
        client_medium = medium.SmartTCPClientMedium(None, None, None)
 
755
        client_medium._socket = client_sock
 
756
        client_medium._connected = True
 
757
        req = client_medium.get_request()
 
758
        self.assertRaises(errors.TooManyConcurrentRequests,
 
759
            client_medium.get_request)
 
760
        client_medium.reset()
 
761
        # The stream should be reset, marked as disconnected, though ready for
 
762
        # us to make a new request
 
763
        self.assertFalse(client_medium._connected)
 
764
        self.assertIs(None, client_medium._socket)
 
765
        try:
 
766
            self.assertEqual('', client_sock.recv(1))
 
767
        except socket.error, e:
 
768
            if e.errno not in (errno.EBADF,):
 
769
                raise
 
770
        req = client_medium.get_request()
 
771
 
567
772
 
568
773
class RemoteTransportTests(test_smart.TestCaseWithSmartMedium):
569
774
 
615
820
 
616
821
    def setUp(self):
617
822
        super(TestSmartServerStreamMedium, self).setUp()
618
 
        self._captureVar('BZR_NO_SMART_VFS', None)
619
 
 
620
 
    def portable_socket_pair(self):
621
 
        """Return a pair of TCP sockets connected to each other.
622
 
 
623
 
        Unlike socket.socketpair, this should work on Windows.
624
 
        """
625
 
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
626
 
        listen_sock.bind(('127.0.0.1', 0))
627
 
        listen_sock.listen(1)
628
 
        client_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
629
 
        client_sock.connect(listen_sock.getsockname())
630
 
        server_sock, addr = listen_sock.accept()
631
 
        listen_sock.close()
632
 
        return server_sock, client_sock
 
823
        self.overrideEnv('BZR_NO_SMART_VFS', None)
 
824
 
 
825
    def create_pipe_medium(self, to_server, from_server, transport,
 
826
                           timeout=4.0):
 
827
        """Create a new SmartServerPipeStreamMedium."""
 
828
        return medium.SmartServerPipeStreamMedium(to_server, from_server,
 
829
            transport, timeout=timeout)
 
830
 
 
831
    def create_pipe_context(self, to_server_bytes, transport):
 
832
        """Create a SmartServerSocketStreamMedium.
 
833
 
 
834
        This differes from create_pipe_medium, in that we initialize the
 
835
        request that is sent to the server, and return the StringIO class that
 
836
        will hold the response.
 
837
        """
 
838
        to_server = StringIO(to_server_bytes)
 
839
        from_server = StringIO()
 
840
        m = self.create_pipe_medium(to_server, from_server, transport)
 
841
        return m, from_server
 
842
 
 
843
    def create_socket_medium(self, server_sock, transport, timeout=4.0):
 
844
        """Initialize a new medium.SmartServerSocketStreamMedium."""
 
845
        return medium.SmartServerSocketStreamMedium(server_sock, transport,
 
846
            timeout=timeout)
 
847
 
 
848
    def create_socket_context(self, transport, timeout=4.0):
 
849
        """Create a new SmartServerSocketStreamMedium with default context.
 
850
 
 
851
        This will call portable_socket_pair and pass the server side to
 
852
        create_socket_medium along with transport.
 
853
        It then returns the client_sock and the server.
 
854
        """
 
855
        server_sock, client_sock = portable_socket_pair()
 
856
        server = self.create_socket_medium(server_sock, transport,
 
857
                                           timeout=timeout)
 
858
        return server, client_sock
633
859
 
634
860
    def test_smart_query_version(self):
635
861
        """Feed a canned query version to a server"""
636
862
        # wire-to-wire, using the whole stack
637
 
        to_server = StringIO('hello\n')
638
 
        from_server = StringIO()
639
863
        transport = local.LocalTransport(urlutils.local_path_to_url('/'))
640
 
        server = medium.SmartServerPipeStreamMedium(
641
 
            to_server, from_server, transport)
 
864
        server, from_server = self.create_pipe_context('hello\n', transport)
642
865
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
643
866
                from_server.write)
644
867
        server._serve_one_request(smart_protocol)
648
871
    def test_response_to_canned_get(self):
649
872
        transport = memory.MemoryTransport('memory:///')
650
873
        transport.put_bytes('testfile', 'contents\nof\nfile\n')
651
 
        to_server = StringIO('get\001./testfile\n')
652
 
        from_server = StringIO()
653
 
        server = medium.SmartServerPipeStreamMedium(
654
 
            to_server, from_server, transport)
 
874
        server, from_server = self.create_pipe_context('get\001./testfile\n',
 
875
            transport)
655
876
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
656
877
                from_server.write)
657
878
        server._serve_one_request(smart_protocol)
668
889
        # VFS requests use filenames, not raw UTF-8.
669
890
        hpss_path = urlutils.escape(utf8_filename)
670
891
        transport.put_bytes(utf8_filename, 'contents\nof\nfile\n')
671
 
        to_server = StringIO('get\001' + hpss_path + '\n')
672
 
        from_server = StringIO()
673
 
        server = medium.SmartServerPipeStreamMedium(
674
 
            to_server, from_server, transport)
 
892
        server, from_server = self.create_pipe_context(
 
893
                'get\001' + hpss_path + '\n', transport)
675
894
        smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
676
895
                from_server.write)
677
896
        server._serve_one_request(smart_protocol)
683
902
 
684
903
    def test_pipe_like_stream_with_bulk_data(self):
685
904
        sample_request_bytes = 'command\n9\nbulk datadone\n'
686
 
        to_server = StringIO(sample_request_bytes)
687
 
        from_server = StringIO()
688
 
        server = medium.SmartServerPipeStreamMedium(
689
 
            to_server, from_server, None)
 
905
        server, from_server = self.create_pipe_context(
 
906
            sample_request_bytes, None)
690
907
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
691
908
        server._serve_one_request(sample_protocol)
692
909
        self.assertEqual('', from_server.getvalue())
695
912
 
696
913
    def test_socket_stream_with_bulk_data(self):
697
914
        sample_request_bytes = 'command\n9\nbulk datadone\n'
698
 
        server_sock, client_sock = self.portable_socket_pair()
699
 
        server = medium.SmartServerSocketStreamMedium(
700
 
            server_sock, None)
 
915
        server, client_sock = self.create_socket_context(None)
701
916
        sample_protocol = SampleRequest(expected_bytes=sample_request_bytes)
702
917
        client_sock.sendall(sample_request_bytes)
703
918
        server._serve_one_request(sample_protocol)
704
 
        server_sock.close()
 
919
        server._disconnect_client()
705
920
        self.assertEqual('', client_sock.recv(1))
706
921
        self.assertEqual(sample_request_bytes, sample_protocol.accepted_bytes)
707
922
        self.assertFalse(server.finished)
708
923
 
709
924
    def test_pipe_like_stream_shutdown_detection(self):
710
 
        to_server = StringIO('')
711
 
        from_server = StringIO()
712
 
        server = medium.SmartServerPipeStreamMedium(to_server, from_server, None)
 
925
        server, _ = self.create_pipe_context('', None)
713
926
        server._serve_one_request(SampleRequest('x'))
714
927
        self.assertTrue(server.finished)
715
928
 
716
929
    def test_socket_stream_shutdown_detection(self):
717
 
        server_sock, client_sock = self.portable_socket_pair()
 
930
        server, client_sock = self.create_socket_context(None)
718
931
        client_sock.close()
719
 
        server = medium.SmartServerSocketStreamMedium(
720
 
            server_sock, None)
721
932
        server._serve_one_request(SampleRequest('x'))
722
933
        self.assertTrue(server.finished)
723
934
 
734
945
        rest_of_request_bytes = 'lo\n'
735
946
        expected_response = (
736
947
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n')
737
 
        server_sock, client_sock = self.portable_socket_pair()
738
 
        server = medium.SmartServerSocketStreamMedium(
739
 
            server_sock, None)
 
948
        server, client_sock = self.create_socket_context(None)
740
949
        client_sock.sendall(incomplete_request_bytes)
741
950
        server_protocol = server._build_protocol()
742
951
        client_sock.sendall(rest_of_request_bytes)
743
952
        server._serve_one_request(server_protocol)
744
 
        server_sock.close()
 
953
        server._disconnect_client()
745
954
        self.assertEqual(expected_response, osutils.recv_all(client_sock, 50),
746
955
                         "Not a version 2 response to 'hello' request.")
747
956
        self.assertEqual('', client_sock.recv(1))
766
975
        to_server_w = os.fdopen(to_server_w, 'w', 0)
767
976
        from_server_r = os.fdopen(from_server_r, 'r', 0)
768
977
        from_server = os.fdopen(from_server, 'w', 0)
769
 
        server = medium.SmartServerPipeStreamMedium(
770
 
            to_server, from_server, None)
 
978
        server = self.create_pipe_medium(to_server, from_server, None)
771
979
        # Like test_socket_stream_incomplete_request, write an incomplete
772
980
        # request (that does not end in '\n') and build a protocol from it.
773
981
        to_server_w.write(incomplete_request_bytes)
788
996
        # _serve_one_request should still process both of them as if they had
789
997
        # been received separately.
790
998
        sample_request_bytes = 'command\n'
791
 
        to_server = StringIO(sample_request_bytes * 2)
792
 
        from_server = StringIO()
793
 
        server = medium.SmartServerPipeStreamMedium(
794
 
            to_server, from_server, None)
 
999
        server, from_server = self.create_pipe_context(
 
1000
            sample_request_bytes * 2, None)
795
1001
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
796
1002
        server._serve_one_request(first_protocol)
797
1003
        self.assertEqual(0, first_protocol.next_read_size())
810
1016
        # _serve_one_request should still process both of them as if they had
811
1017
        # been received separately.
812
1018
        sample_request_bytes = 'command\n'
813
 
        server_sock, client_sock = self.portable_socket_pair()
814
 
        server = medium.SmartServerSocketStreamMedium(
815
 
            server_sock, None)
 
1019
        server, client_sock = self.create_socket_context(None)
816
1020
        first_protocol = SampleRequest(expected_bytes=sample_request_bytes)
817
1021
        # Put two whole requests on the wire.
818
1022
        client_sock.sendall(sample_request_bytes * 2)
825
1029
        stream_still_open = server._serve_one_request(second_protocol)
826
1030
        self.assertEqual(sample_request_bytes, second_protocol.accepted_bytes)
827
1031
        self.assertFalse(server.finished)
828
 
        server_sock.close()
 
1032
        server._disconnect_client()
829
1033
        self.assertEqual('', client_sock.recv(1))
830
1034
 
831
1035
    def test_pipe_like_stream_error_handling(self):
838
1042
        def close():
839
1043
            self.closed = True
840
1044
        from_server.close = close
841
 
        server = medium.SmartServerPipeStreamMedium(
 
1045
        server = self.create_pipe_medium(
842
1046
            to_server, from_server, None)
843
1047
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
844
1048
        server._serve_one_request(fake_protocol)
847
1051
        self.assertTrue(server.finished)
848
1052
 
849
1053
    def test_socket_stream_error_handling(self):
850
 
        server_sock, client_sock = self.portable_socket_pair()
851
 
        server = medium.SmartServerSocketStreamMedium(
852
 
            server_sock, None)
 
1054
        server, client_sock = self.create_socket_context(None)
853
1055
        fake_protocol = ErrorRaisingProtocol(Exception('boom'))
854
1056
        server._serve_one_request(fake_protocol)
855
1057
        # recv should not block, because the other end of the socket has been
858
1060
        self.assertTrue(server.finished)
859
1061
 
860
1062
    def test_pipe_like_stream_keyboard_interrupt_handling(self):
861
 
        to_server = StringIO('')
862
 
        from_server = StringIO()
863
 
        server = medium.SmartServerPipeStreamMedium(
864
 
            to_server, from_server, None)
 
1063
        server, from_server = self.create_pipe_context('', None)
865
1064
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
866
1065
        self.assertRaises(
867
1066
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
868
1067
        self.assertEqual('', from_server.getvalue())
869
1068
 
870
1069
    def test_socket_stream_keyboard_interrupt_handling(self):
871
 
        server_sock, client_sock = self.portable_socket_pair()
872
 
        server = medium.SmartServerSocketStreamMedium(
873
 
            server_sock, None)
 
1070
        server, client_sock = self.create_socket_context(None)
874
1071
        fake_protocol = ErrorRaisingProtocol(KeyboardInterrupt('boom'))
875
1072
        self.assertRaises(
876
1073
            KeyboardInterrupt, server._serve_one_request, fake_protocol)
877
 
        server_sock.close()
 
1074
        server._disconnect_client()
878
1075
        self.assertEqual('', client_sock.recv(1))
879
1076
 
880
1077
    def build_protocol_pipe_like(self, bytes):
881
 
        to_server = StringIO(bytes)
882
 
        from_server = StringIO()
883
 
        server = medium.SmartServerPipeStreamMedium(
884
 
            to_server, from_server, None)
 
1078
        server, _ = self.create_pipe_context(bytes, None)
885
1079
        return server._build_protocol()
886
1080
 
887
1081
    def build_protocol_socket(self, bytes):
888
 
        server_sock, client_sock = self.portable_socket_pair()
889
 
        server = medium.SmartServerSocketStreamMedium(
890
 
            server_sock, None)
 
1082
        server, client_sock = self.create_socket_context(None)
891
1083
        client_sock.sendall(bytes)
892
1084
        client_sock.close()
893
1085
        return server._build_protocol()
933
1125
        server_protocol = self.build_protocol_socket('bzr request 2\n')
934
1126
        self.assertProtocolTwo(server_protocol)
935
1127
 
 
1128
    def test__build_protocol_returns_if_stopping(self):
 
1129
        # _build_protocol should notice that we are stopping, and return
 
1130
        # without waiting for bytes from the client.
 
1131
        server, client_sock = self.create_socket_context(None)
 
1132
        server._stop_gracefully()
 
1133
        self.assertIs(None, server._build_protocol())
 
1134
 
 
1135
    def test_socket_set_timeout(self):
 
1136
        server, _ = self.create_socket_context(None, timeout=1.23)
 
1137
        self.assertEqual(1.23, server._client_timeout)
 
1138
 
 
1139
    def test_pipe_set_timeout(self):
 
1140
        server = self.create_pipe_medium(None, None, None,
 
1141
            timeout=1.23)
 
1142
        self.assertEqual(1.23, server._client_timeout)
 
1143
 
 
1144
    def test_socket_wait_for_bytes_with_timeout_with_data(self):
 
1145
        server, client_sock = self.create_socket_context(None)
 
1146
        client_sock.sendall('data\n')
 
1147
        # This should not block or consume any actual content
 
1148
        self.assertFalse(server._wait_for_bytes_with_timeout(0.1))
 
1149
        data = server.read_bytes(5)
 
1150
        self.assertEqual('data\n', data)
 
1151
 
 
1152
    def test_socket_wait_for_bytes_with_timeout_no_data(self):
 
1153
        server, client_sock = self.create_socket_context(None)
 
1154
        # This should timeout quickly, reporting that there wasn't any data
 
1155
        self.assertRaises(errors.ConnectionTimeout,
 
1156
                          server._wait_for_bytes_with_timeout, 0.01)
 
1157
        client_sock.close()
 
1158
        data = server.read_bytes(1)
 
1159
        self.assertEqual('', data)
 
1160
 
 
1161
    def test_socket_wait_for_bytes_with_timeout_closed(self):
 
1162
        server, client_sock = self.create_socket_context(None)
 
1163
        # With the socket closed, this should return right away.
 
1164
        # It seems select.select() returns that you *can* read on the socket,
 
1165
        # even though it closed. Presumably as a way to tell it is closed?
 
1166
        # Testing shows that without sock.close() this times-out failing the
 
1167
        # test, but with it, it returns False immediately.
 
1168
        client_sock.close()
 
1169
        self.assertFalse(server._wait_for_bytes_with_timeout(10))
 
1170
        data = server.read_bytes(1)
 
1171
        self.assertEqual('', data)
 
1172
 
 
1173
    def test_socket_wait_for_bytes_with_shutdown(self):
 
1174
        server, client_sock = self.create_socket_context(None)
 
1175
        t = time.time()
 
1176
        # Override the _timer functionality, so that time never increments,
 
1177
        # this way, we can be sure we stopped because of the flag, and not
 
1178
        # because of a timeout, etc.
 
1179
        server._timer = lambda: t
 
1180
        server._client_poll_timeout = 0.1
 
1181
        server._stop_gracefully()
 
1182
        server._wait_for_bytes_with_timeout(1.0)
 
1183
 
 
1184
    def test_socket_serve_timeout_closes_socket(self):
 
1185
        server, client_sock = self.create_socket_context(None, timeout=0.1)
 
1186
        # This should timeout quickly, and then close the connection so that
 
1187
        # client_sock recv doesn't block.
 
1188
        server.serve()
 
1189
        self.assertEqual('', client_sock.recv(1))
 
1190
 
 
1191
    def test_pipe_wait_for_bytes_with_timeout_with_data(self):
 
1192
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1193
        # You can't select() on a StringIO
 
1194
        (r_server, w_client) = os.pipe()
 
1195
        self.addCleanup(os.close, w_client)
 
1196
        with os.fdopen(r_server, 'rb') as rf_server:
 
1197
            server = self.create_pipe_medium(
 
1198
                rf_server, None, None)
 
1199
            os.write(w_client, 'data\n')
 
1200
            # This should not block or consume any actual content
 
1201
            server._wait_for_bytes_with_timeout(0.1)
 
1202
            data = server.read_bytes(5)
 
1203
            self.assertEqual('data\n', data)
 
1204
 
 
1205
    def test_pipe_wait_for_bytes_with_timeout_no_data(self):
 
1206
        # We intentionally use a real pipe here, so that we can 'select' on it.
 
1207
        # You can't select() on a StringIO
 
1208
        (r_server, w_client) = os.pipe()
 
1209
        # We can't add an os.close cleanup here, because we need to control
 
1210
        # when the file handle gets closed ourselves.
 
1211
        with os.fdopen(r_server, 'rb') as rf_server:
 
1212
            server = self.create_pipe_medium(
 
1213
                rf_server, None, None)
 
1214
            if sys.platform == 'win32':
 
1215
                # Windows cannot select() on a pipe, so we just always return
 
1216
                server._wait_for_bytes_with_timeout(0.01)
 
1217
            else:
 
1218
                self.assertRaises(errors.ConnectionTimeout,
 
1219
                                  server._wait_for_bytes_with_timeout, 0.01)
 
1220
            os.close(w_client)
 
1221
            data = server.read_bytes(5)
 
1222
            self.assertEqual('', data)
 
1223
 
 
1224
    def test_pipe_wait_for_bytes_no_fileno(self):
 
1225
        server, _ = self.create_pipe_context('', None)
 
1226
        # Our file doesn't support polling, so we should always just return
 
1227
        # 'you have data to consume.
 
1228
        server._wait_for_bytes_with_timeout(0.01)
 
1229
 
936
1230
 
937
1231
class TestGetProtocolFactoryForBytes(tests.TestCase):
938
1232
    """_get_protocol_factory_for_bytes identifies the protocol factory a server
968
1262
 
969
1263
class TestSmartTCPServer(tests.TestCase):
970
1264
 
 
1265
    def make_server(self):
 
1266
        """Create a SmartTCPServer that we can exercise.
 
1267
 
 
1268
        Note: we don't use SmartTCPServer_for_testing because the testing
 
1269
        version overrides lots of functionality like 'serve', and we want to
 
1270
        test the raw service.
 
1271
 
 
1272
        This will start the server in another thread, and wait for it to
 
1273
        indicate it has finished starting up.
 
1274
 
 
1275
        :return: (server, server_thread)
 
1276
        """
 
1277
        t = _mod_transport.get_transport_from_url('memory:///')
 
1278
        server = _mod_server.SmartTCPServer(t, client_timeout=4.0)
 
1279
        server._ACCEPT_TIMEOUT = 0.1
 
1280
        # We don't use 'localhost' because that might be an IPv6 address.
 
1281
        server.start_server('127.0.0.1', 0)
 
1282
        server_thread = threading.Thread(target=server.serve,
 
1283
                                         args=(self.id(),))
 
1284
        server_thread.start()
 
1285
        # Ensure this gets called at some point
 
1286
        self.addCleanup(server._stop_gracefully)
 
1287
        server._started.wait()
 
1288
        return server, server_thread
 
1289
 
 
1290
    def ensure_client_disconnected(self, client_sock):
 
1291
        """Ensure that a socket is closed, discarding all errors."""
 
1292
        try:
 
1293
            client_sock.close()
 
1294
        except Exception:
 
1295
            pass
 
1296
 
 
1297
    def connect_to_server(self, server):
 
1298
        """Create a client socket that can talk to the server."""
 
1299
        client_sock = socket.socket()
 
1300
        server_info = server._server_socket.getsockname()
 
1301
        client_sock.connect(server_info)
 
1302
        self.addCleanup(self.ensure_client_disconnected, client_sock)
 
1303
        return client_sock
 
1304
 
 
1305
    def connect_to_server_and_hangup(self, server):
 
1306
        """Connect to the server, and then hang up.
 
1307
        That way it doesn't sit waiting for 'accept()' to timeout.
 
1308
        """
 
1309
        # If the server has already signaled that the socket is closed, we
 
1310
        # don't need to try to connect to it. Not being set, though, the server
 
1311
        # might still close the socket while we try to connect to it. So we
 
1312
        # still have to catch the exception.
 
1313
        if server._stopped.isSet():
 
1314
            return
 
1315
        try:
 
1316
            client_sock = self.connect_to_server(server)
 
1317
            client_sock.close()
 
1318
        except socket.error, e:
 
1319
            # If the server has hung up already, that is fine.
 
1320
            pass
 
1321
 
 
1322
    def say_hello(self, client_sock):
 
1323
        """Send the 'hello' smart RPC, and expect the response."""
 
1324
        client_sock.send('hello\n')
 
1325
        self.assertEqual('ok\x012\n', client_sock.recv(5))
 
1326
 
 
1327
    def shutdown_server_cleanly(self, server, server_thread):
 
1328
        server._stop_gracefully()
 
1329
        self.connect_to_server_and_hangup(server)
 
1330
        server._stopped.wait()
 
1331
        server._fully_stopped.wait()
 
1332
        server_thread.join()
 
1333
 
971
1334
    def test_get_error_unexpected(self):
972
1335
        """Error reported by server with no specific representation"""
973
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1336
        self.overrideEnv('BZR_NO_SMART_VFS', None)
974
1337
        class FlakyTransport(object):
975
1338
            base = 'a_url'
976
1339
            def external_url(self):
991
1354
                                t.get, 'something')
992
1355
        self.assertContainsRe(str(err), 'some random exception')
993
1356
 
 
1357
    def test_propagates_timeout(self):
 
1358
        server = _mod_server.SmartTCPServer(None, client_timeout=1.23)
 
1359
        server_sock, client_sock = portable_socket_pair()
 
1360
        handler = server._make_handler(server_sock)
 
1361
        self.assertEqual(1.23, handler._client_timeout)
 
1362
 
 
1363
    def test_serve_conn_tracks_connections(self):
 
1364
        server = _mod_server.SmartTCPServer(None, client_timeout=4.0)
 
1365
        server_sock, client_sock = portable_socket_pair()
 
1366
        server.serve_conn(server_sock, '-%s' % (self.id(),))
 
1367
        self.assertEqual(1, len(server._active_connections))
 
1368
        # We still want to talk on the connection. Polling should indicate it
 
1369
        # is still active.
 
1370
        server._poll_active_connections()
 
1371
        self.assertEqual(1, len(server._active_connections))
 
1372
        # Closing the socket will end the active thread, and polling will
 
1373
        # notice and remove it from the active set.
 
1374
        client_sock.close()
 
1375
        server._poll_active_connections(0.1)
 
1376
        self.assertEqual(0, len(server._active_connections))
 
1377
 
 
1378
    def test_serve_closes_out_finished_connections(self):
 
1379
        server, server_thread = self.make_server()
 
1380
        # The server is started, connect to it.
 
1381
        client_sock = self.connect_to_server(server)
 
1382
        # We send and receive on the connection, so that we know the
 
1383
        # server-side has seen the connect, and started handling the
 
1384
        # results.
 
1385
        self.say_hello(client_sock)
 
1386
        self.assertEqual(1, len(server._active_connections))
 
1387
        # Grab a handle to the thread that is processing our request
 
1388
        _, server_side_thread = server._active_connections[0]
 
1389
        # Close the connection, ask the server to stop, and wait for the
 
1390
        # server to stop, as well as the thread that was servicing the
 
1391
        # client request.
 
1392
        client_sock.close()
 
1393
        # Wait for the server-side request thread to notice we are closed.
 
1394
        server_side_thread.join()
 
1395
        # Stop the server, it should notice the connection has finished.
 
1396
        self.shutdown_server_cleanly(server, server_thread)
 
1397
        # The server should have noticed that all clients are gone before
 
1398
        # exiting.
 
1399
        self.assertEqual(0, len(server._active_connections))
 
1400
 
 
1401
    def test_serve_reaps_finished_connections(self):
 
1402
        server, server_thread = self.make_server()
 
1403
        client_sock1 = self.connect_to_server(server)
 
1404
        # We send and receive on the connection, so that we know the
 
1405
        # server-side has seen the connect, and started handling the
 
1406
        # results.
 
1407
        self.say_hello(client_sock1)
 
1408
        server_handler1, server_side_thread1 = server._active_connections[0]
 
1409
        client_sock1.close()
 
1410
        server_side_thread1.join()
 
1411
        # By waiting until the first connection is fully done, the server
 
1412
        # should notice after another connection that the first has finished.
 
1413
        client_sock2 = self.connect_to_server(server)
 
1414
        self.say_hello(client_sock2)
 
1415
        server_handler2, server_side_thread2 = server._active_connections[-1]
 
1416
        # There is a race condition. We know that client_sock2 has been
 
1417
        # registered, but not that _poll_active_connections has been called. We
 
1418
        # know that it will be called before the server will accept a new
 
1419
        # connection, however. So connect one more time, and assert that we
 
1420
        # either have 1 or 2 active connections (never 3), and that the 'first'
 
1421
        # connection is not connection 1
 
1422
        client_sock3 = self.connect_to_server(server)
 
1423
        self.say_hello(client_sock3)
 
1424
        # Copy the list, so we don't have it mutating behind our back
 
1425
        conns = list(server._active_connections)
 
1426
        self.assertEqual(2, len(conns))
 
1427
        self.assertNotEqual((server_handler1, server_side_thread1), conns[0])
 
1428
        self.assertEqual((server_handler2, server_side_thread2), conns[0])
 
1429
        client_sock2.close()
 
1430
        client_sock3.close()
 
1431
        self.shutdown_server_cleanly(server, server_thread)
 
1432
 
 
1433
    def test_graceful_shutdown_waits_for_clients_to_stop(self):
 
1434
        server, server_thread = self.make_server()
 
1435
        # We need something big enough that it won't fit in a single recv. So
 
1436
        # the server thread gets blocked writing content to the client until we
 
1437
        # finish reading on the client.
 
1438
        server.backing_transport.put_bytes('bigfile',
 
1439
            'a'*1024*1024)
 
1440
        client_sock = self.connect_to_server(server)
 
1441
        self.say_hello(client_sock)
 
1442
        _, server_side_thread = server._active_connections[0]
 
1443
        # Start the RPC, but don't finish reading the response
 
1444
        client_medium = medium.SmartClientAlreadyConnectedSocketMedium(
 
1445
            'base', client_sock)
 
1446
        client_client = client._SmartClient(client_medium)
 
1447
        resp, response_handler = client_client.call_expecting_body('get',
 
1448
            'bigfile')
 
1449
        self.assertEqual(('ok',), resp)
 
1450
        # Ask the server to stop gracefully, and wait for it.
 
1451
        server._stop_gracefully()
 
1452
        self.connect_to_server_and_hangup(server)
 
1453
        server._stopped.wait()
 
1454
        # It should not be accepting another connection.
 
1455
        self.assertRaises(socket.error, self.connect_to_server, server)
 
1456
        response_handler.read_body_bytes()
 
1457
        client_sock.close()
 
1458
        server_side_thread.join()
 
1459
        server_thread.join()
 
1460
        self.assertTrue(server._fully_stopped.isSet())
 
1461
        log = self.get_log()
 
1462
        self.assertThat(log, DocTestMatches("""\
 
1463
    INFO  Requested to stop gracefully
 
1464
... Stopping SmartServerSocketStreamMedium(client=('127.0.0.1', ...
 
1465
    INFO  Waiting for 1 client(s) to finish
 
1466
""", flags=doctest.ELLIPSIS|doctest.REPORT_UDIFF))
 
1467
 
 
1468
    def test_stop_gracefully_tells_handlers_to_stop(self):
 
1469
        server, server_thread = self.make_server()
 
1470
        client_sock = self.connect_to_server(server)
 
1471
        self.say_hello(client_sock)
 
1472
        server_handler, server_side_thread = server._active_connections[0]
 
1473
        self.assertFalse(server_handler.finished)
 
1474
        server._stop_gracefully()
 
1475
        self.assertTrue(server_handler.finished)
 
1476
        client_sock.close()
 
1477
        self.connect_to_server_and_hangup(server)
 
1478
        server_thread.join()
 
1479
 
994
1480
 
995
1481
class SmartTCPTests(tests.TestCase):
996
1482
    """Tests for connection/end to end behaviour using the TCP server.
1014
1500
            mem_server.start_server()
1015
1501
            self.addCleanup(mem_server.stop_server)
1016
1502
            self.permit_url(mem_server.get_url())
1017
 
            self.backing_transport = transport.get_transport(
 
1503
            self.backing_transport = _mod_transport.get_transport_from_url(
1018
1504
                mem_server.get_url())
1019
1505
        else:
1020
1506
            self.backing_transport = backing_transport
1021
1507
        if readonly:
1022
1508
            self.real_backing_transport = self.backing_transport
1023
 
            self.backing_transport = transport.get_transport(
 
1509
            self.backing_transport = _mod_transport.get_transport_from_url(
1024
1510
                "readonly+" + self.backing_transport.abspath('.'))
1025
 
        self.server = server.SmartTCPServer(self.backing_transport)
 
1511
        self.server = _mod_server.SmartTCPServer(self.backing_transport,
 
1512
                                                 client_timeout=4.0)
1026
1513
        self.server.start_server('127.0.0.1', 0)
1027
1514
        self.server.start_background_thread('-' + self.id())
1028
1515
        self.transport = remote.RemoteTCPTransport(self.server.get_url())
1077
1564
 
1078
1565
    def test_smart_transport_has(self):
1079
1566
        """Checking for file existence over smart."""
1080
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1567
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1081
1568
        self.backing_transport.put_bytes("foo", "contents of foo\n")
1082
1569
        self.assertTrue(self.transport.has("foo"))
1083
1570
        self.assertFalse(self.transport.has("non-foo"))
1084
1571
 
1085
1572
    def test_smart_transport_get(self):
1086
1573
        """Read back a file over smart."""
1087
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1574
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1088
1575
        self.backing_transport.put_bytes("foo", "contents\nof\nfoo\n")
1089
1576
        fp = self.transport.get("foo")
1090
1577
        self.assertEqual('contents\nof\nfoo\n', fp.read())
1094
1581
        # The path in a raised NoSuchFile exception should be the precise path
1095
1582
        # asked for by the client. This gives meaningful and unsurprising errors
1096
1583
        # for users.
1097
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1584
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1098
1585
        err = self.assertRaises(
1099
1586
            errors.NoSuchFile, self.transport.get, 'not%20a%20file')
1100
1587
        self.assertSubset([err.path], ['not%20a%20file', './not%20a%20file'])
1119
1606
 
1120
1607
    def test_open_dir(self):
1121
1608
        """Test changing directory"""
1122
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1609
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1123
1610
        transport = self.transport
1124
1611
        self.backing_transport.mkdir('toffee')
1125
1612
        self.backing_transport.mkdir('toffee/apple')
1139
1626
        transport = self.transport
1140
1627
        t = self.backing_transport
1141
1628
        bzrdir.BzrDirFormat.get_default_format().initialize_on_transport(t)
1142
 
        result_dir = bzrdir.BzrDir.open_containing_from_transport(transport)
 
1629
        result_dir = controldir.ControlDir.open_containing_from_transport(
 
1630
            transport)
1143
1631
 
1144
1632
 
1145
1633
class ReadOnlyEndToEndTests(SmartTCPTests):
1147
1635
 
1148
1636
    def test_mkdir_error_readonly(self):
1149
1637
        """TransportNotPossible should be preserved from the backing transport."""
1150
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1638
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1151
1639
        self.start_server(readonly=True)
1152
1640
        self.assertRaises(errors.TransportNotPossible, self.transport.mkdir,
1153
1641
            'foo')
1162
1650
    def test_server_started_hook_memory(self):
1163
1651
        """The server_started hook fires when the server is started."""
1164
1652
        self.hook_calls = []
1165
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1653
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1166
1654
            self.capture_server_call, None)
1167
1655
        self.start_server()
1168
1656
        # at this point, the server will be starting a thread up.
1176
1664
    def test_server_started_hook_file(self):
1177
1665
        """The server_started hook fires when the server is started."""
1178
1666
        self.hook_calls = []
1179
 
        server.SmartTCPServer.hooks.install_named_hook('server_started',
 
1667
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_started',
1180
1668
            self.capture_server_call, None)
1181
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1669
        self.start_server(
 
1670
            backing_transport=_mod_transport.get_transport_from_path("."))
1182
1671
        # at this point, the server will be starting a thread up.
1183
1672
        # there is no indicator at the moment, so bodge it by doing a request.
1184
1673
        self.transport.has('.')
1192
1681
    def test_server_stopped_hook_simple_memory(self):
1193
1682
        """The server_stopped hook fires when the server is stopped."""
1194
1683
        self.hook_calls = []
1195
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1684
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1196
1685
            self.capture_server_call, None)
1197
1686
        self.start_server()
1198
1687
        result = [([self.backing_transport.base], self.transport.base)]
1209
1698
    def test_server_stopped_hook_simple_file(self):
1210
1699
        """The server_stopped hook fires when the server is stopped."""
1211
1700
        self.hook_calls = []
1212
 
        server.SmartTCPServer.hooks.install_named_hook('server_stopped',
 
1701
        _mod_server.SmartTCPServer.hooks.install_named_hook('server_stopped',
1213
1702
            self.capture_server_call, None)
1214
 
        self.start_server(backing_transport=transport.get_transport("."))
 
1703
        self.start_server(
 
1704
            backing_transport=_mod_transport.get_transport_from_path("."))
1215
1705
        result = [(
1216
1706
            [self.backing_transport.base, self.backing_transport.external_url()]
1217
1707
            , self.transport.base)]
1261
1751
 
1262
1752
    def setUp(self):
1263
1753
        super(SmartServerRequestHandlerTests, self).setUp()
1264
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1754
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1265
1755
 
1266
1756
    def build_handler(self, transport):
1267
1757
        """Returns a handler for the commands in protocol version one."""
1286
1776
        handler = vfs.HasRequest(None, '/')
1287
1777
        # set environment variable after construction to make sure it's
1288
1778
        # examined.
1289
 
        # Note that we can safely clobber BZR_NO_SMART_VFS here, because setUp
1290
 
        # has called _captureVar, so it will be restored to the right state
1291
 
        # afterwards.
1292
 
        os.environ['BZR_NO_SMART_VFS'] = ''
 
1779
        self.overrideEnv('BZR_NO_SMART_VFS', '')
1293
1780
        self.assertRaises(errors.DisabledMethod, handler.execute)
1294
1781
 
1295
1782
    def test_readonly_exception_becomes_transport_not_possible(self):
1356
1843
class RemoteTransportRegistration(tests.TestCase):
1357
1844
 
1358
1845
    def test_registration(self):
1359
 
        t = transport.get_transport('bzr+ssh://example.com/path')
 
1846
        t = _mod_transport.get_transport_from_url('bzr+ssh://example.com/path')
1360
1847
        self.assertIsInstance(t, remote.RemoteSSHTransport)
1361
 
        self.assertEqual('example.com', t._host)
 
1848
        self.assertEqual('example.com', t._parsed_url.host)
1362
1849
 
1363
1850
    def test_bzr_https(self):
1364
1851
        # https://bugs.launchpad.net/bzr/+bug/128456
1365
 
        t = transport.get_transport('bzr+https://example.com/path')
 
1852
        t = _mod_transport.get_transport_from_url('bzr+https://example.com/path')
1366
1853
        self.assertIsInstance(t, remote.RemoteHTTPTransport)
1367
1854
        self.assertStartsWith(
1368
1855
            t._http_transport.base,
1496
1983
        smart_protocol._has_dispatched = True
1497
1984
        smart_protocol.request = _mod_request.SmartServerRequestHandler(
1498
1985
            None, _mod_request.request_handlers, '/')
 
1986
        # GZ 2010-08-10: Cycle with closure affects 4 tests
1499
1987
        class FakeCommand(_mod_request.SmartServerRequest):
1500
1988
            def do_body(self_cmd, body_bytes):
1501
1989
                self.end_received = True
1602
2090
        self.assertTrue(self.end_received)
1603
2091
 
1604
2092
    def test_accept_request_and_body_all_at_once(self):
1605
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
2093
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1606
2094
        mem_transport = memory.MemoryTransport()
1607
2095
        mem_transport.put_bytes('foo', 'abcdefghij')
1608
2096
        out_stream = StringIO()
1868
2356
        self.assertTrue(self.end_received)
1869
2357
 
1870
2358
    def test_accept_request_and_body_all_at_once(self):
1871
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
2359
        self.overrideEnv('BZR_NO_SMART_VFS', None)
1872
2360
        mem_transport = memory.MemoryTransport()
1873
2361
        mem_transport.put_bytes('foo', 'abcdefghij')
1874
2362
        out_stream = StringIO()
2415
2903
        self.assertEqual('aaa', stream.next())
2416
2904
        self.assertEqual('bbb', stream.next())
2417
2905
        exc = self.assertRaises(errors.ErrorFromSmartServer, stream.next)
2418
 
        self.assertEqual(('error', 'Boom!'), exc.error_tuple)
 
2906
        self.assertEqual(('error', 'Exception', 'Boom!'), exc.error_tuple)
2419
2907
 
2420
2908
    def test_interrupted_by_connection_lost(self):
2421
2909
        interrupted_body_stream = (
2544
3032
        from_server = StringIO()
2545
3033
        transport = memory.MemoryTransport('memory:///')
2546
3034
        server = medium.SmartServerPipeStreamMedium(
2547
 
            to_server, from_server, transport)
 
3035
            to_server, from_server, transport, timeout=4.0)
2548
3036
        proto = server._build_protocol()
2549
3037
        message_handler = proto.message_handler
2550
3038
        server._serve_one_request(proto)
2795
3283
            'e', # end
2796
3284
            output.getvalue())
2797
3285
 
 
3286
    def test_records_start_of_body_stream(self):
 
3287
        requester, output = self.make_client_encoder_and_output()
 
3288
        requester.set_headers({})
 
3289
        in_stream = [False]
 
3290
        def stream_checker():
 
3291
            self.assertTrue(requester.body_stream_started)
 
3292
            in_stream[0] = True
 
3293
            yield 'content'
 
3294
        flush_called = []
 
3295
        orig_flush = requester.flush
 
3296
        def tracked_flush():
 
3297
            flush_called.append(in_stream[0])
 
3298
            if in_stream[0]:
 
3299
                self.assertTrue(requester.body_stream_started)
 
3300
            else:
 
3301
                self.assertFalse(requester.body_stream_started)
 
3302
            return orig_flush()
 
3303
        requester.flush = tracked_flush
 
3304
        requester.call_with_body_stream(('one arg',), stream_checker())
 
3305
        self.assertEqual(
 
3306
            'bzr message 3 (bzr 1.6)\n' # protocol version
 
3307
            '\x00\x00\x00\x02de' # headers
 
3308
            's\x00\x00\x00\x0bl7:one arge' # args
 
3309
            'b\x00\x00\x00\x07content' # body
 
3310
            'e', output.getvalue())
 
3311
        self.assertEqual([False, True, True], flush_called)
 
3312
 
2798
3313
 
2799
3314
class StubMediumRequest(object):
2800
3315
    """A stub medium request that tracks the number of times accept_bytes is
2818
3333
    'b\x00\x00\x00\x03aaa' # body part ('aaa')
2819
3334
    'b\x00\x00\x00\x03bbb' # body part ('bbb')
2820
3335
    'oE' # status flag (error)
2821
 
    's\x00\x00\x00\x10l5:error5:Boom!e' # err struct ('error', 'Boom!')
 
3336
    # err struct ('error', 'Exception', 'Boom!')
 
3337
    's\x00\x00\x00\x1bl5:error9:Exception5:Boom!e'
2822
3338
    'e' # EOM
2823
3339
    )
2824
3340
 
2869
3385
    """
2870
3386
 
2871
3387
    def setUp(self):
2872
 
        tests.TestCase.setUp(self)
 
3388
        super(TestResponseEncoderBufferingProtocolThree, self).setUp()
2873
3389
        self.writes = []
2874
3390
        self.responder = protocol.ProtocolThreeResponder(self.writes.append)
2875
3391
 
3219
3735
        # encoder.
3220
3736
 
3221
3737
 
 
3738
class Test_SmartClientRequest(tests.TestCase):
 
3739
 
 
3740
    def make_client_with_failing_medium(self, fail_at_write=True, response=''):
 
3741
        response_io = StringIO(response)
 
3742
        output = StringIO()
 
3743
        vendor = FirstRejectedStringIOSSHVendor(response_io, output,
 
3744
                    fail_at_write=fail_at_write)
 
3745
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3746
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3747
        smart_client = client._SmartClient(client_medium, headers={})
 
3748
        return output, vendor, smart_client
 
3749
 
 
3750
    def make_response(self, args, body=None, body_stream=None):
 
3751
        response_io = StringIO()
 
3752
        response = _mod_request.SuccessfulSmartServerResponse(args, body=body,
 
3753
            body_stream=body_stream)
 
3754
        responder = protocol.ProtocolThreeResponder(response_io.write)
 
3755
        responder.send_response(response)
 
3756
        return response_io.getvalue()
 
3757
 
 
3758
    def test__call_doesnt_retry_append(self):
 
3759
        response = self.make_response(('appended', '8'))
 
3760
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3761
            fail_at_write=False, response=response)
 
3762
        smart_request = client._SmartClientRequest(smart_client, 'append',
 
3763
            ('foo', ''), body='content\n')
 
3764
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3765
 
 
3766
    def test__call_retries_get_bytes(self):
 
3767
        response = self.make_response(('ok',), 'content\n')
 
3768
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3769
            fail_at_write=False, response=response)
 
3770
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3771
            ('foo',))
 
3772
        response, response_handler = smart_request._call(3)
 
3773
        self.assertEqual(('ok',), response)
 
3774
        self.assertEqual('content\n', response_handler.read_body_bytes())
 
3775
 
 
3776
    def test__call_noretry_get_bytes(self):
 
3777
        debug.debug_flags.add('noretry')
 
3778
        response = self.make_response(('ok',), 'content\n')
 
3779
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3780
            fail_at_write=False, response=response)
 
3781
        smart_request = client._SmartClientRequest(smart_client, 'get',
 
3782
            ('foo',))
 
3783
        self.assertRaises(errors.ConnectionReset, smart_request._call, 3)
 
3784
 
 
3785
    def test__send_no_retry_pipes(self):
 
3786
        client_read, server_write = create_file_pipes()
 
3787
        server_read, client_write = create_file_pipes()
 
3788
        client_medium = medium.SmartSimplePipesClientMedium(client_read,
 
3789
            client_write, base='/')
 
3790
        smart_client = client._SmartClient(client_medium)
 
3791
        smart_request = client._SmartClientRequest(smart_client,
 
3792
            'hello', ())
 
3793
        # Close the server side
 
3794
        server_read.close()
 
3795
        encoder, response_handler = smart_request._construct_protocol(3)
 
3796
        self.assertRaises(errors.ConnectionReset,
 
3797
            smart_request._send_no_retry, encoder)
 
3798
 
 
3799
    def test__send_read_response_sockets(self):
 
3800
        listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
3801
        listen_sock.bind(('127.0.0.1', 0))
 
3802
        listen_sock.listen(1)
 
3803
        host, port = listen_sock.getsockname()
 
3804
        client_medium = medium.SmartTCPClientMedium(host, port, '/')
 
3805
        client_medium._ensure_connection()
 
3806
        smart_client = client._SmartClient(client_medium)
 
3807
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3808
        # Accept the connection, but don't actually talk to the client.
 
3809
        server_sock, _ = listen_sock.accept()
 
3810
        server_sock.close()
 
3811
        # Sockets buffer and don't really notice that the server has closed the
 
3812
        # connection until we try to read again.
 
3813
        handler = smart_request._send(3)
 
3814
        self.assertRaises(errors.ConnectionReset,
 
3815
            handler.read_response_tuple, expect_body=False)
 
3816
 
 
3817
    def test__send_retries_on_write(self):
 
3818
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3819
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3820
        handler = smart_request._send(3)
 
3821
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3822
                         '\x00\x00\x00\x02de'   # empty headers
 
3823
                         's\x00\x00\x00\tl5:helloee',
 
3824
                         output.getvalue())
 
3825
        self.assertEqual(
 
3826
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3827
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3828
             ('close',),
 
3829
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3830
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3831
            ],
 
3832
            vendor.calls)
 
3833
 
 
3834
    def test__send_doesnt_retry_read_failure(self):
 
3835
        output, vendor, smart_client = self.make_client_with_failing_medium(
 
3836
            fail_at_write=False)
 
3837
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3838
        handler = smart_request._send(3)
 
3839
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3840
                         '\x00\x00\x00\x02de'   # empty headers
 
3841
                         's\x00\x00\x00\tl5:helloee',
 
3842
                         output.getvalue())
 
3843
        self.assertEqual(
 
3844
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3845
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3846
            ],
 
3847
            vendor.calls)
 
3848
        self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
3849
 
 
3850
    def test__send_request_retries_body_stream_if_not_started(self):
 
3851
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3852
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3853
            body_stream=['a', 'b'])
 
3854
        response_handler = smart_request._send(3)
 
3855
        # We connect, get disconnected, and notice before consuming the stream,
 
3856
        # so we try again one time and succeed.
 
3857
        self.assertEqual(
 
3858
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3859
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3860
             ('close',),
 
3861
             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3862
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3863
            ],
 
3864
            vendor.calls)
 
3865
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3866
                         '\x00\x00\x00\x02de'   # empty headers
 
3867
                         's\x00\x00\x00\tl5:helloe'
 
3868
                         'b\x00\x00\x00\x01a'
 
3869
                         'b\x00\x00\x00\x01b'
 
3870
                         'e',
 
3871
                         output.getvalue())
 
3872
 
 
3873
    def test__send_request_stops_if_body_started(self):
 
3874
        # We intentionally use the python StringIO so that we can subclass it.
 
3875
        from StringIO import StringIO
 
3876
        response = StringIO()
 
3877
 
 
3878
        class FailAfterFirstWrite(StringIO):
 
3879
            """Allow one 'write' call to pass, fail the rest"""
 
3880
            def __init__(self):
 
3881
                StringIO.__init__(self)
 
3882
                self._first = True
 
3883
 
 
3884
            def write(self, s):
 
3885
                if self._first:
 
3886
                    self._first = False
 
3887
                    return StringIO.write(self, s)
 
3888
                raise IOError(errno.EINVAL, 'invalid file handle')
 
3889
        output = FailAfterFirstWrite()
 
3890
 
 
3891
        vendor = FirstRejectedStringIOSSHVendor(response, output,
 
3892
            fail_at_write=False)
 
3893
        ssh_params = medium.SSHParams('a host', 'a port', 'a user', 'a pass')
 
3894
        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
 
3895
        smart_client = client._SmartClient(client_medium, headers={})
 
3896
        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
 
3897
            body_stream=['a', 'b'])
 
3898
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3899
        # We connect, and manage to get to the point that we start consuming
 
3900
        # the body stream. The next write fails, so we just stop.
 
3901
        self.assertEqual(
 
3902
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3903
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3904
             ('close',),
 
3905
            ],
 
3906
            vendor.calls)
 
3907
        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
 
3908
                         '\x00\x00\x00\x02de'   # empty headers
 
3909
                         's\x00\x00\x00\tl5:helloe',
 
3910
                         output.getvalue())
 
3911
 
 
3912
    def test__send_disabled_retry(self):
 
3913
        debug.debug_flags.add('noretry')
 
3914
        output, vendor, smart_client = self.make_client_with_failing_medium()
 
3915
        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
 
3916
        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
 
3917
        self.assertEqual(
 
3918
            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
 
3919
              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
 
3920
             ('close',),
 
3921
            ],
 
3922
            vendor.calls)
 
3923
 
 
3924
 
3222
3925
class LengthPrefixedBodyDecoder(tests.TestCase):
3223
3926
 
3224
3927
    # XXX: TODO: make accept_reading_trailer invoke translate_response or
3517
4220
    def setUp(self):
3518
4221
        super(HTTPTunnellingSmokeTest, self).setUp()
3519
4222
        # We use the VFS layer as part of HTTP tunnelling tests.
3520
 
        self._captureVar('BZR_NO_SMART_VFS', None)
 
4223
        self.overrideEnv('BZR_NO_SMART_VFS', None)
3521
4224
 
3522
4225
    def test_smart_http_medium_request_accept_bytes(self):
3523
4226
        medium = FakeHTTPMedium()
3558
4261
        # still work correctly.
3559
4262
        base_transport = remote.RemoteHTTPTransport('bzr+http://host/%7Ea/b')
3560
4263
        new_transport = base_transport.clone('c')
3561
 
        self.assertEqual('bzr+http://host/~a/b/c/', new_transport.base)
 
4264
        self.assertEqual(base_transport.base + 'c/', new_transport.base)
3562
4265
        self.assertEqual(
3563
4266
            'c/',
3564
4267
            new_transport._client.remote_path_from_transport(new_transport))
3581
4284
        r = t._redirected_to('http://www.example.com/foo',
3582
4285
                             'http://www.example.com/bar')
3583
4286
        self.assertEquals(type(r), type(t))
3584
 
        self.assertEquals('joe', t._user)
3585
 
        self.assertEquals(t._user, r._user)
 
4287
        self.assertEquals('joe', t._parsed_url.user)
 
4288
        self.assertEquals(t._parsed_url.user, r._parsed_url.user)
3586
4289
 
3587
4290
    def test_redirected_to_same_host_different_protocol(self):
3588
4291
        t = remote.RemoteHTTPTransport('bzr+http://joe@www.example.com/foo')