~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/ssh.py

  • Committer: Martin Pool
  • Date: 2006-11-03 01:52:12 UTC
  • mto: This revision was merged to the branch mainline in revision 2119.
  • Revision ID: mbp@sourcefrog.net-20061103015212-1e5f881c2152d79f
Review comments

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
 
2
# Copyright (C) 2005, 2006 Canonical Ltd
 
3
#
 
4
# This program is free software; you can redistribute it and/or modify
 
5
# it under the terms of the GNU General Public License as published by
 
6
# the Free Software Foundation; either version 2 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
17
 
 
18
"""Foundation SSH support for SFTP and smart server."""
 
19
 
 
20
import errno
 
21
import getpass
 
22
import os
 
23
import socket
 
24
import subprocess
 
25
import sys
 
26
 
 
27
from bzrlib.config import config_dir, ensure_config_dir_exists
 
28
from bzrlib.errors import (ConnectionError,
 
29
                           ParamikoNotPresent,
 
30
                           SocketConnectionError,
 
31
                           TransportError,
 
32
                           UnknownSSH,
 
33
                           )
 
34
 
 
35
from bzrlib.osutils import pathjoin
 
36
from bzrlib.trace import mutter, warning
 
37
import bzrlib.ui
 
38
 
 
39
try:
 
40
    import paramiko
 
41
except ImportError, e:
 
42
    raise ParamikoNotPresent(e)
 
43
else:
 
44
    from paramiko.sftp_client import SFTPClient
 
45
 
 
46
 
 
47
SYSTEM_HOSTKEYS = {}
 
48
BZR_HOSTKEYS = {}
 
49
 
 
50
 
 
51
_paramiko_version = getattr(paramiko, '__version_info__', (0, 0, 0))
 
52
 
 
53
# Paramiko 1.5 tries to open a socket.AF_UNIX in order to connect
 
54
# to ssh-agent. That attribute doesn't exist on win32 (it does in cygwin)
 
55
# so we get an AttributeError exception. So we will not try to
 
56
# connect to an agent if we are on win32 and using Paramiko older than 1.6
 
57
_use_ssh_agent = (sys.platform != 'win32' or _paramiko_version >= (1, 6, 0))
 
58
 
 
59
_ssh_vendors = {}
 
60
 
 
61
def register_ssh_vendor(name, vendor):
 
62
    """Register SSH vendor."""
 
63
    _ssh_vendors[name] = vendor
 
64
 
 
65
    
 
66
_ssh_vendor = None
 
67
def _get_ssh_vendor():
 
68
    """Find out what version of SSH is on the system."""
 
69
    global _ssh_vendor
 
70
    if _ssh_vendor is not None:
 
71
        return _ssh_vendor
 
72
 
 
73
    if 'BZR_SSH' in os.environ:
 
74
        vendor_name = os.environ['BZR_SSH']
 
75
        try:
 
76
            _ssh_vendor = _ssh_vendors[vendor_name]
 
77
        except KeyError:
 
78
            raise UnknownSSH(vendor_name)
 
79
        return _ssh_vendor
 
80
 
 
81
    try:
 
82
        p = subprocess.Popen(['ssh', '-V'],
 
83
                             stdin=subprocess.PIPE,
 
84
                             stdout=subprocess.PIPE,
 
85
                             stderr=subprocess.PIPE,
 
86
                             **os_specific_subprocess_params())
 
87
        returncode = p.returncode
 
88
        stdout, stderr = p.communicate()
 
89
    except OSError:
 
90
        returncode = -1
 
91
        stdout = stderr = ''
 
92
    if 'OpenSSH' in stderr:
 
93
        mutter('ssh implementation is OpenSSH')
 
94
        _ssh_vendor = OpenSSHSubprocessVendor()
 
95
    elif 'SSH Secure Shell' in stderr:
 
96
        mutter('ssh implementation is SSH Corp.')
 
97
        _ssh_vendor = SSHCorpSubprocessVendor()
 
98
 
 
99
    if _ssh_vendor is not None:
 
100
        return _ssh_vendor
 
101
 
 
102
    # XXX: 20051123 jamesh
 
103
    # A check for putty's plink or lsh would go here.
 
104
 
 
105
    mutter('falling back to paramiko implementation')
 
106
    _ssh_vendor = ParamikoVendor()
 
107
    return _ssh_vendor
 
108
 
 
109
 
 
110
def _ignore_sigint():
 
