~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/ssh.py

  • Committer: mbp at sourcefrog
  • Date: 2005-03-21 22:29:49 UTC
  • Revision ID: mbp@sourcefrog.net-20050321222949-232c2093a6eadd80
fixup doctest for new module structure

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
2
 
# Copyright (C) 2005, 2006 Canonical Ltd
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
 
18
 
"""Foundation SSH support for SFTP and smart server."""
19
 
 
20
 
import errno
21
 
import getpass
22
 
import os
23
 
import socket
24
 
import subprocess
25
 
import sys
26
 
 
27
 
from bzrlib.config import config_dir, ensure_config_dir_exists
28
 
from bzrlib.errors import (ConnectionError,
29
 
                           ParamikoNotPresent,
30
 
                           TransportError,
31
 
                           UnknownSSH,
32
 
                           )
33
 
 
34
 
from bzrlib.osutils import pathjoin
35
 
from bzrlib.trace import mutter, warning
36
 
import bzrlib.ui
37
 
 
38
 
try:
39
 
    import paramiko
40
 
except ImportError, e:
41
 
    raise ParamikoNotPresent(e)
42
 
else:
43
 
    from paramiko.sftp_client import SFTPClient
44
 
 
45
 
 
46
 
SYSTEM_HOSTKEYS = {}
47
 
BZR_HOSTKEYS = {}
48
 
 
49
 
 
50
 
_paramiko_version = getattr(paramiko, '__version_info__', (0, 0, 0))
51
 
 
52
 
# Paramiko 1.5 tries to open a socket.AF_UNIX in order to connect
53
 
# to ssh-agent. That attribute doesn't exist on win32 (it does in cygwin)
54
 
# so we get an AttributeError exception. So we will not try to
55
 
# connect to an agent if we are on win32 and using Paramiko older than 1.6
56
 
_use_ssh_agent = (sys.platform != 'win32' or _paramiko_version >= (1, 6, 0))
57
 
 
58
 
_ssh_vendors = {}
59
 
 
60
 
def register_ssh_vendor(name, vendor):
61
 
    """Register SSH vendor."""
62
 
    _ssh_vendors[name] = vendor
63
 
 
64
 
    
65
 
_ssh_vendor = None
66
 
def _get_ssh_vendor():
67
 
    """Find out what version of SSH is on the system."""
68
 
    global _ssh_vendor
69
 
    if _ssh_vendor is not None:
70
 
        return _ssh_vendor
71
 
 
72
 
    if 'BZR_SSH' in os.environ:
73
 
        vendor_name = os.environ['BZR_SSH']
74
 
        try:
75
 
            _ssh_vendor = _ssh_vendors[vendor_name]
76
 
        except KeyError:
77
 
            raise UnknownSSH(vendor_name)
78
 
        return _ssh_vendor
79
 
 
80
 
    try:
81
 
        p = subprocess.Popen(['ssh', '-V'],
82
 
                             stdin=subprocess.PIPE,
83
 
                             stdout=subprocess.PIPE,
84
 
                             stderr=subprocess.PIPE,
85
 
                             **os_specific_subprocess_params())
86
 
        returncode = p.returncode
87
 
        stdout, stderr = p.communicate()
88
 
    except OSError:
89
 
        returncode = -1
90
 
        stdout = stderr = ''
91
 
    if 'OpenSSH' in stderr:
92
 
        mutter('ssh implementation is OpenSSH')
93
 
        _ssh_vendor = OpenSSHSubprocessVendor()
94
 
    elif 'SSH Secure Shell' in stderr:
95
 
        mutter('ssh implementation is SSH Corp.')
96
 
        _ssh_vendor = SSHCorpSubprocessVendor()
97
 
 
98
 
    if _ssh_vendor is not None:
99
 
        return _ssh_vendor
100
 
 
101
 
    # XXX: 20051123 jamesh
102
 
    # A check for putty's plink or lsh would go here.
103
 
 
104
 
    mutter('falling back to paramiko implementation')
105
 
    _ssh_vendor = ParamikoVendor()
106
 
    return _ssh_vendor
107
 
 
108
 
 
109
 
 
110
 
def _ignore_sigint():
111
 
    # TODO: This should possibly ignore SIGHUP as well, but bzr currently
