~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/ssh.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2010-07-07 21:30:06 UTC
  • mfrom: (5333.1.2 better_pyqt_include)
  • Revision ID: pqm@pqm.ubuntu.com-20100707213006-lriphkkbzwwrl7ne
(jameinel) Use a better list of PyQt includes and excludes. (Gary van der
 Merwe)

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
 
1
# Copyright (C) 2006-2010 Robey Pointer <robey@lag.net>
2
2
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
3
3
#
4
4
# This program is free software; you can redistribute it and/or modify
13
13
#
14
14
# You should have received a copy of the GNU General Public License
15
15
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
 
18
18
"""Foundation SSH support for SFTP and smart server."""
19
19
 
20
20
import errno
21
21
import getpass
 
22
import logging
22
23
import os
23
24
import socket
24
25
import subprocess
25
26
import sys
26
27
 
27
 
from bzrlib.config import config_dir, ensure_config_dir_exists
28
 
from bzrlib.errors import (ConnectionError,
29
 
                           ParamikoNotPresent,
30
 
                           SocketConnectionError,
31
 
                           SSHVendorNotFound,
32
 
                           TransportError,
33
 
                           UnknownSSH,
34
 
                           )
35
 
 
36
 
from bzrlib.osutils import pathjoin
37
 
from bzrlib.trace import mutter, warning
38
 
import bzrlib.ui
 
28
from bzrlib import (
 
29
    config,
 
30
    errors,
 
31
    osutils,
 
32
    trace,
 
33
    ui,
 
34
    )
39
35
 
40
36
try:
41
37
    import paramiko
98
94
            try:
99
95
                vendor = self._ssh_vendors[vendor_name]
100
96
            except KeyError:
101
 
                raise UnknownSSH(vendor_name)
 
97
                vendor = self._get_vendor_from_path(vendor_name)
 
98
                if vendor is None:
 
99
                    raise errors.UnknownSSH(vendor_name)
 
100
                vendor.executable_path = vendor_name
102
101
            return vendor
103
102
        return None
104
103
 
114
113
            stdout = stderr = ''
115
114
        return stdout + stderr
116
115
 
117
 
    def _get_vendor_by_version_string(self, version):
 
116
    def _get_vendor_by_version_string(self, version, progname):
118
117
        """Return the vendor or None based on output from the subprocess.
119
118
 
120
119
        :param version: The output of 'ssh -V' like command.
 
120
        :param args: Command line that was run.
121
121
        """
122
122
        vendor = None
123
123
        if 'OpenSSH' in version:
124
 
            mutter('ssh implementation is OpenSSH')
 
124
            trace.mutter('ssh implementation is OpenSSH')
125
125
            vendor = OpenSSHSubprocessVendor()
126
126
        elif 'SSH Secure Shell' in version:
127
 
            mutter('ssh implementation is SSH Corp.')
 
127
            trace.mutter('ssh implementation is SSH Corp.')
128
128
            vendor = SSHCorpSubprocessVendor()
129
 
        elif 'plink' in version:
130
 
            mutter("ssh implementation is Putty's plink.")
 
129
        # As plink user prompts are not handled currently, don't auto-detect
 
130
        # it by inspection below, but keep this vendor detection for if a path
 
131
        # is given in BZR_SSH. See https://bugs.launchpad.net/bugs/414743
 
132
        elif 'plink' in version and progname == 'plink':
 
133
            # Checking if "plink" was the executed argument as Windows
 
134
            # sometimes reports 'ssh -V' incorrectly with 'plink' in it's
 
135
            # version.  See https://bugs.launchpad.net/bzr/+bug/107155
 
136
            trace.mutter("ssh implementation is Putty's plink.")
131
137
            vendor = PLinkSubprocessVendor()
132
138
        return vendor
133
139
 
134
140
    def _get_vendor_by_inspection(self):
135
141
        """Return the vendor or None by checking for known SSH implementations."""
136
 
        for args in [['ssh', '-V'], ['plink', '-V']]:
137
 
            version = self._get_ssh_version_string(args)
138
 
            vendor = self._get_vendor_by_version_string(version)
139
 
            if vendor is not None:
140
 
                return vendor
141
 
        return None
 
142
        version = self._get_ssh_version_string(['ssh', '-V'])
 
143
        return self._get_vendor_by_version_string(version, "ssh")
 
144
 
 
145
    def _get_vendor_from_path(self, path):
 
146
        """Return the vendor or None using the program at the given path"""
 
147
        version = self._get_ssh_version_string([path, '-V'])
 
148
        return self._get_vendor_by_version_string(version, 
 
149
            os.path.splitext(os.path.basename(path))[0])
142
150
 
143
151
    def get_vendor(self, environment=None):
144
152
        """Find out what version of SSH is on the system.
152
160
            if vendor is None:
153
161
                vendor = self._get_vendor_by_inspection()
154
162
                if vendor is None:
155
 
                    mutter('falling back to default implementation')
 
163
                    trace.mutter('falling back to default implementation')
156
164
                    vendor = self._default_ssh_vendor
157
165
                    if vendor is None:
158
 
                        raise SSHVendorNotFound()
 
166
                        raise errors.SSHVendorNotFound()
159
167
            self._cached_ssh_vendor = vendor
160
168
        return self._cached_ssh_vendor
161
169
 
165
173
register_ssh_vendor = _ssh_vendor_manager.register_vendor
166
174
 
167
175
 
168
 
def _ignore_sigint():
 
176
def _ignore_signals():
169
177
    # TODO: This should possibly ignore SIGHUP as well, but bzr currently
170
178
    # doesn't handle it itself.
171
179
    # <https://launchpad.net/products/bzr/+bug/41433/+index>
172
180
    import signal
173
181
    signal.signal(signal.SIGINT, signal.SIG_IGN)
174
 
 
175
 
 
176
 
class LoopbackSFTP(object):
 
182
    # GZ 2010-02-19: Perhaps make this check if breakin is installed instead
 
183
    if signal.getsignal(signal.SIGQUIT) != signal.SIG_DFL:
 
184
        signal.signal(signal.SIGQUIT, signal.SIG_IGN)
 
185
 
 
186
 
 
187
class SocketAsChannelAdapter(object):
177
188
    """Simple wrapper for a socket that pretends to be a paramiko Channel."""
178
189
 
179
190
    def __init__(self, sock):
180
191
        self.__socket = sock
181
 
 
 
192
 
 
193
    def get_name(self):
 
194
        return "bzr SocketAsChannelAdapter"
 
195
 
182
196
    def send(self, data):
183
197
        return self.__socket.send(data)
184
198
 
185
199
    def recv(self, n):
186
 
        return self.__socket.recv(n)
 
200
        try:
 
201
            return self.__socket.recv(n)
 
202
        except socket.error, e:
 
203
            if e.args[0] in (errno.EPIPE, errno.ECONNRESET, errno.ECONNABORTED,
 
204
                             errno.EBADF):
 
205
                # Connection has closed.  Paramiko expects an empty string in
 
206
                # this case, not an exception.
 
207
                return ''
 
208
            raise
187
209
 
188
210
    def recv_ready(self):
 
211
        # TODO: jam 20051215 this function is necessary to support the
 
212
        # pipelined() function. In reality, it probably should use
 
213
        # poll() or select() to actually return if there is data
 
214
        # available, otherwise we probably don't get any benefit
189
215
        return True
190
216
 
191
217
    def close(self):
194
220
 
195
221
class SSHVendor(object):
196
222
    """Abstract base class for SSH vendor implementations."""
197
 
    
 
223
 
198
224
    def connect_sftp(self, username, password, host, port):
199
225
        """Make an SSH connection, and return an SFTPClient.
200
 
        
 
226
 
201
227
        :param username: an ascii string
202
228
        :param password: an ascii string
203
229
        :param host: a host name as an ascii string
212
238
 
213
239
    def connect_ssh(self, username, password, host, port, command):