111
    # TODO: This should possibly ignore SIGHUP as well, but bzr currently
 
112
    # doesn't handle it itself.
 
113
    # <https://launchpad.net/products/bzr/+bug/41433/+index>
 
114
    import signal
 
115
    signal.signal(signal.SIGINT, signal.SIG_IGN)
 
116
    
 
117
 
 
118
 
 
119
class LoopbackSFTP(object):
 
120
    """Simple wrapper for a socket that pretends to be a paramiko Channel."""
 
121
 
 
122
    def __init__(self, sock):
 
123
        self.__socket = sock
 
124
 
 
125
    def send(self, data):
 
126
        return self.__socket.send(data)
 
127
 
 
128
    def recv(self, n):
 
129
        return self.__socket.recv(n)
 
130
 
 
131
    def recv_ready(self):
 
132
        return True
 
133
 
 
134
    def close(self):
 
135
        self.__socket.close()
 
136
 
 
137
 
 
138
class SSHVendor(object):
 
139
    """Abstract base class for SSH vendor implementations."""
 
140
    
 
141
    def connect_sftp(self, username, password, host, port):
 
142
        """Make an SSH connection, and return an SFTPClient.
 
143
        
 
144
        :param username: an ascii string
 
145
        :param password: an ascii string
 
146
        :param host: a host name as an ascii string
 
147
        :param port: a port number
 
148
        :type port: int
 
149
 
 
150
        :raises: ConnectionError if it cannot connect.
 
151
 
 
152
        :rtype: paramiko.sftp_client.SFTPClient
 
153
        """
 
154
        raise NotImplementedError(self.connect_sftp)
 
155
 
 
156
    def connect_ssh(self, username, password, host, port, command):
 
157
        """Make an SSH connection.
 
158
        
 
159
        :returns: something with a `close` method, and a `get_filelike_channels`
 
160
            method that returns a pair of (read, write) filelike objects.
 
161
        """
 
162
        raise NotImplementedError(self.connect_ssh)
 
163
        
 
164
    def _raise_connection_error(self, host, port=None, orig_error=None,
 
165
                                msg='Unable to connect to SSH host'):
 
166
        """Raise a SocketConnectionError with properly formatted host.
 
167
 
 
168
        This just unifies all the locations that try to raise ConnectionError,
 
169
        so that they format things properly.
 
170
        """
 
171
        raise SocketConnectionError(host=host, port=port, msg=msg,
 
172
                                    orig_error=orig_error)
 
173
 
 
174
 
 
175
class LoopbackVendor(SSHVendor):
 
176
    """SSH "vendor" that connects over a plain TCP socket, not SSH."""
 
177
    
 
178
    def connect_sftp(self, username, password, host, port):
 
179
        sock = socket.socket()
 
180
        try:
 
181
            sock.connect((host, port))
 
182
        except socket.error, e:
 
183
            self._raise_connection_error(host, port=port, orig_error=e)
 
184
        return SFTPClient(LoopbackSFTP(sock))
 
185
 
 
186
register_ssh_vendor('loopback', LoopbackVendor())
 
187
 
 
188
 
 
189
class _ParamikoSSHConnection(object):
 
190
    def __init__(self, channel):
 
191
        self.channel = channel
 
192
 
 
193
    def get_filelike_channels(self):
 
194
        return self.channel.makefile('rb'), self.channel.makefile('wb')
 
195
 
 
196
    def close(self):
 
197
        return self.channel.close()
 
198
 
 
199
 
 
200
class ParamikoVendor(SSHVendor):
 
201
    """Vendor that uses paramiko."""
 
202
 
 
203
    def _connect(self, username, password, host, port):
 
204
        global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
 
205
        
 
206
        load_host_keys()
 
207
 
 
208
        try:
 
209
            t = paramiko.Transport((host, port or 22))
 
210
            t.set_log_channel('bzr.paramiko')
 
211
            t.start_client()
 
212
        except (paramiko.SSHException, socket.error), e:
 
213
            self._raise_connection_error(host, port=port, orig_error=e)
 
214
            
 
215
        server_key = t.get_remote_server_key()
 
216
        server_key_hex = paramiko.util.hexify(server_key.get_fingerprint())
 
217
        keytype = server_key.get_name()
 
218
        if host in SYSTEM_HOSTKEYS and keytype in SYSTEM_HOSTKEYS[host]:
 