112
 
    # doesn't handle it itself.
113
 
    # <https://launchpad.net/products/bzr/+bug/41433/+index>
114
 
    import signal
115
 
    signal.signal(signal.SIGINT, signal.SIG_IGN)
116
 
    
117
 
 
118
 
 
119
 
class LoopbackSFTP(object):
120
 
    """Simple wrapper for a socket that pretends to be a paramiko Channel."""
121
 
 
122
 
    def __init__(self, sock):
123
 
        self.__socket = sock
124
 
 
125
 
    def send(self, data):
126
 
        return self.__socket.send(data)
127
 
 
128
 
    def recv(self, n):
129
 
        return self.__socket.recv(n)
130
 
 
131
 
    def recv_ready(self):
132
 
        return True
133
 
 
134
 
    def close(self):
135
 
        self.__socket.close()
136
 
 
137
 
 
138
 
class SSHVendor(object):
139
 
    """Abstract base class for SSH vendor implementations."""
140
 
    
141
 
    def connect_sftp(self, username, password, host, port):
142
 
        """Make an SSH connection, and return an SFTPClient.
143
 
        
144
 
        :param username: an ascii string
145
 
        :param password: an ascii string
146
 
        :param host: a host name as an ascii string
147
 
        :param port: a port number
148
 
        :type port: int
149
 
 
150
 
        :raises: ConnectionError if it cannot connect.
151
 
 
152
 
        :rtype: paramiko.sftp_client.SFTPClient
153
 
        """
154
 
        raise NotImplementedError(self.connect_sftp)
155
 
 
156
 
    def connect_ssh(self, username, password, host, port, command):
157
 
        """Make an SSH connection.
158
 
        
159
 
        :returns: something with a `close` method, and a `get_filelike_channels`
160
 
            method that returns a pair of (read, write) filelike objects.
161
 
        """
162
 
        raise NotImplementedError(self.connect_ssh)
163
 
        
164
 
 
165
 
class LoopbackVendor(SSHVendor):
166
 
    """SSH "vendor" that connects over a plain TCP socket, not SSH."""
167
 
    
168
 
    def connect_sftp(self, username, password, host, port):
169
 
        sock = socket.socket()
170
 
        try:
171
 
            sock.connect((host, port))
172
 
        except socket.error, e:
173
 
            raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
174
 
                                  % (host, port, e))
175
 
        return SFTPClient(LoopbackSFTP(sock))
176
 
 
177
 
register_ssh_vendor('loopback', LoopbackVendor())
178
 
 
179
 
 
180
 
class _ParamikoSSHConnection(object):
181
 
    def __init__(self, channel):
182
 
        self.channel = channel
183
 
 
184
 
    def get_filelike_channels(self):
185
 
        return self.channel.makefile('rb'), self.channel.makefile('wb')
186
 
 
187
 
    def close(self):
188
 
        return self.channel.close()
189
 
 
190
 
 
191
 
class ParamikoVendor(SSHVendor):
192
 
    """Vendor that uses paramiko."""
193
 
 
194
 
    def _connect(self, username, password, host, port):
195
 
        global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
196
 
        
197
 
        load_host_keys()
198
 
 
199
 
        try:
200
 
            t = paramiko.Transport((host, port or 22))
201
 
            t.set_log_channel('bzr.paramiko')
202
 
            t.start_client()
203
 
        except (paramiko.SSHException, socket.error), e:
204
 
            raise ConnectionError('Unable to reach SSH host %s:%s: %s' 
205
 
                                  % (host, port, e))
206
 
            
207
 
        server_key = t.get_remote_server_key()
208
 
        server_key_hex = paramiko.util.hexify(server_key.get_fingerprint())
209
 
        keytype = server_key.get_name()
210
 
        if host in SYSTEM_HOSTKEYS and keytype in SYSTEM_HOSTKEYS[host]:
211
 
            our_server_key = SYSTEM_HOSTKEYS[host][keytype]
212
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
213
 
        elif host in BZR_HOSTKEYS and keytype in BZR_HOSTKEYS[host]:
214
 
            our_server_key = BZR_HOSTKEYS[host][keytype]
215
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
216
 
        else:
217
 
            warning('Adding %s host key for %s: %s' % (keytype, host, server_key_hex))