214
240
        """Make an SSH connection.
215
 
        
216
 
        :returns: something with a `close` method, and a `get_filelike_channels`
217
 
            method that returns a pair of (read, write) filelike objects.
 
241
 
 
242
        :returns: an SSHConnection.
218
243
        """
219
244
        raise NotImplementedError(self.connect_ssh)
220
 
        
 
245
 
221
246
    def _raise_connection_error(self, host, port=None, orig_error=None,
222
247
                                msg='Unable to connect to SSH host'):
223
248
        """Raise a SocketConnectionError with properly formatted host.
225
250
        This just unifies all the locations that try to raise ConnectionError,
226
251
        so that they format things properly.
227
252
        """
228
 
        raise SocketConnectionError(host=host, port=port, msg=msg,
229
 
                                    orig_error=orig_error)
 
253
        raise errors.SocketConnectionError(host=host, port=port, msg=msg,
 
254
                                           orig_error=orig_error)
230
255
 
231
256
 
232
257
class LoopbackVendor(SSHVendor):
233
258
    """SSH "vendor" that connects over a plain TCP socket, not SSH."""
234
 
    
 
259
 
235
260
    def connect_sftp(self, username, password, host, port):
236
261
        sock = socket.socket()
237
262
        try:
238
263
            sock.connect((host, port))
239
264
        except socket.error, e:
240
265
            self._raise_connection_error(host, port=port, orig_error=e)
241
 
        return SFTPClient(LoopbackSFTP(sock))
 
266
        return SFTPClient(SocketAsChannelAdapter(sock))
242
267
 
243
268
register_ssh_vendor('loopback', LoopbackVendor())
244
269
 
245
270
 
246
 
class _ParamikoSSHConnection(object):
247
 
    def __init__(self, channel):
248
 
        self.channel = channel
249
 
 
250
 
    def get_filelike_channels(self):
251
 
        return self.channel.makefile('rb'), self.channel.makefile('wb')
252
 
 
253
 
    def close(self):
254
 
        return self.channel.close()
255
 
 
256
 
 
257
271
class ParamikoVendor(SSHVendor):
258
272
    """Vendor that uses paramiko."""
259
273
 
260
274
    def _connect(self, username, password, host, port):
261
275
        global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
262
 
        
 
276
 
263
277
        load_host_keys()
264
278
 
265
279
        try:
268
282
            t.start_client()
269
283
        except (paramiko.SSHException, socket.error), e:
270
284
            self._raise_connection_error(host, port=port, orig_error=e)
271
 
            
 
285
 
272
286
        server_key = t.get_remote_server_key()
273
287
        server_key_hex = paramiko.util.hexify(server_key.get_fingerprint())
274
288
        keytype = server_key.get_name()
275
289
        if host in SYSTEM_HOSTKEYS and keytype in SYSTEM_HOSTKEYS[host]:
276
290
            our_server_key = SYSTEM_HOSTKEYS[host][keytype]
277
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
291
            our_server_key_hex = paramiko.util.hexify(
 
292
                our_server_key.get_fingerprint())
278
293
        elif host in BZR_HOSTKEYS and keytype in BZR_HOSTKEYS[host]:
279
294
            our_server_key = BZR_HOSTKEYS[host][keytype]
280
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
295
            our_server_key_hex = paramiko.util.hexify(
 
296
                our_server_key.get_fingerprint())
281
297
        else:
282
 
            warning('Adding %s host key for %s: %s' % (keytype, host, server_key_hex))
 
298
            trace.warning('Adding %s host key for %s: %s'
 
299
                          % (keytype, host, server_key_hex))
283
300
            add = getattr(BZR_HOSTKEYS, 'add', None)
284
301
            if add is not None: # paramiko >= 1.X.X
285
302
                BZR_HOSTKEYS.add(host, keytype, server_key)
286
303
            else:
287
304
                BZR_HOSTKEYS.setdefault(host, {})[keytype] = server_key
288
305
            our_server_key = server_key