219
            our_server_key = SYSTEM_HOSTKEYS[host][keytype]
 
220
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
221
        elif host in BZR_HOSTKEYS and keytype in BZR_HOSTKEYS[host]:
 
222
            our_server_key = BZR_HOSTKEYS[host][keytype]
 
223
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
224
        else:
 
225
            warning('Adding %s host key for %s: %s' % (keytype, host, server_key_hex))
 
226
            if host not in BZR_HOSTKEYS:
 
227
                BZR_HOSTKEYS[host] = {}
 
228
            BZR_HOSTKEYS[host][keytype] = server_key
 
229
            our_server_key = server_key
 
230
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
231
            save_host_keys()
 
232
        if server_key != our_server_key:
 
233
            filename1 = os.path.expanduser('~/.ssh/known_hosts')
 
234
            filename2 = pathjoin(config_dir(), 'ssh_host_keys')
 
235
            raise TransportError('Host keys for %s do not match!  %s != %s' % \
 
236
                (host, our_server_key_hex, server_key_hex),
 
237
                ['Try editing %s or %s' % (filename1, filename2)])
 
238
 
 
239
        _paramiko_auth(username, password, host, t)
 
240
        return t
 
241
        
 
242
    def connect_sftp(self, username, password, host, port):
 
243
        t = self._connect(username, password, host, port)
 
244
        try:
 
245
            return t.open_sftp_client()
 
246
        except paramiko.SSHException, e:
 
247
            self._raise_connection_error(host, port=port, orig_error=e,
 
248
                                         msg='Unable to start sftp client')
 
249
 
 
250
    def connect_ssh(self, username, password, host, port, command):
 
251
        t = self._connect(username, password, host, port)
 
252
        try:
 
253
            channel = t.open_session()
 
254
            cmdline = ' '.join(command)
 
255
            channel.exec_command(cmdline)
 
256
            return _ParamikoSSHConnection(channel)
 
257
        except paramiko.SSHException, e:
 
258
            self._raise_connection_error(host, port=port, orig_error=e,
 
259
                                         msg='Unable to invoke remote bzr')
 
260
 
 
261
register_ssh_vendor('paramiko', ParamikoVendor())
 
262
 
 
263
 
 
264
class SubprocessVendor(SSHVendor):
 
265
    """Abstract base class for vendors that use pipes to a subprocess."""
 
266
    
 
267
    def _connect(self, argv):
 
268
        proc = subprocess.Popen(argv,
 
269
                                stdin=subprocess.PIPE,
 
270
                                stdout=subprocess.PIPE,
 
271
                                **os_specific_subprocess_params())
 
272
        return SSHSubprocess(proc)
 
273
 
 
274
    def connect_sftp(self, username, password, host, port):
 
275
        try:
 
276
            argv = self._get_vendor_specific_argv(username, host, port,
 
277
                                                  subsystem='sftp')
 
278
            sock = self._connect(argv)
 
279
            return SFTPClient(sock)
 
280
        except (EOFError, paramiko.SSHException), e:
 
281
            self._raise_connection_error(host, port=port, orig_error=e)
 
282
        except (OSError, IOError), e:
 
283
            # If the machine is fast enough, ssh can actually exit
 
284
            # before we try and send it the sftp request, which
 
285
            # raises a Broken Pipe
 
286
            if e.errno not in (errno.EPIPE,):
 
287
                raise
 
288
            self._raise_connection_error(host, port=port, orig_error=e)
 
289
 
 
290
    def connect_ssh(self, username, password, host, port, command):
 
291
        try:
 
292
            argv = self._get_vendor_specific_argv(username, host, port,
 
293
                                                  command=command)
 
294
            return self._connect(argv)
 
295
        except (EOFError), e:
 
296
            self._raise_connection_error(host, port=port, orig_error=e)
 
297
        except (OSError, IOError), e:
 
298
            # If the machine is fast enough, ssh can actually exit
 
299
            # before we try and send it the sftp request, which
 
300
            # raises a Broken Pipe
 
301
            if e.errno not in (errno.EPIPE,):
 
302
                raise
 
303
            self._raise_connection_error(host, port=port, orig_error=e)
 
304
 
 
305
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
 
306
                                  command=None):
 
307
        """Returns the argument list to run the subprocess with.
 
308
        
 
309
        Exactly one of 'subsystem' and 'command' must be specified.
 
310
        """
 