218
 
            if host not in BZR_HOSTKEYS:
219
 
                BZR_HOSTKEYS[host] = {}
220
 
            BZR_HOSTKEYS[host][keytype] = server_key
221
 
            our_server_key = server_key
222
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
223
 
            save_host_keys()
224
 
        if server_key != our_server_key:
225
 
            filename1 = os.path.expanduser('~/.ssh/known_hosts')
226
 
            filename2 = pathjoin(config_dir(), 'ssh_host_keys')
227
 
            raise TransportError('Host keys for %s do not match!  %s != %s' % \
228
 
                (host, our_server_key_hex, server_key_hex),
229
 
                ['Try editing %s or %s' % (filename1, filename2)])
230
 
 
231
 
        _paramiko_auth(username, password, host, t)
232
 
        return t
233
 
        
234
 
    def connect_sftp(self, username, password, host, port):
235
 
        t = self._connect(username, password, host, port)
236
 
        try:
237
 
            return t.open_sftp_client()
238
 
        except paramiko.SSHException, e:
239
 
            raise ConnectionError('Unable to start sftp client %s:%d' %
240
 
                                  (host, port), e)
241
 
 
242
 
    def connect_ssh(self, username, password, host, port, command):
243
 
        t = self._connect(username, password, host, port)
244
 
        try:
245
 
            channel = t.open_session()
246
 
            cmdline = ' '.join(command)
247
 
            channel.exec_command(cmdline)
248
 
            return _ParamikoSSHConnection(channel)
249
 
        except paramiko.SSHException, e:
250
 
            raise ConnectionError('Unable to invoke remote bzr %s:%d' %
251
 
                                  (host, port), e)
252
 
 
253
 
register_ssh_vendor('paramiko', ParamikoVendor())
254
 
 
255
 
 
256
 
class SubprocessVendor(SSHVendor):
257
 
    """Abstract base class for vendors that use pipes to a subprocess."""
258
 
    
259
 
    def _connect(self, argv):
260
 
        proc = subprocess.Popen(argv,
261
 
                                stdin=subprocess.PIPE,
262
 
                                stdout=subprocess.PIPE,
263
 
                                **os_specific_subprocess_params())
264
 
        return SSHSubprocess(proc)
265
 
 
266
 
    def connect_sftp(self, username, password, host, port):
267
 
        try:
268
 
            argv = self._get_vendor_specific_argv(username, host, port,
269
 
                                                  subsystem='sftp')
270
 
            sock = self._connect(argv)
271
 
            return SFTPClient(sock)
272
 
        except (EOFError, paramiko.SSHException), e:
273
 
            raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
274
 
                                  % (host, port, e))
275
 
        except (OSError, IOError), e:
276
 
            # If the machine is fast enough, ssh can actually exit
277
 
            # before we try and send it the sftp request, which
278
 
            # raises a Broken Pipe
279
 
            if e.errno not in (errno.EPIPE,):
280
 
                raise
281
 
            raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
282
 
                                  % (host, port, e))
283
 
 
284
 
    def connect_ssh(self, username, password, host, port, command):
285
 
        try:
286
 
            argv = self._get_vendor_specific_argv(username, host, port,
287
 
                                                  command=command)
288
 
            return self._connect(argv)
289
 
        except (EOFError), e:
290
 
            raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
291
 
                                  % (host, port, e))
292
 
        except (OSError, IOError), e:
293
 
            # If the machine is fast enough, ssh can actually exit
294
 
            # before we try and send it the sftp request, which
295
 
            # raises a Broken Pipe
296
 
            if e.errno not in (errno.EPIPE,):
297
 
                raise
298
 
            raise ConnectionError('Unable to connect to SSH host %s:%s: %s'
299
 
                                  % (host, port, e))
300
 
 
301
 
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
302
 
                                  command=None):
303
 
        """Returns the argument list to run the subprocess with.
304
 
        
305
 
        Exactly one of 'subsystem' and 'command' must be specified.
306
 
        """
307
 
        raise NotImplementedError(self._get_vendor_specific_argv)
308
 
 
309
 
register_ssh_vendor('none', ParamikoVendor())
310
 
 
311
 
 
312
 
class OpenSSHSubprocessVendor(SubprocessVendor):
313
 
    """SSH vendor that uses the 'ssh' executable from OpenSSH."""