289
 
            our_server_key_hex = paramiko.util.hexify(our_server_key.get_fingerprint())
 
306
            our_server_key_hex = paramiko.util.hexify(
 
307
                our_server_key.get_fingerprint())
290
308
            save_host_keys()
291
309
        if server_key != our_server_key:
292
310
            filename1 = os.path.expanduser('~/.ssh/known_hosts')
293
 
            filename2 = pathjoin(config_dir(), 'ssh_host_keys')
294
 
            raise TransportError('Host keys for %s do not match!  %s != %s' % \
 
311
            filename2 = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
 
312
            raise errors.TransportError(
 
313
                'Host keys for %s do not match!  %s != %s' %
295
314
                (host, our_server_key_hex, server_key_hex),
296
315
                ['Try editing %s or %s' % (filename1, filename2)])
297
316
 
298
 
        _paramiko_auth(username, password, host, t)
 
317
        _paramiko_auth(username, password, host, port, t)
299
318
        return t
300
 
        
 
319
 
301
320
    def connect_sftp(self, username, password, host, port):
302
321
        t = self._connect(username, password, host, port)
303
322
        try:
322
341
    register_ssh_vendor('paramiko', vendor)
323
342
    register_ssh_vendor('none', vendor)
324
343
    register_default_ssh_vendor(vendor)
 
344
    _sftp_connection_errors = (EOFError, paramiko.SSHException)
325
345
    del vendor
 
346
else:
 
347
    _sftp_connection_errors = (EOFError,)
326
348
 
327
349
 
328
350
class SubprocessVendor(SSHVendor):
329
351
    """Abstract base class for vendors that use pipes to a subprocess."""
330
 
    
 
352
 
331
353
    def _connect(self, argv):
332
 
        proc = subprocess.Popen(argv,
333
 
                                stdin=subprocess.PIPE,
334
 
                                stdout=subprocess.PIPE,
 
354
        # Attempt to make a socketpair to use as stdin/stdout for the SSH
 
355
        # subprocess.  We prefer sockets to pipes because they support
 
356
        # non-blocking short reads, allowing us to optimistically read 64k (or
 
357
        # whatever) chunks.
 
358
        try:
 
359
            my_sock, subproc_sock = socket.socketpair()
 
360
        except (AttributeError, socket.error):
 
361
            # This platform doesn't support socketpair(), so just use ordinary
 
362
            # pipes instead.
 
363
            stdin = stdout = subprocess.PIPE
 
364
            sock = None
 
365
        else:
 
366
            stdin = stdout = subproc_sock
 
367
            sock = my_sock
 
368
        proc = subprocess.Popen(argv, stdin=stdin, stdout=stdout,
335
369
                                **os_specific_subprocess_params())
336
 
        return SSHSubprocess(proc)
 
370
        return SSHSubprocessConnection(proc, sock=sock)
337
371
 
338
372
    def connect_sftp(self, username, password, host, port):
339
373
        try:
340
374
            argv = self._get_vendor_specific_argv(username, host, port,
341
375
                                                  subsystem='sftp')
342
376
            sock = self._connect(argv)
343
 
            return SFTPClient(sock)
344
 
        except (EOFError, paramiko.SSHException), e:
 
377
            return SFTPClient(SocketAsChannelAdapter(sock))
 
378
        except _sftp_connection_errors, e:
345
379
            self._raise_connection_error(host, port=port, orig_error=e)
346
380
        except (OSError, IOError), e:
347
381
            # If the machine is fast enough, ssh can actually exit
369
403
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
370
404
                                  command=None):
371
405
        """Returns the argument list to run the subprocess with.
372
 
        
 
406
 
373
407
        Exactly one of 'subsystem' and 'command' must be specified.
374
408
        """
375
409
        raise NotImplementedError(self._get_vendor_specific_argv)
377
411
 
378
412
class OpenSSHSubprocessVendor(SubprocessVendor):
379
413
    """SSH vendor that uses the 'ssh' executable from OpenSSH."""
380
 
    
 
414
 
 
415
    executable_path = 'ssh'
 
416
 
381
417
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
382
418
                                  command=None):