311
        raise NotImplementedError(self._get_vendor_specific_argv)
 
312
 
 
313
register_ssh_vendor('none', ParamikoVendor())
 
314
 
 
315
 
 
316
class OpenSSHSubprocessVendor(SubprocessVendor):
 
317
    """SSH vendor that uses the 'ssh' executable from OpenSSH."""
 
318
    
 
319
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
 
320
                                  command=None):
 
321
        assert subsystem is not None or command is not None, (
 
322
            'Must specify a command or subsystem')
 
323
        if subsystem is not None:
 
324
            assert command is None, (
 
325
                'subsystem and command are mutually exclusive')
 
326
        args = ['ssh',
 
327
                '-oForwardX11=no', '-oForwardAgent=no',
 
328
                '-oClearAllForwardings=yes', '-oProtocol=2',
 
329
                '-oNoHostAuthenticationForLocalhost=yes']
 
330
        if port is not None:
 
331
            args.extend(['-p', str(port)])
 
332
        if username is not None:
 
333
            args.extend(['-l', username])
 
334
        if subsystem is not None:
 
335
            args.extend(['-s', host, subsystem])
 
336
        else:
 
337
            args.extend([host] + command)
 
338
        return args
 
339
 
 
340
register_ssh_vendor('openssh', OpenSSHSubprocessVendor())
 
341
 
 
342
 
 
343
class SSHCorpSubprocessVendor(SubprocessVendor):
 
344
    """SSH vendor that uses the 'ssh' executable from SSH Corporation."""
 
345
 
 
346
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
 
347
                                  command=None):
 
348
        assert subsystem is not None or command is not None, (
 
349
            'Must specify a command or subsystem')
 
350
        if subsystem is not None:
 
351
            assert command is None, (
 
352
                'subsystem and command are mutually exclusive')
 
353
        args = ['ssh', '-x']
 
354
        if port is not None:
 
355
            args.extend(['-p', str(port)])
 
356
        if username is not None:
 
357
            args.extend(['-l', username])
 
358
        if subsystem is not None:
 
359
            args.extend(['-s', subsystem, host])
 
360
        else:
 
361
            args.extend([host] + command)
 
362
        return args
 
363
    
 
364
register_ssh_vendor('ssh', SSHCorpSubprocessVendor())
 
365
 
 
366
 
 
367
def _paramiko_auth(username, password, host, paramiko_transport):
 
368
    # paramiko requires a username, but it might be none if nothing was supplied
 
369
    # use the local username, just in case.
 
370
    # We don't override username, because if we aren't using paramiko,
 
371
    # the username might be specified in ~/.ssh/config and we don't want to
 
372
    # force it to something else
 
373
    # Also, it would mess up the self.relpath() functionality
 
374
    username = username or getpass.getuser()
 
375
 
 
376
    if _use_ssh_agent:
 
377
        agent = paramiko.Agent()
 
378
        for key in agent.get_keys():
 
379
            mutter('Trying SSH agent key %s' % paramiko.util.hexify(key.get_fingerprint()))
 
380
            try:
 
381
                paramiko_transport.auth_publickey(username, key)
 
382
                return
 
383
            except paramiko.SSHException, e:
 
384
                pass
 
385
    
 
386
    # okay, try finding id_rsa or id_dss?  (posix only)
 
387
    if _try_pkey_auth(paramiko_transport, paramiko.RSAKey, username, 'id_rsa'):
 
388
        return
 
389
    if _try_pkey_auth(paramiko_transport, paramiko.DSSKey, username, 'id_dsa'):
 
390
        return
 
391
 
 
392
    if password:
 
393
        try:
 
394
            paramiko_transport.auth_password(username, password)
 
395
            return
 
396
        except paramiko.SSHException, e:
 
397
            pass
 
398
 
 
399
    # give up and ask for a password
 
400
    password = bzrlib.ui.ui_factory.get_password(
 
401
            prompt='SSH %(user)s@%(host)s password',
 
402
            user=username, host=host)
 
403
    try:
 
404
        paramiko_transport.auth_password(username, password)
 
405
    except paramiko.SSHException, e:
 
406
        raise ConnectionError('Unable to authenticate to SSH host as %s@%s' %
 
407
                              (username, host), e)
 
408
 
 
409
 
 
410
def _try_pkey_auth(paramiko_transport, pkey_class, username, filename):
 
411
    filename = os.path.expanduser('~/.ssh/' + filename)
 
