~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/sftp.py

  • Committer: Martin Pool
  • Date: 2006-06-20 07:55:43 UTC
  • mfrom: (1798 +trunk)
  • mto: This revision was merged to the branch mainline in revision 1799.
  • Revision ID: mbp@sourcefrog.net-20060620075543-b10f6575d4a4fa32
[merge] bzr.dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
22
22
import os
23
23
import random
24
24
import re
 
25
import select
25
26
import stat
26
27
import subprocess
27
28
import sys
46
47
    Server,
47
48
    split_url,
48
49
    Transport,
49
 
    urlescape,
50
50
    )
51
51
import bzrlib.ui
 
52
import bzrlib.urlutils as urlutils
52
53
 
53
54
try:
54
55
    import paramiko
302
303
            # What specific errors should we catch here?
303
304
            pass
304
305
 
 
306
 
305
307
class SFTPTransport (Transport):
306
308
    """
307
309
    Transport implementation for SFTP access.
355
357
        """
356
358
        # FIXME: share the common code across transports
357
359
        assert isinstance(relpath, basestring)
358
 
        relpath = urllib.unquote(relpath).split('/')
 
360
        relpath = urlutils.unescape(relpath).split('/')
359
361
        basepath = self._path.split('/')
360
362
        if len(basepath) > 0 and basepath[-1] == '':
361
363
            basepath = basepath[:-1]
702
704
        vendor = _get_ssh_vendor()
703
705
        if vendor == 'loopback':
704
706
            sock = socket.socket()
705
 
            sock.connect((self._host, self._port))
 
707
            try:
 
708
                sock.connect((self._host, self._port))
 
709
            except socket.error, e:
 
710
                raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
 
711
                                      % (self._host, self._port, e))
706
712
            self._sftp = SFTPClient(LoopbackSFTP(sock))
707
713
        elif vendor != 'none':
