~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/ssh.py

  • Committer: Martin Pool
  • Date: 2007-04-04 06:17:31 UTC
  • mto: This revision was merged to the branch mainline in revision 2397.
  • Revision ID: mbp@sourcefrog.net-20070404061731-tt2xrzllqhbodn83
Contents of TODO file moved into bug tracker

Show diffs side-by-side

added added

removed removed

Lines of Context:
13
13
#
14
14
# You should have received a copy of the GNU General Public License
15
15
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
17
 
18
18
"""Foundation SSH support for SFTP and smart server."""
19
19
 
24
24
import subprocess
25
25
import sys
26
26
 
27
 
from bzrlib import (
28
 
    config,
29
 
    errors,
30
 
    osutils,
31
 
    trace,
32
 
    ui,
33
 
    )
 
27
from bzrlib.config import config_dir, ensure_config_dir_exists
 
28
from bzrlib.errors import (ConnectionError,
 
29
                           ParamikoNotPresent,
 
30
                           SocketConnectionError,
 
31
                           SSHVendorNotFound,
 
32
                           TransportError,
 
33
                           UnknownSSH,
 
34
                           )
 
35
 
 
36
from bzrlib.osutils import pathjoin
 
37
from bzrlib.trace import mutter, warning
 
38
import bzrlib.ui
34
39
 
35
40
try:
36
41
    import paramiko
93
98
            try:
94
99
                vendor = self._ssh_vendors[vendor_name]
95
100
            except KeyError:
96
 
                raise errors.UnknownSSH(vendor_name)
 
101
                raise UnknownSSH(vendor_name)
97
102
            return vendor
98
103
        return None
99
104
 
109
114
            stdout = stderr = ''
110
115
        return stdout + stderr
111
116
 
112
 
    def _get_vendor_by_version_string(self, version, args):
 
117
    def _get_vendor_by_version_string(self, version):
113
118
        """Return the vendor or None based on output from the subprocess.
114
119
 
115
120
        :param version: The output of 'ssh -V' like command.
116
 
        :param args: Command line that was run.
117
121
        """
118
122
        vendor = None
119
123
        if 'OpenSSH' in version:
120
 
            trace.mutter('ssh implementation is OpenSSH')
 
124
            mutter('ssh implementation is OpenSSH')
121
125
            vendor = OpenSSHSubprocessVendor()
122
126
        elif 'SSH Secure Shell' in version:
123
 
            trace.mutter('ssh implementation is SSH Corp.')
 
127
            mutter('ssh implementation is SSH Corp.')
124
128
            vendor = SSHCorpSubprocessVendor()
125
 
        elif 'plink' in version and args[0] == 'plink':
126
 
            # Checking if "plink" was the executed argument as Windows
127
 
            # sometimes reports 'ssh -V' incorrectly with 'plink' in it's
128
 
            # version.  See https://bugs.launchpad.net/bzr/+bug/107155
129
 
            trace.mutter("ssh implementation is Putty's plink.")
 
129
        elif 'plink' in version:
 
130
            mutter("ssh implementation is Putty's plink.")
130
131
            vendor = PLinkSubprocessVendor()
131
132
        return vendor
132
133
 
133
134
    def _get_vendor_by_inspection(self):
134
135
        """Return the vendor or None by checking for known SSH implementations."""
135
 
        for args in (['ssh', '-V'], ['plink', '-V']):
 
136
        for args in [['ssh', '-V'], ['plink', '-V']]:
136
137
            version = self._get_ssh_version_string(args)
137
 
            vendor = self._get_vendor_by_version_string(version, args)
 
138
            vendor = self._get_vendor_by_version_string(version)
138
139
            if vendor is not None:
139
140
                return vendor
140
141
        return None
151
152
            if vendor is None:
152
153
                vendor = self._get_vendor_by_inspection()
153
154
                if vendor is None:
154
 
                    trace.mutter('falling back to default implementation')
 
155
                    mutter('falling back to default implementation')
155
156
                    vendor = self._default_ssh_vendor
156
157
                    if vendor is None:
157
 
                        raise errors.SSHVendorNotFound()
 
158
                        raise SSHVendorNotFound()
