~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/sftp.py

  • Committer: Martin Pool
  • Date: 2010-02-23 07:43:11 UTC
  • mfrom: (4797.2.20 2.1)
  • mto: This revision was merged to the branch mainline in revision 5055.
  • Revision ID: mbp@sourcefrog.net-20100223074311-gnj55xdhrgz9l94e
Merge 2.1 back to trunk

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2007, 2008, 2009 Canonical Ltd
 
1
# Copyright (C) 2005-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
28
28
import itertools
29
29
import os
30
30
import random
31
 
import select
32
 
import socket
33
31
import stat
34
32
import sys
35
33
import time
884
882
        else:
885
883
            return True
886
884
 
887
 
# ------------- server test implementation --------------
888
 
import threading
889
 
 
890
 
from bzrlib.tests.stub_sftp import StubServer, StubSFTPServer
891
 
 
892
 
STUB_SERVER_KEY = """
893
 
-----BEGIN RSA PRIVATE KEY-----
894
 
MIICWgIBAAKBgQDTj1bqB4WmayWNPB+8jVSYpZYk80Ujvj680pOTh2bORBjbIAyz
895
 
oWGW+GUjzKxTiiPvVmxFgx5wdsFvF03v34lEVVhMpouqPAYQ15N37K/ir5XY+9m/
896
 
d8ufMCkjeXsQkKqFbAlQcnWMCRnOoPHS3I4vi6hmnDDeeYTSRvfLbW0fhwIBIwKB
897
 
gBIiOqZYaoqbeD9OS9z2K9KR2atlTxGxOJPXiP4ESqP3NVScWNwyZ3NXHpyrJLa0
898
 
EbVtzsQhLn6rF+TzXnOlcipFvjsem3iYzCpuChfGQ6SovTcOjHV9z+hnpXvQ/fon
899
 
soVRZY65wKnF7IAoUwTmJS9opqgrN6kRgCd3DASAMd1bAkEA96SBVWFt/fJBNJ9H
900
 
tYnBKZGw0VeHOYmVYbvMSstssn8un+pQpUm9vlG/bp7Oxd/m+b9KWEh2xPfv6zqU
901
 
avNwHwJBANqzGZa/EpzF4J8pGti7oIAPUIDGMtfIcmqNXVMckrmzQ2vTfqtkEZsA
902
 
4rE1IERRyiJQx6EJsz21wJmGV9WJQ5kCQQDwkS0uXqVdFzgHO6S++tjmjYcxwr3g
903
 
H0CoFYSgbddOT6miqRskOQF3DZVkJT3kyuBgU2zKygz52ukQZMqxCb1fAkASvuTv
904
 
qfpH87Qq5kQhNKdbbwbmd2NxlNabazPijWuphGTdW0VfJdWfklyS2Kr+iqrs/5wV
905
 
HhathJt636Eg7oIjAkA8ht3MQ+XSl9yIJIS8gVpbPxSw5OMfw0PjVE7tBdQruiSc
906
 
nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7
907
 
-----END RSA PRIVATE KEY-----
908
 
"""
909
 
 
910
 
 
911
 
class SocketListener(threading.Thread):
912
 
 
913
 
    def __init__(self, callback):
914
 
        threading.Thread.__init__(self)
915
 
        self._callback = callback
916
 
        self._socket = socket.socket()
917
 
        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
918
 
        self._socket.bind(('localhost', 0))
919
 
        self._socket.listen(1)
920
 
        self.host, self.port = self._socket.getsockname()[:2]
921
 
        self._stop_event = threading.Event()
922
 
 
923
 
    def stop(self):
924
 
        # called from outside this thread
925
 
        self._stop_event.set()
926
 
        # use a timeout here, because if the test fails, the server thread may
927
 
        # never notice the stop_event.
928
 
        self.join(5.0)
929
 
        self._socket.close()
930
 
 
931
 
    def run(self):
932
 
        while True:
933
 
            readable, writable_unused, exception_unused = \
934
 
                select.select([self._socket], [], [], 0.1)
935
 
            if self._stop_event.isSet():
936
 
                return
937
 
            if len(readable) == 0:
938
 
                continue
939
 
            try:
940
 
                s, addr_unused = self._socket.accept()
941
 
                # because the loopback socket is inline, and transports are
942
 
                # never explicitly closed, best to launch a new thread.
943
 
                threading.Thread(target=self._callback, args=(s,)).start()