383
 
        assert subsystem is not None or command is not None, (
384
 
            'Must specify a command or subsystem')
385
 
        if subsystem is not None:
386
 
            assert command is None, (
387
 
                'subsystem and command are mutually exclusive')
388
 
        args = ['ssh',
 
419
        args = [self.executable_path,
389
420
                '-oForwardX11=no', '-oForwardAgent=no',
390
421
                '-oClearAllForwardings=yes', '-oProtocol=2',
391
422
                '-oNoHostAuthenticationForLocalhost=yes']
405
436
class SSHCorpSubprocessVendor(SubprocessVendor):
406
437
    """SSH vendor that uses the 'ssh' executable from SSH Corporation."""
407
438
 
 
439
    executable_path = 'ssh'
 
440
 
408
441
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
409
442
                                  command=None):
410
 
        assert subsystem is not None or command is not None, (
411
 
            'Must specify a command or subsystem')
412
 
        if subsystem is not None:
413
 
            assert command is None, (
414
 
                'subsystem and command are mutually exclusive')
415
 
        args = ['ssh', '-x']
 
443
        args = [self.executable_path, '-x']
416
444
        if port is not None:
417
445
            args.extend(['-p', str(port)])
418
446
        if username is not None:
422
450
        else:
423
451
            args.extend([host] + command)
424
452
        return args
425
 
    
426
 
register_ssh_vendor('ssh', SSHCorpSubprocessVendor())
 
453
 
 
454
register_ssh_vendor('sshcorp', SSHCorpSubprocessVendor())
427
455
 
428
456
 
429
457
class PLinkSubprocessVendor(SubprocessVendor):
430
458
    """SSH vendor that uses the 'plink' executable from Putty."""
431
459
 
 
460
    executable_path = 'plink'
 
461
 
432
462
    def _get_vendor_specific_argv(self, username, host, port, subsystem=None,
433
463
                                  command=None):
434
 
        assert subsystem is not None or command is not None, (
435
 
            'Must specify a command or subsystem')
436
 
        if subsystem is not None:
437
 
            assert command is None, (
438
 
                'subsystem and command are mutually exclusive')
439
 
        args = ['plink', '-x', '-a', '-ssh', '-2']
 
464
        args = [self.executable_path, '-x', '-a', '-ssh', '-2', '-batch']
440
465
        if port is not None:
441
466
            args.extend(['-P', str(port)])
442
467
        if username is not None:
450
475
register_ssh_vendor('plink', PLinkSubprocessVendor())
451
476
 
452
477
 
453
 
def _paramiko_auth(username, password, host, paramiko_transport):
454
 
    # paramiko requires a username, but it might be none if nothing was supplied
455
 
    # use the local username, just in case.
456
 
    # We don't override username, because if we aren't using paramiko,
457
 
    # the username might be specified in ~/.ssh/config and we don't want to
458
 
    # force it to something else
459
 
    # Also, it would mess up the self.relpath() functionality
460
 
    username = username or getpass.getuser()
461
 
 
 
478
def _paramiko_auth(username, password, host, port, paramiko_transport):
 
479
    auth = config.AuthenticationConfig()
 
480
    # paramiko requires a username, but it might be none if nothing was
 
481
    # supplied.  If so, use the local username.
 
482
    if username is None:
 
483
        username = auth.get_user('ssh', host, port=port,
 
484
                                 default=getpass.getuser())
462
485
    if _use_ssh_agent:
463
486
        agent = paramiko.Agent()
464
487
        for key in agent.get_keys():
465
 
            mutter('Trying SSH agent key %s' % paramiko.util.hexify(key.get_fingerprint()))
 
488
            trace.mutter('Trying SSH agent key %s'
 
489
                         % paramiko.util.hexify(key.get_fingerprint()))
466
490
            try:
467
491
                paramiko_transport.auth_publickey(username, key)
468
492
                return
469
493
            except paramiko.SSHException, e:
470
494
                pass
471
 
    
 
495
 
472
496
    # okay, try finding id_rsa or id_dss?  (posix only)
473
497
    if _try_pkey_auth(paramiko_transport, paramiko.RSAKey, username, 'id_rsa'):
474
498
        return
475
499
    if _try_pkey_auth(paramiko_transport, paramiko.DSSKey, username, 'id_dsa'):
476
500
        return
477
501
 
 
502
    # If we have gotten this far, we are about to try for passwords, do an
 
503
    # auth_none check to see if it is even supported.
 
504
    supported_auth_types = []
 
505
    try:
 
506
        # Note that with paramiko <1.7.5 this logs an INFO message:
 
507
        #    Authentication type (none) not permitted.
 
508
        # So we explicitly disable the logging level for this action
 
509
        old_level = paramiko_transport.logger.level
 
510
        paramiko_transport.logger.setLevel(logging.WARNING)
 
511
        try:
 
512
            paramiko_transport.auth_none(username)
 
513
        finally:
 
514
            paramiko_transport.logger.setLevel(old_level)
 
515
    except paramiko.BadAuthenticationType, e:
 
516
        # Supported methods are in the exception
 
517
        supported_auth_types = e.allowed_types
 
518
    except paramiko.SSHException, e:
 
519
        # Don't know what happened, but just ignore it
 
520
        pass
 
521
    # We treat 'keyboard-interactive' and 'password' auth methods identically,
 
522
    # because Paramiko's auth_password method will automatically try
 
523
    # 'keyboard-interactive' auth (using the password as the response) if
 
524
    # 'password' auth is not available.  Apparently some Debian and Gentoo
 
525
    # OpenSSH servers require this.
 
526
    # XXX: It's possible for a server to require keyboard-interactive auth that
 
527
    # requires something other than a single password, but we currently don't
 
528
    # support that.
 
529
    if ('password' not in supported_auth_types and
 
530
        'keyboard-interactive' not in supported_auth_types):
 
531
        raise errors.ConnectionError('Unable to authenticate to SSH host as'
 
532
            '\n  %s@%s\nsupported auth types: %s'
 
533
            % (username, host, supported_auth_types))
 
534
 
478
535
    if password:
479
536
        try:
480
537
            paramiko_transport.auth_password(username, password)
483
540
            pass
484
541
 
485
542
    # give up and ask for a password
486
 
    password = bzrlib.ui.ui_factory.get_password(
487
 
            prompt='SSH %(user)s@%(host)s password',
488
 
            user=username, host=host)
489
 
    try:
490
 
        paramiko_transport.auth_password(username, password)
491
 
    except paramiko.SSHException, e:
492
 
        raise ConnectionError('Unable to authenticate to SSH host as %s@%s' %
493
 
                              (username, host), e)
 
543
    password = auth.get_password('ssh', host, username, port=port)
 
544
    # get_password can still return None, which means we should not prompt
 
545
    if password is not None:
 
546
        try:
 
547
            paramiko_transport.auth_password(username, password)
 
548
        except paramiko.SSHException, e:
 
549
            raise errors.ConnectionError(
 
550
                'Unable to authenticate to SSH host as'
 
551
                '\n  %s@%s\n' % (username, host), e)
 
552
    else:
 
553
        raise errors.ConnectionError('Unable to authenticate to SSH host as'
 
554
                                     '  %s@%s' % (username, host))
494
555
 
495
556
 
496
557
def _try_pkey_auth(paramiko_transport, pkey_class, username, filename):
500
561
        paramiko_transport.auth_publickey(username, key)
501
562
        return True
502
563
    except paramiko.PasswordRequiredException:
503
 
        password = bzrlib.ui.ui_factory.get_password(
504
 
                prompt='SSH %(filename)s password',
505
 
                filename=filename)
 
564
        password = ui.ui_factory.get_password(
 
565
            prompt='SSH %(filename)s password', filename=filename)
506
566
        try:
507
567
            key = pkey_class.from_private_key_file(filename, password)