158
159
            self._cached_ssh_vendor = vendor
159
160
        return self._cached_ssh_vendor
160
161
 
172
173
    signal.signal(signal.SIGINT, signal.SIG_IGN)
173
174
 
174
175
 
175
 
class SocketAsChannelAdapter(object):
 
176
class LoopbackSFTP(object):
176
177
    """Simple wrapper for a socket that pretends to be a paramiko Channel."""
177
178
 
178
179
    def __init__(self, sock):
179
180
        self.__socket = sock
180
 
 
181
 
    def get_name(self):
182
 
        return "bzr SocketAsChannelAdapter"
183
 
 
 
181
 
184
182
    def send(self, data):
185
183
        return self.__socket.send(data)
186
184
 
187
185
    def recv(self, n):
188
 
        try:
189
 
            return self.__socket.recv(n)
190
 
        except socket.error, e:
191
 
            if e.args[0] in (errno.EPIPE, errno.ECONNRESET, errno.ECONNABORTED,
192
 
                             errno.EBADF):
193
 
                # Connection has closed.  Paramiko expects an empty string in
194
 
                # this case, not an exception.
195
 
                return ''
196
 
            raise
 
186
        return self.__socket.recv(n)
197
187
 
198
188
    def recv_ready(self):
199
 
        # TODO: jam 20051215 this function is necessary to support the
200
 
        # pipelined() function. In reality, it probably should use
201
 
        # poll() or select() to actually return if there is data
202
 
        # available, otherwise we probably don't get any benefit
203
189
        return True
204
190
 
205
191
    def close(self):
208
194
 
209
195
class SSHVendor(object):
210
196
    """Abstract base class for SSH vendor implementations."""
211
 
 
 
197
    
212
198
    def connect_sftp(self, username, password, host, port):
213
199
        """Make an SSH connection, and return an SFTPClient.
214
 
 
 
200
        
215
201
        :param username: an ascii string
216
202
        :param password: an ascii string
217
203
        :param host: a host name as an ascii string
226
212
 
227
213
    def connect_ssh(self, username, password, host, port, command):
228
214
        """Make an SSH connection.
229
 
 
 
215
        
230
216
        :returns: something with a `close` method, and a `get_filelike_channels`
231
217
            method that returns a pair of (read, write) filelike objects.
232
218
        """
233
219
        raise NotImplementedError(self.connect_ssh)
234
 
 
 
220
        
235
221
    def _raise_connection_error(self, host, port=None, orig_error=None,
236
222
                                msg='Unable to connect to SSH host'):
237
223
        """Raise a SocketConnectionError with properly formatted host.
239
225
        This just unifies all the locations that try to raise ConnectionError,
240
226
        so that they format things properly.