944
 
            except socket.error, x:
945
 
                sys.excepthook(*sys.exc_info())
946
 
                warning('Socket error during accept() within unit test server'
947
 
                        ' thread: %r' % x)
948
 
            except Exception, x:
949
 
                # probably a failed test; unit test thread will log the
950
 
                # failure/error
951
 
                sys.excepthook(*sys.exc_info())
952
 
                warning('Exception from within unit test server thread: %r' %
953
 
                        x)
954
 
 
955
 
 
956
 
class SocketDelay(object):
957
 
    """A socket decorator to make TCP appear slower.
958
 
 
959
 
    This changes recv, send, and sendall to add a fixed latency to each python
960
 
    call if a new roundtrip is detected. That is, when a recv is called and the
961
 
    flag new_roundtrip is set, latency is charged. Every send and send_all
962
 
    sets this flag.
963
 
 
964
 
    In addition every send, sendall and recv sleeps a bit per character send to
965
 
    simulate bandwidth.
966
 
 
967
 
    Not all methods are implemented, this is deliberate as this class is not a
968
 
    replacement for the builtin sockets layer. fileno is not implemented to
969
 
    prevent the proxy being bypassed.
970
 
    """
971
 
 
972
 
    simulated_time = 0
973
 
    _proxied_arguments = dict.fromkeys([
974
 
        "close", "getpeername", "getsockname", "getsockopt", "gettimeout",
975
 
        "setblocking", "setsockopt", "settimeout", "shutdown"])
976
 
 
977
 
    def __init__(self, sock, latency, bandwidth=1.0,
978
 
                 really_sleep=True):
979
 
        """
980
 
        :param bandwith: simulated bandwith (MegaBit)
981
 
        :param really_sleep: If set to false, the SocketDelay will just
982
 
        increase a counter, instead of calling time.sleep. This is useful for
983
 
        unittesting the SocketDelay.
984
 
        """
985
 
        self.sock = sock
986
 
        self.latency = latency
987
 
        self.really_sleep = really_sleep
988
 
        self.time_per_byte = 1 / (bandwidth / 8.0 * 1024 * 1024)
989
 
        self.new_roundtrip = False
990
 
 
991
 
    def sleep(self, s):
992
 
        if self.really_sleep:
993
 
            time.sleep(s)
994
 
        else:
995
 
            SocketDelay.simulated_time += s
996
 
 
997
 
    def __getattr__(self, attr):
998
 
        if attr in SocketDelay._proxied_arguments:
999
 
            return getattr(self.sock, attr)
1000
 
        raise AttributeError("'SocketDelay' object has no attribute %r" %
1001
 
                             attr)
1002
 
 
1003
 
    def dup(self):
1004
 
        return SocketDelay(self.sock.dup(), self.latency, self.time_per_byte,
1005
 
                           self._sleep)
1006
 
 
1007
 
    def recv(self, *args):
1008
 
        data = self.sock.recv(*args)
1009
 
        if data and self.new_roundtrip:
1010
 
            self.new_roundtrip = False
1011
 
            self.sleep(self.latency)
1012
 
        self.sleep(len(data) * self.time_per_byte)
1013
 
        return data
1014
 
 
1015
 
    def sendall(self, data, flags=0):
1016
 
        if not self.new_roundtrip:
1017
 
            self.new_roundtrip = True
1018
 
            self.sleep(self.latency)
1019
 
        self.sleep(len(data) * self.time_per_byte)
1020
 
        return self.sock.sendall(data, flags)
1021
 
 
1022
 
    def send(self, data, flags=0):
1023
 
        if not self.new_roundtrip:
1024
 
            self.new_roundtrip = True
1025
 
            self.sleep(self.latency)
1026
 
        bytes_sent = self.sock.send(data, flags)
1027
 
        self.sleep(bytes_sent * self.time_per_byte)
1028
 
        return bytes_sent
1029
 
 
1030
 
 
1031
 
class SFTPServer(Server):
1032
 
    """Common code for SFTP server facilities."""
1033
 
 
1034
 
    def __init__(self, server_interface=StubServer):
1035
 
        self._original_vendor = None
1036
 
        self._homedir = None
1037
 
        self._server_homedir = None
1038
 
        self._listener = None
1039
 
        self._root = None
1040
 
        self._vendor = ssh.ParamikoVendor()
1041
 
        self._server_interface = server_interface