412
    try:
 
413
        key = pkey_class.from_private_key_file(filename)
 
414
        paramiko_transport.auth_publickey(username, key)
 
415
        return True
 
416
    except paramiko.PasswordRequiredException:
 
417
        password = bzrlib.ui.ui_factory.get_password(
 
418
                prompt='SSH %(filename)s password',
 
419
                filename=filename)
 
420
        try:
 
421
            key = pkey_class.from_private_key_file(filename, password)
 
422
            paramiko_transport.auth_publickey(username, key)
 
423
            return True
 
424
        except paramiko.SSHException:
 
425
            mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
 
426
    except paramiko.SSHException:
 
427
        mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
 
428
    except IOError:
 
429
        pass
 
430
    return False
 
431
 
 
432
 
 
433
def load_host_keys():
 
434
    """
 
435
    Load system host keys (probably doesn't work on windows) and any
 
436
    "discovered" keys from previous sessions.
 
437
    """
 
438
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
 
439
    try:
 
440
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
 
441
    except Exception, e:
 
442
        mutter('failed to load system host keys: ' + str(e))
 
443
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
 
444
    try:
 
445
        BZR_HOSTKEYS = paramiko.util.load_host_keys(bzr_hostkey_path)
 
446
    except Exception, e:
 
447
        mutter('failed to load bzr host keys: ' + str(e))
 
448
        save_host_keys()
 
449
 
 
450
 
 
451
def save_host_keys():
 
452
    """
 
453
    Save "discovered" host keys in $(config)/ssh_host_keys/.
 
454
    """
 
455
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
 
456
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
 
457
    ensure_config_dir_exists()
 
458
 
 
459
    try:
 
460
        f = open(bzr_hostkey_path, 'w')
 
461
        f.write('# SSH host keys collected by bzr\n')
 
462
        for hostname, keys in BZR_HOSTKEYS.iteritems():
 
463
            for keytype, key in keys.iteritems():
 
464
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
 
465
        f.close()
 
466
    except IOError, e:
 
467
        mutter('failed to save bzr host keys: ' + str(e))
 
468
 
 
469
 
 
470
def os_specific_subprocess_params():
 
471
    """Get O/S specific subprocess parameters."""
 
472
    if sys.platform == 'win32':
 
473
        # setting the process group and closing fds is not supported on 
 
474
        # win32
 
475
        return {}
 
476
    else:
 
477
        # We close fds other than the pipes as the child process does not need 
 
478
        # them to be open.
 
479
        #
 
480
        # We also set the child process to ignore SIGINT.  Normally the signal
 
481
        # would be sent to every process in the foreground process group, but
 
482
        # this causes it to be seen only by bzr and not by ssh.  Python will
 
483
        # generate a KeyboardInterrupt in bzr, and we will then have a chance
 
484
        # to release locks or do other cleanup over ssh before the connection
 
485
        # goes away.  
 
486
        # <https://launchpad.net/products/bzr/+bug/5987>
 
487
        #
 
488
        # Running it in a separate process group is not good because then it
 
489
        # can't get non-echoed input of a password or passphrase.
 
490
        # <https://launchpad.net/products/bzr/+bug/40508>
 
491
        return {'preexec_fn': _ignore_sigint,
 
492
                'close_fds': True,
 
493
                }
 
494
 
 
495
 
 
496
class SSHSubprocess(object):
 
497
    """A socket-like object that talks to an ssh subprocess via pipes."""
 
498
 
 
499
    def __init__(self, proc):
 
500
        self.proc = proc
 
501
 
 
502
    def send(self, data):
 
503
        return os.write(self.proc.stdin.fileno(), data)
 
504
 
 
505
    def recv_ready(self):
 
506
        # TODO: jam 20051215 this function is necessary to support the
 
507
        # pipelined() function. In reality, it probably should use
 
508
        # poll() or select() to actually return if there is data
 
509
        # available, otherwise we probably don't get any benefit
 
510
        return True
 
511
 
 
512
    def recv(self, count):
 
513
        return os.read(self.proc.stdout.fileno(), count)
 
514
 
 
515
    def close(self):
 
516
        self.proc.stdin.close()
 
517
        self.proc.stdout.close()
 
518
        self.proc.wait()
 
519
 
 
520
    def get_filelike_channels(self):
 
521
        return (self.proc.stdout, self.proc.stdin)
 
522