241
227
        """
242
 
        raise errors.SocketConnectionError(host=host, port=port, msg=msg,
243
 
                                           orig_error=orig_error)
 
228
        raise SocketConnectionError(host=host, port=port, msg=msg,
 
229
                                    orig_error=orig_error)
244
230
 
245
231
 
246
232
class LoopbackVendor(SSHVendor):
247
233
    """SSH "vendor" that connects over a plain TCP socket, not SSH."""
248
 
 
 
234
    
249
235
    def connect_sftp(self, username, password, host, port):
250
236
        sock = socket.socket()
251
237
        try:
252
238
            sock.connect((host, port))
253
239
        except socket.error, e:
254
240
            self._raise_connection_error(host, port=port, orig_error=e)
255
 
        return SFTPClient(SocketAsChannelAdapter(sock))
 
241
        return SFTPClient(LoopbackSFTP(sock))
256
242
 
257
243
register_ssh_vendor('loopback', LoopbackVendor())
258
244
 
273
259
 
274
260
    def _connect(self, username, password, host, port):
275
261
        global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
276
 
 
 
262
        
277
263
        load_host_keys()
278
264
 
279
265
        try:
282
268
            t.start_client()
283
269
        except (paramiko.SSHException, socket.error), e:
284
270
            self._raise_connection_error(host, port=port, orig_error=e)
285
 
 
 
271
            
286
272
        server_key = t.get_remote_server_key()
287
273
        server_key_hex = paramiko.util.hexify(server_key.get_fingerprint())
288
274
        keytype = server_key.get_name()
289
275
        if host in SYSTEM_HOSTKEYS and keytype in SYSTEM_HOSTKEYS[host]:
290
276
            our_server_key = SYSTEM_HOSTKEYS[host][keytype]
291
 
            our_server_key_hex = paramiko.util.hexify(
292
 
                our_server_key.get_fingerprint())
 
277
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
293
278
        elif host in BZR_HOSTKEYS and keytype in BZR_HOSTKEYS[host]:
294
279
            our_server_key = BZR_HOSTKEYS[host][keytype]
295
 
            our_server_key_hex = paramiko.util.hexify(
296
 
                our_server_key.get_fingerprint())
 
280
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
297
281
        else:
298
 
            trace.warning('Adding %s host key for %s: %s'
299
 
                          % (keytype, host, server_key_hex))
 
282
            warning('Adding %s host key for %s: %s' % (keytype, host, server_key_hex))
300
283
            add = getattr(BZR_HOSTKEYS, 'add', None)
301
284
            if add is not None: # paramiko >= 1.X.X
302
285
                BZR_HOSTKEYS.add(host, keytype, server_key)
303
286
            else:
304
287
                BZR_HOSTKEYS.setdefault(host, {})[keytype] = server_key
305
288
            our_server_key = server_key
306
 
            our_server_key_hex = paramiko.util.hexify(
307
 
                our_server_key.get_fingerprint())
 
289
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
308
290
            save_host_keys()
309
291
        if server_key != our_server_key:
310
292
            filename1 = os.path.expanduser('~/.ssh/known_hosts')
311
 
            filename2 = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
312
 
            raise errors.TransportError(
313
 
                'Host keys for %s do not match!  %s != %s' %
 
293
            filename2 = pathjoin(config_dir(), 'ssh_host_keys')
 
294
            raise TransportError('Host keys for %s do not match!  %s != %s' % \
314
295
                (host, our_server_key_hex, server_key_hex),
315
296
                ['Try editing %s or %s' % (filename1, filename2)])
316
297
 
317
 
        _paramiko_auth(username, password, host, port, t)
 
298
        _paramiko_auth(username, password, host, t)
318
299
        return t
319
 
 
 
300
        
320
301
    def connect_sftp(self, username, password, host, port):
321
302
        t = self._connect(username, password, host, port)
322
303
        try:
341
322
    register_ssh_vendor('paramiko', vendor)
342
323
    register_ssh_vendor('none', vendor)
343
324
    register_default_ssh_vendor(vendor)
344
 
    _sftp_connection_errors = (EOFError, paramiko.SSHException)
345
325
    del vendor
346
 
else:
347
 
    _sftp_connection_errors = (EOFError,)
348
326
 
349
327
 
350
328
class SubprocessVendor(SSHVendor):
351
329
    """Abstract base class for vendors that use pipes to a subprocess."""
352
 
 
 
330
    
353
331
    def _connect(self, argv):
354
332
        proc = subprocess.Popen(argv,
355
333
                                stdin=subprocess.PIPE,
362
340
            argv = self._get_vendor_specific_argv(username, host, port,
363
341
                                                  subsystem='sftp')
364
342
            sock = self._connect(argv)
365
 
            return SFTPClient(SocketAsChannelAdapter(sock))
366
 
        except _sftp_connection_errors, e:
 
343
            return SFTPClient(sock)
 
344
        except (EOFError, paramiko.SSHException), e:
367
345
            self._raise_connection_error(host, port=port, orig_error=e)
368
346
        except (OSError, IOError), e:
369
347
            # If the machine is fast enough, ssh can actually exit
391
369
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
392
370
                                  command=None):
393
371
        """Returns the argument list to run the subprocess with.
394
 
 
 
372
        
395
373
        Exactly one of 'subsystem' and 'command' must be specified.