708
714
            sock = SFTPSubprocess(self._host, vendor, self._port,
723
729
            t.set_log_channel('bzr.paramiko')
724
730
            t.start_client()
725
731
        except paramiko.SSHException, e:
726
 
            raise ConnectionError('Unable to reach SSH host %s:%d' %
727
 
                                  (self._host, self._port), e)
 
732
            raise ConnectionError('Unable to reach SSH host %s:%s: %s' 
 
733
                                  % (self._host, self._port, e))
728
734
            
729
735
        server_key = t.get_remote_server_key()
730
736
        server_key_hex = paramiko.util.hexify(server_key.get_fingerprint())
883
889
nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7
884
890
-----END RSA PRIVATE KEY-----
885
891
"""
886
 
    
887
 
 
888
 
class SingleListener(threading.Thread):
 
892
 
 
893
 
 
894
class SocketListener(threading.Thread):
889
895
 
890
896
    def __init__(self, callback):
891
897
        threading.Thread.__init__(self)
895
901
        self._socket.bind(('localhost', 0))
896
902
        self._socket.listen(1)
897
903
        self.port = self._socket.getsockname()[1]
898
 
        self.stop_event = threading.Event()
899
 
 
900
 
    def run(self):
901
 
        s, _ = self._socket.accept()
902
 
        # now close the listen socket
903
 
        self._socket.close()
904
 
        try:
905
 
            self._callback(s, self.stop_event)
906
 
        except socket.error:
907
 
            pass #Ignore socket errors
908
 
        except Exception, x:
909
 
            # probably a failed test
910
 
            warning('Exception from within unit test server thread: %r' % x)
 
904
        self._stop_event = threading.Event()
911
905
 
912
906
    def stop(self):
913
 
        self.stop_event.set()
 
907
        # called from outside this thread
 
908
        self._stop_event.set()
914
909
        # use a timeout here, because if the test fails, the server thread may
915
910
        # never notice the stop_event.
916
911
        self.join(5.0)
 
912
        self._socket.close()
 
913
 
 
914
    def run(self):
 
915
        while True:
 
916
            readable, writable_unused, exception_unused = \
 
917
                select.select([self._socket], [], [], 0.1)
 
918
            if self._stop_event.isSet():
 
919
                return
 
920
            if len(readable) == 0:
 
921
                continue
 
922
            try:
 
923
                s, addr_unused = self._socket.accept()
 
924
                # because the loopback socket is inline, and transports are
 
925
                # never explicitly closed, best to launch a new thread.
 
926
                threading.Thread(target=self._callback, args=(s,)).start()
 
927
            except socket.error, x:
 
928
                sys.excepthook(*sys.exc_info())
 
929
                warning('Socket error during accept() within unit test server'
 
930
                        ' thread: %r' % x)
 
931
            except Exception, x:
 
932
                # probably a failed test; unit test thread will log the
 
933
                # failure/error
 
934
                sys.excepthook(*sys.exc_info())
 
935
                warning('Exception from within unit test server thread: %r' % 
 
936
                        x)
917
937
 
918
938
 
919
939
class SFTPServer(Server):
937
957
        """StubServer uses this to log when a new server is created."""
938
958
        self.logs.append(message)
939
959
 
940
 
    def _run_server(self, s, stop_event):
 
960
    def _run_server(self, s):
941
961
        ssh_server = paramiko.Transport(s)
942
962
        key_file = os.path.join(self._homedir, 'test_rsa.key')
943
 
        file(key_file, 'w').write(STUB_SERVER_KEY)
 
963
        f = open(key_file, 'w')
 
964
        f.write(STUB_SERVER_KEY)
 
965
        f.close()
944
966
        host_key = paramiko.RSAKey.from_private_key_file(key_file)
945
967
        ssh_server.add_server_key(host_key)
946
968
        server = StubServer(self)
950
972
        event = threading.Event()
951
973
        ssh_server.start_server(event, server)
952
974
        event.wait(5.0)
953
 
        stop_event.wait(30.0)
954
975
    
955
976
    def setUp(self):
956
977
        global _ssh_vendor
957
978
        self._original_vendor = _ssh_vendor
958
979
        _ssh_vendor = self._vendor
959
 
        self._homedir = os.getcwdu()
 
980
        self._homedir = os.getcwd()
960
981
        if self._server_homedir is None:
961
982
            self._server_homedir = self._homedir
962
983
        self._root = '/'
963
984
        # FIXME WINDOWS: _root should be _server_homedir[0]:/
964
 
        self._listener = SingleListener(self._run_server)
 
985
        self._listener = SocketListener(self._run_server)
965
986
        self._listener.setDaemon(True)
966
987
        self._listener.start()
967
988
 
971
992
        self._listener.stop()
972
993
        _ssh_vendor = self._original_vendor
973
994
 
 
995
    def get_bogus_url(self):
 
996
        """See bzrlib.transport.Server.get_bogus_url."""
 
997
        # this is chosen to try to prevent trouble with proxies, wierd dns,
 
998
        # etc
 
999
        return 'sftp://127.0.0.1:1/'
 
1000
 
 
1001
 
974
1002
 
975
1003
class SFTPFullAbsoluteServer(SFTPServer):
976
1004
    """A test server for sftp transports, using absolute urls and ssh."""
977
1005
 
978
1006
    def get_url(self):
979
1007
        """See bzrlib.transport.Server.get_url."""
980
 
        return self._get_sftp_url(urlescape(self._homedir[1:]))
 
1008
        return self._get_sftp_url(urlutils.escape(self._homedir[1:]))
981
1009
 
982
1010
 
983
1011
class SFTPServerWithoutSSH(SFTPServer):
987
1015
        super(SFTPServerWithoutSSH, self).__init__()
988
1016
        self._vendor = 'loopback'
989
1017
 
990
 
    def _run_server(self, sock, stop_event):
 
1018
    def _run_server(self, sock):
991
1019
        class FakeChannel(object):
992
1020
            def get_transport(self):
993
1021
                return self
1011
1039
 
1012
1040
    def get_url(self):
1013
1041
        """See bzrlib.transport.Server.get_url."""
1014
 
        return self._get_sftp_url(urlescape(self._homedir[1:]))
 
1042
        return self._get_sftp_url(urlutils.escape(self._homedir[1:]))
1015
1043
 
1016
1044
 
1017
1045
class SFTPHomeDirServer(SFTPServerWithoutSSH):