1042
 
        # sftp server logs
1043
 
        self.logs = []
1044
 
        self.add_latency = 0
1045
 
 
1046
 
    def _get_sftp_url(self, path):
1047
 
        """Calculate an sftp url to this server for path."""
1048
 
        return 'sftp://foo:bar@%s:%d/%s' % (self._listener.host,
1049
 
                                            self._listener.port, path)
1050
 
 
1051
 
    def log(self, message):
1052
 
        """StubServer uses this to log when a new server is created."""
1053
 
        self.logs.append(message)
1054
 
 
1055
 
    def _run_server_entry(self, sock):
1056
 
        """Entry point for all implementations of _run_server.
1057
 
 
1058
 
        If self.add_latency is > 0.000001 then sock is given a latency adding
1059
 
        decorator.
1060
 
        """
1061
 
        if self.add_latency > 0.000001:
1062
 
            sock = SocketDelay(sock, self.add_latency)
1063
 
        return self._run_server(sock)
1064
 
 
1065
 
    def _run_server(self, s):
1066
 
        ssh_server = paramiko.Transport(s)
1067
 
        key_file = pathjoin(self._homedir, 'test_rsa.key')
1068
 
        f = open(key_file, 'w')
1069
 
        f.write(STUB_SERVER_KEY)
1070
 
        f.close()
1071
 
        host_key = paramiko.RSAKey.from_private_key_file(key_file)
1072
 
        ssh_server.add_server_key(host_key)
1073
 
        server = self._server_interface(self)
1074
 
        ssh_server.set_subsystem_handler('sftp', paramiko.SFTPServer,
1075
 
                                         StubSFTPServer, root=self._root,
1076
 
                                         home=self._server_homedir)
1077
 
        event = threading.Event()
1078
 
        ssh_server.start_server(event, server)
1079
 
        event.wait(5.0)
1080
 
 
1081
 
    def start_server(self, backing_server=None):
1082
 
        # XXX: TODO: make sftpserver back onto backing_server rather than local
1083
 
        # disk.
1084
 
        if not (backing_server is None or
1085
 
                isinstance(backing_server, local.LocalURLServer)):
1086
 
            raise AssertionError(
1087
 
                "backing_server should not be %r, because this can only serve the "
1088
 
                "local current working directory." % (backing_server,))
1089
 
        self._original_vendor = ssh._ssh_vendor_manager._cached_ssh_vendor
1090
 
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._vendor
1091
 
        if sys.platform == 'win32':
1092
 
            # Win32 needs to use the UNICODE api
1093
 
            self._homedir = getcwd()
1094
 
        else:
1095
 
            # But Linux SFTP servers should just deal in bytestreams
1096
 
            self._homedir = os.getcwd()
1097
 
        if self._server_homedir is None:
1098
 
            self._server_homedir = self._homedir
1099
 
        self._root = '/'
1100
 
        if sys.platform == 'win32':
1101
 
            self._root = ''
1102
 
        self._listener = SocketListener(self._run_server_entry)
1103
 
        self._listener.setDaemon(True)
1104
 
        self._listener.start()
1105
 
 
1106
 
    def stop_server(self):
1107
 
        self._listener.stop()
1108
 
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._original_vendor
1109
 
 
1110
 
    def get_bogus_url(self):
1111
 
        """See bzrlib.transport.Server.get_bogus_url."""
1112
 
        # this is chosen to try to prevent trouble with proxies, wierd dns, etc
1113
 
        # we bind a random socket, so that we get a guaranteed unused port
1114
 
        # we just never listen on that port
1115
 
        s = socket.socket()
1116
 
        s.bind(('localhost', 0))
1117
 
        return 'sftp://%s:%s/' % s.getsockname()
1118
 
 
1119
 
 
1120
 
class SFTPFullAbsoluteServer(SFTPServer):
1121
 
    """A test server for sftp transports, using absolute urls and ssh."""
1122
 
 
1123
 
    def get_url(self):
1124
 
        """See bzrlib.transport.Server.get_url."""
1125
 
        homedir = self._homedir
1126
 
        if sys.platform != 'win32':
1127
 
            # Remove the initial '/' on all platforms but win32
1128
 
            homedir = homedir[1:]
1129
 
        return self._get_sftp_url(urlutils.escape(homedir))
1130
 
 
1131
 
 
1132
 