396
374
        """
397
375
        raise NotImplementedError(self._get_vendor_specific_argv)
399
377
 
400
378
class OpenSSHSubprocessVendor(SubprocessVendor):
401
379
    """SSH vendor that uses the 'ssh' executable from OpenSSH."""
402
 
 
 
380
    
403
381
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
404
382
                                  command=None):
 
383
        assert subsystem is not None or command is not None, (
 
384
            'Must specify a command or subsystem')
 
385
        if subsystem is not None:
 
386
            assert command is None, (
 
387
                'subsystem and command are mutually exclusive')
405
388
        args = ['ssh',
406
389
                '-oForwardX11=no', '-oForwardAgent=no',
407
390
                '-oClearAllForwardings=yes', '-oProtocol=2',
424
407
 
425
408
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
426
409
                                  command=None):
 
410
        assert subsystem is not None or command is not None, (
 
411
            'Must specify a command or subsystem')
 
412
        if subsystem is not None:
 
413
            assert command is None, (
 
414
                'subsystem and command are mutually exclusive')
427
415
        args = ['ssh', '-x']
428
416
        if port is not None:
429
417
            args.extend(['-p', str(port)])
434
422
        else:
435
423
            args.extend([host] + command)
436
424
        return args
437
 
 
 
425
    
438
426
register_ssh_vendor('ssh', SSHCorpSubprocessVendor())
439
427
 
440
428
 
443
431
 
444
432
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
445
433
                                  command=None):
446
 
        args = ['plink', '-x', '-a', '-ssh', '-2', '-batch']
 
434
        assert subsystem is not None or command is not None, (
 
435
            'Must specify a command or subsystem')
 
436
        if subsystem is not None:
 
437
            assert command is None, (
 
438
                'subsystem and command are mutually exclusive')
 
439
        args = ['plink', '-x', '-a', '-ssh', '-2']
447
440
        if port is not None:
448
441
            args.extend(['-P', str(port)])
449
442
        if username is not None:
457
450
register_ssh_vendor('plink', PLinkSubprocessVendor())
458
451
 
459
452
 
460
 
def _paramiko_auth(username, password, host, port, paramiko_transport):
461
 
    auth = config.AuthenticationConfig()
462
 
    # paramiko requires a username, but it might be none if nothing was
463
 
    # supplied.  If so, use the local username.
464
 
    if username is None:
465
 
        username = auth.get_user('ssh', host, port=port,
466
 
                                 default=getpass.getuser())
 
453
def _paramiko_auth(username, password, host, paramiko_transport):
 
454
    # paramiko requires a username, but it might be none if nothing was supplied
 
455
    # use the local username, just in case.
 
456
    # We don't override username, because if we aren't using paramiko,
 
457
    # the username might be specified in ~/.ssh/config and we don't want to
 
458
    # force it to something else
 
459
    # Also, it would mess up the self.relpath() functionality
 
460
    username = username or getpass.getuser()
 
461
 
467
462
    if _use_ssh_agent:
468
463
        agent = paramiko.Agent()
469
464
        for key in agent.get_keys():
470
 
            trace.mutter('Trying SSH agent key %s'
471
 
                         % paramiko.util.hexify(key.get_fingerprint()))
 
465
            mutter('Trying SSH agent key %s' % paramiko.util.hexify(key.get_fingerprint()))
472
466
            try:
473
467
                paramiko_transport.auth_publickey(username, key)
474
468
                return
475
469
            except paramiko.SSHException, e:
476
470
                pass
477
 
 
 
471
    
478
472
    # okay, try finding id_rsa or id_dss?  (posix only)
479
473
    if _try_pkey_auth(paramiko_transport, paramiko.RSAKey, username, 'id_rsa'):
480
474
        return
489
483
            pass
490
484
 
491
485
    # give up and ask for a password
492
 
    password = auth.get_password('ssh', host, username, port=port)
 
486
    password = bzrlib.ui.ui_factory.get_password(
 
487
            prompt='SSH %(user)s@%(host)s password',
 
488
            user=username, host=host)
493
489
    try:
494
490
        paramiko_transport.auth_password(username, password)
495
491
    except paramiko.SSHException, e:
496
 
        raise errors.ConnectionError(
497
 
            'Unable to authenticate to SSH host as %s@%s' % (username, host), e)
 
492
        raise ConnectionError('Unable to authenticate to SSH host as %s@%s' %
 
493
                              (username, host), e)
498
494
 
499
495
 
500
496
def _try_pkey_auth(paramiko_transport, pkey_class, username, filename):
504
500
        paramiko_transport.auth_publickey(username, key)
505
501
        return True
506
502
    except paramiko.PasswordRequiredException:
507
 
        password = ui.ui_factory.get_password(
508
 
            prompt='SSH %(filename)s password', filename=filename)
 
503
        password = bzrlib.ui.ui_factory.get_password(
 
504
                prompt='SSH %(filename)s password',
 
505
                filename=filename)
509
506
        try:
510
507
            key = pkey_class.from_private_key_file(filename, password)
511
508
            paramiko_transport.auth_publickey(username, key)
512
509
            return True
513
510
        except paramiko.SSHException:
514
 
            trace.mutter('SSH authentication via %s key failed.'
515
 
                         % (os.path.basename(filename),))
 
511
            mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
516
512
    except paramiko.SSHException:
517
 
        trace.mutter('SSH authentication via %s key failed.'
518
 
                     % (os.path.basename(filename),))
 
513
        mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
519
514
    except IOError:
520
515
        pass
521
516
    return False
528
523
    """