508
568
            paramiko_transport.auth_publickey(username, key)
509
569
            return True
510
570
        except paramiko.SSHException:
511
 
            mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
 
571
            trace.mutter('SSH authentication via %s key failed.'
 
572
                         % (os.path.basename(filename),))
512
573
    except paramiko.SSHException:
513
 
        mutter('SSH authentication via %s key failed.' % (os.path.basename(filename),))
 
574
        trace.mutter('SSH authentication via %s key failed.'
 
575
                     % (os.path.basename(filename),))
514
576
    except IOError:
515
577
        pass
516
578
    return False
523
585
    """
524
586
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
525
587
    try:
526
 
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(os.path.expanduser('~/.ssh/known_hosts'))
527
 
    except Exception, e:
528
 
        mutter('failed to load system host keys: ' + str(e))
529
 
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
 
588
        SYSTEM_HOSTKEYS = paramiko.util.load_host_keys(
 
589
            os.path.expanduser('~/.ssh/known_hosts'))
 
590
    except IOError, e:
 
591
        trace.mutter('failed to load system host keys: ' + str(e))
 
592
    bzr_hostkey_path = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
530
593
    try:
531
594
        BZR_HOSTKEYS = paramiko.util.load_host_keys(bzr_hostkey_path)
532
 
    except Exception, e:
533
 
        mutter('failed to load bzr host keys: ' + str(e))
 
595
    except IOError, e:
 
596
        trace.mutter('failed to load bzr host keys: ' + str(e))
534
597
        save_host_keys()
535
598
 
536
599
 
539
602
    Save "discovered" host keys in $(config)/ssh_host_keys/.
540
603
    """
541
604
    global SYSTEM_HOSTKEYS, BZR_HOSTKEYS
542
 
    bzr_hostkey_path = pathjoin(config_dir(), 'ssh_host_keys')
543
 
    ensure_config_dir_exists()
 
605
    bzr_hostkey_path = osutils.pathjoin(config.config_dir(), 'ssh_host_keys')
 
606
    config.ensure_config_dir_exists()
544
607
 
545
608
    try:
546
609
        f = open(bzr_hostkey_path, 'w')
550
613
                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
551
614
        f.close()
552
615
    except IOError, e:
553
 
        mutter('failed to save bzr host keys: ' + str(e))
 
616
        trace.mutter('failed to save bzr host keys: ' + str(e))
554
617
 
555
618
 
556
619
def os_specific_subprocess_params():
557
620
    """Get O/S specific subprocess parameters."""
558
621
    if sys.platform == 'win32':
559
 
        # setting the process group and closing fds is not supported on 
 
622
        # setting the process group and closing fds is not supported on
560
623
        # win32
561
624
        return {}
562
625
    else:
563
 
        # We close fds other than the pipes as the child process does not need 
 
626
        # We close fds other than the pipes as the child process does not need
564
627
        # them to be open.
565
628
        #
566
629
        # We also set the child process to ignore SIGINT.  Normally the signal
568
631
        # this causes it to be seen only by bzr and not by ssh.  Python will
569
632
        # generate a KeyboardInterrupt in bzr, and we will then have a chance
570
633
        # to release locks or do other cleanup over ssh before the connection
571
 
        # goes away.  
 
634
        # goes away.
572
635
        # <https://launchpad.net/products/bzr/+bug/5987>
573
636
        #
574
637
        # Running it in a separate process group is not good because then it
575
638
        # can't get non-echoed input of a password or passphrase.
576
639
        # <https://launchpad.net/products/bzr/+bug/40508>