class SFTPServerWithoutSSH(SFTPServer):
1133
 
    """An SFTP server that uses a simple TCP socket pair rather than SSH."""
1134
 
 
1135
 
    def __init__(self):
1136
 
        super(SFTPServerWithoutSSH, self).__init__()
1137
 
        self._vendor = ssh.LoopbackVendor()
1138
 
 
1139
 
    def _run_server(self, sock):
1140
 
        # Re-import these as locals, so that they're still accessible during
1141
 
        # interpreter shutdown (when all module globals get set to None, leading
1142
 
        # to confusing errors like "'NoneType' object has no attribute 'error'".
1143
 
        class FakeChannel(object):
1144
 
            def get_transport(self):
1145
 
                return self
1146
 
            def get_log_channel(self):
1147
 
                return 'paramiko'
1148
 
            def get_name(self):
1149
 
                return '1'
1150
 
            def get_hexdump(self):
1151
 
                return False
1152
 
            def close(self):
1153
 
                pass
1154
 
 
1155
 
        server = paramiko.SFTPServer(
1156
 
            FakeChannel(), 'sftp', StubServer(self), StubSFTPServer,
1157
 
            root=self._root, home=self._server_homedir)
1158
 
        try:
1159
 
            server.start_subsystem(
1160
 
                'sftp', None, ssh.SocketAsChannelAdapter(sock))
1161
 
        except socket.error, e:
1162
 
            if (len(e.args) > 0) and (e.args[0] == errno.EPIPE):
1163
 
                # it's okay for the client to disconnect abruptly
1164
 
                # (bug in paramiko 1.6: it should absorb this exception)
1165
 
                pass
1166
 
            else:
1167
 
                raise
1168
 
        except Exception, e:
1169
 
            # This typically seems to happen during interpreter shutdown, so
1170
 
            # most of the useful ways to report this error are won't work.
1171
 
            # Writing the exception type, and then the text of the exception,
1172
 
            # seems to be the best we can do.
1173
 
            import sys
1174
 
            sys.stderr.write('\nEXCEPTION %r: ' % (e.__class__,))
1175
 
            sys.stderr.write('%s\n\n' % (e,))
1176
 
        server.finish_subsystem()
1177
 
 
1178
 
 
1179
 
class SFTPAbsoluteServer(SFTPServerWithoutSSH):
1180
 
    """A test server for sftp transports, using absolute urls."""
1181
 
 
1182
 
    def get_url(self):
1183
 
        """See bzrlib.transport.Server.get_url."""
1184
 
        homedir = self._homedir
1185
 
        if sys.platform != 'win32':
1186
 
            # Remove the initial '/' on all platforms but win32
1187
 
            homedir = homedir[1:]
1188
 
        return self._get_sftp_url(urlutils.escape(homedir))
1189
 
 
1190
 
 
1191
 
class SFTPHomeDirServer(SFTPServerWithoutSSH):
1192
 
    """A test server for sftp transports, using homedir relative urls."""
1193
 
 
1194
 
    def get_url(self):
1195
 
        """See bzrlib.transport.Server.get_url."""
1196
 
        return self._get_sftp_url("~/")
1197
 
 
1198
 
 
1199
 
class SFTPSiblingAbsoluteServer(SFTPAbsoluteServer):
1200
 
    """A test server for sftp transports where only absolute paths will work.
1201
 
 
1202
 
    It does this by serving from a deeply-nested directory that doesn't exist.
1203
 
    """
1204
 
 
1205
 
    def start_server(self, backing_server=None):
1206
 
        self._server_homedir = '/dev/noone/runs/tests/here'
1207
 
        super(SFTPSiblingAbsoluteServer, self).start_server(backing_server)
1208
 
 
1209
885
 
1210
886
def get_test_permutations():
1211
887
    """Return the permutations to be used in testing."""
1212
 
    return [(SFTPTransport, SFTPAbsoluteServer),
1213
 
            (SFTPTransport, SFTPHomeDirServer),
1214
 
            (SFTPTransport, SFTPSiblingAbsoluteServer),
 
888
    from bzrlib.tests import stub_sftp
 
889
    return [(SFTPTransport, stub_sftp.SFTPAbsoluteServer),
 
890
            (SFTPTransport, stub_sftp.SFTPHomeDirServer),
 
891
            (SFTPTransport, stub_sftp.SFTPSiblingAbsoluteServer),
1215
892
            ]