314
 
    
315
 
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
316
 
                                  command=None):
317
 
        assert subsystem is not None or command is not None, (
318
 
            'Must specify a command or subsystem')
319
 
        if subsystem is not None:
320
 
            assert command is None, (
321
 
                'subsystem and command are mutually exclusive')
322
 
        args = ['ssh',
323
 
                '-oForwardX11=no', '-oForwardAgent=no',
324
 
                '-oClearAllForwardings=yes', '-oProtocol=2',
325
 
                '-oNoHostAuthenticationForLocalhost=yes']
326
 
        if port is not None:
327
 
            args.extend(['-p', str(port)])
328
 
        if username is not None:
329
 
            args.extend(['-l', username])
330
 
        if subsystem is not None:
331
 
            args.extend(['-s', host, subsystem])
332
 
        else:
333
 
            args.extend([host] + command)
334
 
        return args
335
 
 
336
 
register_ssh_vendor('openssh', OpenSSHSubprocessVendor())
337
 
 
338
 
 
339
 
class SSHCorpSubprocessVendor(SubprocessVendor):
340
 
    """SSH vendor that uses the 'ssh' executable from SSH Corporation."""
341
 
 
342
 
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
343
 
                                  command=None):
344
 
        assert subsystem is not None or command is not None, (
345
 
            'Must specify a command or subsystem')
346
 
        if subsystem is not None:
347
 
            assert command is None, (
348
 
                'subsystem and command are mutually exclusive')
349
 
        args = ['ssh', '-x']
350
 
        if port is not None:
351
 
            args.extend(['-p', str(port)])
352
 
        if username is not None:
353
 
            args.extend(['-l', username])
354
 
        if subsystem is not None:
355
 
            args.extend(['-s', subsystem, host])
356
 
        else:
357
 
            args.extend([host] + command)
358
 
        return args
359
 
    
360
 
register_ssh_vendor('ssh', SSHCorpSubprocessVendor())
361
 
 
362
 
 
363
 
def _paramiko_auth(username, password, host, paramiko_transport):
364
 
    # paramiko requires a username, but it might be none if nothing was supplied
365
 
    # use the local username, just in case.
366
 
    # We don't override username, because if we aren't using paramiko,
367
 
    # the username might be specified in ~/.ssh/config and we don't want to
368
 
    # force it to something else
369
 
    # Also, it would mess up the self.relpath() functionality
370
 
    username = username or getpass.getuser()
371
 
 
372
 
    if _use_ssh_agent:
373
 
        agent = paramiko.Agent()
374
 
        for key in agent.get_keys():
375
 
            mutter('Trying SSH agent key %s' % paramiko.util.hexify(key.get_fingerprint()))
376
 
            try:
377
 
                paramiko_transport.auth_publickey(username, key)
378
 
                return
379
 
            except paramiko.SSHException, e:
380
 
                pass
381
 
    
382
 
    # okay, try finding id_rsa or id_dss?  (posix only)
383
 
    if _try_pkey_auth(paramiko_transport, paramiko.RSAKey, username, 'id_rsa'):
384
 
        return
385
 
    if _try_pkey_auth(paramiko_transport, paramiko.DSSKey, username, 'id_dsa'):
386
 
        return
387
 
 
388
 
    if password:
389
 
        try:
390
 
            paramiko_transport.auth_password(username, password)
391
 
            return
392
 
        except paramiko.SSHException, e:
393
 
            pass
394
 
 
395
 
    # give up and ask for a password
396
 
    password = bzrlib.ui.ui_factory.get_password(
397
 
            prompt='SSH %(user)s@%(host)s password',
398
 
            user=username, host=host)
399
 
    try:
400
 
        paramiko_transport.auth_password(username, password)
401
 
    except paramiko.SSHException, e:
402
 
        raise ConnectionError('Unable to authenticate to SSH host as %s@%s' %
403
 
                              (username, host), e)
404
 
 
405
 
 
406
 
def _try_pkey_auth(paramiko_transport, pkey_class, username, filename):
407
 
    filename = os.path.expanduser('~/.ssh/' + filename)
408
 
    try:
409
 
        key = pkey_class.from_private_key_file(filename)