577
 
        return {'preexec_fn': _ignore_sigint,
 
640
        return {'preexec_fn': _ignore_signals,
578
641
                'close_fds': True,
579
642
                }
580
643
 
581
 
 
582
 
class SSHSubprocess(object):
583
 
    """A socket-like object that talks to an ssh subprocess via pipes."""
584
 
 
585
 
    def __init__(self, proc):
 
644
import weakref
 
645
_subproc_weakrefs = set()
 
646
 
 
647
def _close_ssh_proc(proc):
 
648
    for func in [proc.stdin.close, proc.stdout.close, proc.wait]:
 
649
        try:
 
650
            func()
 
651
        except OSError:
 
652
            pass
 
653
 
 
654
 
 
655
class SSHConnection(object):
 
656
    """Abstract base class for SSH connections."""
 
657
 
 
658
    def get_sock_or_pipes(self):
 
659
        """Returns a (kind, io_object) pair.
 
660
 
 
661
        If kind == 'socket', then io_object is a socket.
 
662
 
 
663
        If kind == 'pipes', then io_object is a pair of file-like objects
 
664
        (read_from, write_to).
 
665
        """
 
666
        raise NotImplementedError(self.get_sock_or_pipes)
 
667
 
 
668
    def close(self):
 
669
        raise NotImplementedError(self.close)
 
670
 
 
671
 
 
672
class SSHSubprocessConnection(SSHConnection):
 
673
    """A connection to an ssh subprocess via pipes or a socket.
 
674
 
 
675
    This class is also socket-like enough to be used with
 
676
    SocketAsChannelAdapter (it has 'send' and 'recv' methods).
 
677
    """
 
678
 
 
679
    def __init__(self, proc, sock=None):
 
680
        """Constructor.
 
681
 
 
682
        :param proc: a subprocess.Popen
 
683
        :param sock: if proc.stdin/out is a socket from a socketpair, then sock
 
684
            should bzrlib's half of that socketpair.  If not passed, proc's
 
685
            stdin/out is assumed to be ordinary pipes.
 
686
        """
586
687
        self.proc = proc
 
688
        self._sock = sock
 
689
        # Add a weakref to proc that will attempt to do the same as self.close
 
690
        # to avoid leaving processes lingering indefinitely.
 
691
        def terminate(ref):
 
692
            _subproc_weakrefs.remove(ref)
 
693
            _close_ssh_proc(proc)
 
694
        _subproc_weakrefs.add(weakref.ref(self, terminate))
587
695
 
588
696
    def send(self, data):
589
 
        return os.write(self.proc.stdin.fileno(), data)
590
 
 
591
 
    def recv_ready(self):
592
 
        # TODO: jam 20051215 this function is necessary to support the
593
 
        # pipelined() function. In reality, it probably should use
594
 
        # poll() or select() to actually return if there is data
595
 
        # available, otherwise we probably don't get any benefit
596
 
        return True
 
697
        if self._sock is not None:
 
698
            return self._sock.send(data)
 
699
        else:
 
700
            return os.write(self.proc.stdin.fileno(), data)
597
701
 
598
702
    def recv(self, count):
599
 
        return os.read(self.proc.stdout.fileno(), count)
600
 
 
601
 
    def close(self):
602
 
        self.proc.stdin.close()
603
 
        self.proc.stdout.close()
604
 
        self.proc.wait()
605
 
 
606
 
    def get_filelike_channels(self):
607
 
        return (self.proc.stdout, self.proc.stdin)
 
703
        if self._sock is not None:
 
704
            return self._sock.recv(count)
 
705
        else:
 
706
            return os.read(self.proc.stdout.fileno(), count)
 
707
 
 
708
    def close(self):
 
709
        _close_ssh_proc(self.proc)
 
710
 
 
711
    def get_sock_or_pipes(self):
 
712
        if self._sock is not None:
 
713
            return 'socket', self._sock
 
714
        else:
 
715
            return 'pipes', (self.proc.stdout, self.proc.stdin)
 
716
 
 
717
 
 
718
class _ParamikoSSHConnection(SSHConnection):
 
719
    """An SSH connection via paramiko."""
 
720
 
 
721
    def __init__(self, channel):
 
722
        self.channel = channel
 
723
 
 
724
    def get_sock_or_pipes(self):
 
725
        return ('socket', self.channel)
 
726
 
 
727
    def close(self):
 
728
        return self.channel.close()
 
729
 
608
730