529
524
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
530
525
    try:
531
 
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(
532
 
            os.path.expanduser('~/.ssh/known_hosts'))
533
 
    except IOError, e:
534
 
        trace.mutter('failed to load system host keys: ' + str(e))
535
 
    bzr_hostkey_path = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
 
526
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
 
527
    except Exception, e:
 
528
        mutter('failed to load system host keys: ' + str(e))
 
529
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
536
530
    try:
537
531
        BZR_HOSTKEYS = paramiko.util.load_host_keys(bzr_hostkey_path)
538
 
    except IOError, e:
539
 
        trace.mutter('failed to load bzr host keys: ' + str(e))
 
532
    except Exception, e:
 
533
        mutter('failed to load bzr host keys: ' + str(e))
540
534
        save_host_keys()
541
535
 
542
536
 
545
539
    Save "discovered" host keys in $(config)/ssh_host_keys/.
546
540
    """
547
541
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
548
 
    bzr_hostkey_path = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
549
 
    config.ensure_config_dir_exists()
 
542
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
 
543
    ensure_config_dir_exists()
550
544
 
551
545
    try:
552
546
        f = open(bzr_hostkey_path, 'w')
556
550
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
557
551
        f.close()
558
552
    except IOError, e:
559
 
        trace.mutter('failed to save bzr host keys: ' + str(e))
 
553
        mutter('failed to save bzr host keys: ' + str(e))
560
554
 
561
555
 
562
556
def os_specific_subprocess_params():
563
557
    """Get O/S specific subprocess parameters."""
564
558
    if sys.platform == 'win32':
565
 
        # setting the process group and closing fds is not supported on
 
559
        # setting the process group and closing fds is not supported on 
566
560
        # win32
567
561
        return {}
568
562
    else:
569
 
        # We close fds other than the pipes as the child process does not need
 
563
        # We close fds other than the pipes as the child process does not need 
570
564
        # them to be open.
571
565
        #
572
566
        # We also set the child process to ignore SIGINT.  Normally the signal
574
568
        # this causes it to be seen only by bzr and not by ssh.  Python will
575
569
        # generate a KeyboardInterrupt in bzr, and we will then have a chance
576
570
        # to release locks or do other cleanup over ssh before the connection
577
 
        # goes away.
 
571
        # goes away.  
578
572
        # <https://launchpad.net/products/bzr/+bug/5987>
579
573
        #
580
574
        # Running it in a separate process group is not good because then it
594
588
    def send(self, data):
595
589
        return os.write(self.proc.stdin.fileno(), data)
596
590
 
 
591
    def recv_ready(self):
 
592
        # TODO: jam 20051215 this function is necessary to support the
 
593
        # pipelined() function. In reality, it probably should use
 
594
        # poll() or select() to actually return if there is data
 
595
        # available, otherwise we probably don't get any benefit
 
596
        return True
 
597
 
597
598
    def recv(self, count):
598
599
        return os.read(self.proc.stdout.fileno(), count)
599
600