410
 
        paramiko_transport.auth_publickey(username, key)
411
 
        return True
412
 
    except paramiko.PasswordRequiredException:
413
 
        password = bzrlib.ui.ui_factory.get_password(
414
 
                prompt='SSH %(filename)s password',
415
 
                filename=filename)
416
 
        try:
417
 
            key = pkey_class.from_private_key_file(filename, password)
418
 
            paramiko_transport.auth_publickey(username, key)
419
 
            return True
420
 
        except paramiko.SSHException:
421
 
            mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
422
 
    except paramiko.SSHException:
423
 
        mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
424
 
    except IOError:
425
 
        pass
426
 
    return False
427
 
 
428
 
 
429
 
def load_host_keys():
430
 
    """
431
 
    Load system host keys (probably doesn't work on windows) and any
432
 
    "discovered" keys from previous sessions.
433
 
    """
434
 
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
435
 
    try:
436
 
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
437
 
    except Exception, e:
438
 
        mutter('failed to load system host keys: ' + str(e))
439
 
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
440
 
    try:
441
 
        BZR_HOSTKEYS = paramiko.util.load_host_keys(bzr_hostkey_path)
442
 
    except Exception, e:
443
 
        mutter('failed to load bzr host keys: ' + str(e))
444
 
        save_host_keys()
445
 
 
446
 
 
447
 
def save_host_keys():
448
 
    """
449
 
    Save "discovered" host keys in $(config)/ssh_host_keys/.
450
 
    """
451
 
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
452
 
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
453
 
    ensure_config_dir_exists()
454
 
 
455
 
    try:
456
 
        f = open(bzr_hostkey_path, 'w')
457
 
        f.write('# SSH host keys collected by bzr\n')
458
 
        for hostname, keys in BZR_HOSTKEYS.iteritems():
459
 
            for keytype, key in keys.iteritems():
460
 
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
461
 
        f.close()
462
 
    except IOError, e:
463
 
        mutter('failed to save bzr host keys: ' + str(e))
464
 
 
465
 
 
466
 
def os_specific_subprocess_params():
467
 
    """Get O/S specific subprocess parameters."""
468
 
    if sys.platform == 'win32':
469
 
        # setting the process group and closing fds is not supported on 
470
 
        # win32
471
 
        return {}
472
 
    else:
473
 
        # We close fds other than the pipes as the child process does not need 
474
 
        # them to be open.
475
 
        #
476
 
        # We also set the child process to ignore SIGINT.  Normally the signal
477
 
        # would be sent to every process in the foreground process group, but
478
 
        # this causes it to be seen only by bzr and not by ssh.  Python will
479
 
        # generate a KeyboardInterrupt in bzr, and we will then have a chance
480
 
        # to release locks or do other cleanup over ssh before the connection
481
 
        # goes away.  
482
 
        # <https://launchpad.net/products/bzr/+bug/5987>
483
 
        #
484
 
        # Running it in a separate process group is not good because then it
485
 
        # can't get non-echoed input of a password or passphrase.
486
 
        # <https://launchpad.net/products/bzr/+bug/40508>
487
 
        return {'preexec_fn': _ignore_sigint,
488
 
                'close_fds': True,
489
 
                }
490
 
 
491
 
 
492
 
class SSHSubprocess(object):
493
 
    """A socket-like object that talks to an ssh subprocess via pipes."""
494
 
 
495
 
    def __init__(self, proc):
496
 
        self.proc = proc
497
 
 
498
 
    def send(self, data):
499
 
        return os.write(self.proc.stdin.fileno(), data)
500
 
 
501
 
    def recv_ready(self):
502
 
        # TODO: jam 20051215 this function is necessary to support the
503
 
        # pipelined() function. In reality, it probably should use
504
 
        # poll() or select() to actually return if there is data
505
 
        # available, otherwise we probably don't get any benefit
506
 
        return True
507
 
 
508
 
    def recv(self, count):
509
 
        return os.read(self.proc.stdout.fileno(), count)
510
 
 
511
 
    def close(self):
512
 
        self.proc.stdin.close()
513
 
        self.proc.stdout.close()
514
 
        self.proc.wait()
515
 
 
516
 
    def get_filelike_channels(self):
517
 
        return (self.proc.stdout, self.proc.stdin)
518