~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/sftp.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2008-09-11 06:10:59 UTC
  • mfrom: (3702.1.1 trivial)
  • Revision ID: pqm@pqm.ubuntu.com-20080911061059-svzqfejar17ui4zw
(mbp) KnitVersionedFiles repr

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Canonical Ltd
 
1
# Copyright (C) 2005 Robey Pointer <robey@lag.net>
 
2
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
2
3
#
3
4
# This program is free software; you can redistribute it and/or modify
4
5
# it under the terms of the GNU General Public License as published by
12
13
#
13
14
# You should have received a copy of the GNU General Public License
14
15
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
17
 
17
18
"""Implementation of Transport over SFTP, using paramiko."""
18
19
 
28
29
import itertools
29
30
import os
30
31
import random
 
32
import select
 
33
import socket
31
34
import stat
32
35
import sys
33
36
import time
 
37
import urllib
 
38
import urlparse
34
39
import warnings
35
40
 
36
41
from bzrlib import (
37
 
    config,
38
 
    debug,
39
42
    errors,
40
43
    urlutils,
41
44
    )
42
45
from bzrlib.errors import (FileExists,
43
 
                           NoSuchFile,
 
46
                           NoSuchFile, PathNotChild,
44
47
                           TransportError,
45
48
                           LockError,
46
49
                           PathError,
47
50
                           ParamikoNotPresent,
48
51
                           )
49
 
from bzrlib.osutils import fancy_rename
 
52
from bzrlib.osutils import pathjoin, fancy_rename, getcwd
 
53
from bzrlib.symbol_versioning import (
 
54
        deprecated_function,
 
55
        )
50
56
from bzrlib.trace import mutter, warning
51
57
from bzrlib.transport import (
52
58
    FileFileStream,
53
59
    _file_streams,
 
60
    local,
 
61
    Server,
54
62
    ssh,
55
63
    ConnectedTransport,
56
64
    )
75
83
else:
76
84
    from paramiko.sftp import (SFTP_FLAG_WRITE, SFTP_FLAG_CREATE,
77
85
                               SFTP_FLAG_EXCL, SFTP_FLAG_TRUNC,
78
 
                               SFTP_OK, CMD_HANDLE, CMD_OPEN)
 
86
                               CMD_HANDLE, CMD_OPEN)
79
87
    from paramiko.sftp_attr import SFTPAttributes
80
88
    from paramiko.sftp_file import SFTPFile
81
89
 
87
95
 
88
96
class SFTPLock(object):
89
97
    """This fakes a lock in a remote location.
90
 
 
 
98
    
91
99
    A present lock is indicated just by the existence of a file.  This
92
 
    doesn't work well on all transports and they are only used in
 
100
    doesn't work well on all transports and they are only used in 
93
101
    deprecated storage formats.
94
102
    """
95
 
 
 
103
    
96
104
    __slots__ = ['path', 'lock_path', 'lock_file', 'transport']
97
105
 
98
106
    def __init__(self, path, transport):
107
115
        except FileExists:
108
116
            raise LockError('File %r already locked' % (self.path,))
109
117
 
 
118
    def __del__(self):
 
119
        """Should this warn, or actually try to cleanup?"""
 
120
        if self.lock_file:
 
121
            warning("SFTPLock %r not explicitly unlocked" % (self.path,))
 
122
            self.unlock()
 
123
 
110
124
    def unlock(self):
111
125
        if not self.lock_file:
112
126
            return
125
139
    # See _get_requests for an explanation.
126
140
    _max_request_size = 32768
127
141
 
128
 
    def __init__(self, original_offsets, relpath, _report_activity):
 
142
    def __init__(self, original_offsets, relpath):
129
143
        """Create a new readv helper.
130
144
 
131
145
        :param original_offsets: The original requests given by the caller of
132
146
            readv()
133
147
        :param relpath: The name of the file (if known)
134
 
        :param _report_activity: A Transport._report_activity bound method,
135
 
            to be called as data arrives.
136
148
        """
137
149
        self.original_offsets = list(original_offsets)
138
150
        self.relpath = relpath
139
 
        self._report_activity = _report_activity
140
151
 
141
152
    def _get_requests(self):
142
153
        """Break up the offsets into individual requests over sftp.
168
179
                requests.append((start, next_size))
169
180
                size -= next_size
170
181
                start += next_size
171
 
        if 'sftp' in debug.debug_flags:
172
 
            mutter('SFTP.readv(%s) %s offsets => %s coalesced => %s requests',
173
 
                self.relpath, len(sorted_offsets), len(coalesced),
174
 
                len(requests))
 
182
        mutter('SFTP.readv(%s) %s offsets => %s coalesced => %s requests',
 
183
               self.relpath, len(sorted_offsets), len(coalesced),
 
184
               len(requests))
175
185
        return requests
176
186
 
177
187
    def request_and_yield_offsets(self, fp):
208
218
            if len(data) != length:
209
219
                raise errors.ShortReadvError(self.relpath,
210
220
                    start, length, len(data))
211
 
            self._report_activity(length, 'read')
212
221
            if last_end is None:
213
222
                # This is the first request, just buffer it
214
223
                buffered_data = [data]
268
277
                    buffered = buffered[buffered_offset:]
269
278
                    buffered_data = [buffered]
270
279
                    buffered_len = len(buffered)
271
 
        # now that the data stream is done, close the handle
272
 
        fp.close()
273
280
        if buffered_len:
274
281
            buffered = ''.join(buffered_data)
275
282
            del buffered_data[:]
276
283
            data_chunks.append((input_start, buffered))
277
284
        if data_chunks:
278
 
            if 'sftp' in debug.debug_flags:
279
 
                mutter('SFTP readv left with %d out-of-order bytes',
280
 
                    sum(map(lambda x: len(x[1]), data_chunks)))
 
285
            mutter('SFTP readv left with %d out-of-order bytes',
 
286
                   sum(map(lambda x: len(x[1]), data_chunks)))
281
287
            # We've processed all the readv data, at this point, anything we
282
288
            # couldn't process is in data_chunks. This doesn't happen often, so
283
289
            # this code path isn't optimized
288
294
            # get the previous node
289
295
            while True:
290
296
                idx = bisect.bisect_left(data_chunks, (cur_offset,))
291
 
                if idx < len(data_chunks) and data_chunks[idx][0] == cur_offset:
292
 
                    # The data starts here
 
297
                if data_chunks[idx][0] == cur_offset: # The data starts here
293
298
                    data = data_chunks[idx][1][:cur_size]
294
299
                elif idx > 0:
295
300
                    # The data is in a portion of a previous page
330
335
    # up the request itself, rather than us having to worry about it
331
336
    _max_request_size = 32768
332
337
 
 
338
    def __init__(self, base, _from_transport=None):
 
339
        super(SFTPTransport, self).__init__(base,
 
340
                                            _from_transport=_from_transport)
 
341
 
333
342
    def _remote_path(self, relpath):
334
343
        """Return the path to be passed along the sftp protocol for relpath.
335
 
 
 
344
        
336
345
        :param relpath: is a urlencoded string.
337
346
        """
338
 
        remote_path = self._parsed_url.clone(relpath).path
 
347
        relative = urlutils.unescape(relpath).encode('utf-8')
 
348
        remote_path = self._combine_paths(self._path, relative)
339
349
        # the initial slash should be removed from the path, and treated as a
340
350
        # homedir relative path (the path begins with a double slash if it is
341
351
        # absolute).  see draft-ietf-secsh-scp-sftp-ssh-uri-03.txt
360
370
        in base url at transport creation time.
361
371
        """
362
372
        if credentials is None:
363
 
            password = self._parsed_url.password
 
373
            password = self._password
364
374
        else:
365
375
            password = credentials
366
376
 
367
377
        vendor = ssh._get_ssh_vendor()
368
 
        user = self._parsed_url.user
369
 
        if user is None:
370
 
            auth = config.AuthenticationConfig()
371
 
            user = auth.get_user('ssh', self._parsed_url.host,
372
 
                self._parsed_url.port)
373
 
        connection = vendor.connect_sftp(self._parsed_url.user, password,
374
 
            self._parsed_url.host, self._parsed_url.port)
375
 
        return connection, (user, password)
376
 
 
377
 
    def disconnect(self):
378
 
        connection = self._get_connection()
379
 
        if connection is not None:
380
 
            connection.close()
 
378
        connection = vendor.connect_sftp(self._user, password,
 
379
                                         self._host, self._port)
 
380
        return connection, password
381
381
 
382
382
    def _get_sftp(self):
383
383
        """Ensures that a connection is established"""
394
394
        """
395
395
        try:
396
396
            self._get_sftp().stat(self._remote_path(relpath))
397
 
            # stat result is about 20 bytes, let's say
398
 
            self._report_activity(20, 'read')
399
397
            return True
400
398
        except IOError:
401
399
            return False
402
400
 
403
401
    def get(self, relpath):
404
 
        """Get the file at the given relative path.
 
402
        """
 
403
        Get the file at the given relative path.
405
404
 
406
405
        :param relpath: The relative path to the file
407
406
        """
415
414
            self._translate_io_exception(e, path, ': error retrieving',
416
415
                failure_exc=errors.ReadError)
417
416
 
418
 
    def get_bytes(self, relpath):
419
 
        # reimplement this here so that we can report how many bytes came back
420
 
        f = self.get(relpath)
421
 
        try:
422
 
            bytes = f.read()
423
 
            self._report_activity(len(bytes), 'read')
424
 
            return bytes
425
 
        finally:
426
 
            f.close()
427
 
 
428
417
    def _readv(self, relpath, offsets):
429
418
        """See Transport.readv()"""
430
419
        # We overload the default readv() because we want to use a file
439
428
            readv = getattr(fp, 'readv', None)
440
429
            if readv:
441
430
                return self._sftp_readv(fp, offsets, relpath)
442
 
            if 'sftp' in debug.debug_flags:
443
 
                mutter('seek and read %s offsets', len(offsets))
 
431
            mutter('seek and read %s offsets', len(offsets))
444
432
            return self._seek_and_read(fp, offsets, relpath)
445
433
        except (IOError, paramiko.SSHException), e:
446
434
            self._translate_io_exception(e, path, ': error retrieving')
453
441
        """
454
442
        return 64 * 1024
455
443
 
456
 
    def _sftp_readv(self, fp, offsets, relpath):
 
444
    def _sftp_readv(self, fp, offsets, relpath='<unknown>'):
457
445
        """Use the readv() member of fp to do async readv.
458
446
 
459
 
        Then read them using paramiko.readv(). paramiko.readv()
 
447
        And then read them using paramiko.readv(). paramiko.readv()
460
448
        does not support ranges > 64K, so it caps the request size, and
461
 
        just reads until it gets all the stuff it wants.
 
449
        just reads until it gets all the stuff it wants
462
450
        """
463
 
        helper = _SFTPReadvHelper(offsets, relpath, self._report_activity)
 
451
        helper = _SFTPReadvHelper(offsets, relpath)
464
452
        return helper.request_and_yield_offsets(fp)
465
453
 
466
454
    def put_file(self, relpath, f, mode=None):
492
480
            #      sticky bit. So it is probably best to stop chmodding, and
493
481
            #      just tell users that they need to set the umask correctly.
494
482
            #      The attr.st_mode = mode, in _sftp_open_exclusive
495
 
            #      will handle when the user wants the final mode to be more
496
 
            #      restrictive. And then we avoid a round trip. Unless
 
483
            #      will handle when the user wants the final mode to be more 
 
484
            #      restrictive. And then we avoid a round trip. Unless 
497
485
            #      paramiko decides to expose an async chmod()
498
486
 
499
487
            # This is designed to chmod() right before we close.
500
 
            # Because we set_pipelined() earlier, theoretically we might
 
488
            # Because we set_pipelined() earlier, theoretically we might 
501
489
            # avoid the round trip for fout.close()
502
490
            if mode is not None:
503
491
                self._get_sftp().chmod(tmp_abspath, mode)
545
533
                                                 ': unable to open')
546
534
 
547
535
                # This is designed to chmod() right before we close.
548
 
                # Because we set_pipelined() earlier, theoretically we might
 
536
                # Because we set_pipelined() earlier, theoretically we might 
549
537
                # avoid the round trip for fout.close()
550
538
                if mode is not None:
551
539
                    self._get_sftp().chmod(abspath, mode)
602
590
 
603
591
    def iter_files_recursive(self):
604
592
        """Walk the relative paths of all files in this transport."""
605
 
        # progress is handled by list_dir
606
593
        queue = list(self.list_dir('.'))
607
594
        while queue:
608
595
            relpath = queue.pop(0)
619
606
        else:
620
607
            local_mode = mode
621
608
        try:
622
 
            self._report_activity(len(abspath), 'write')
623
609
            self._get_sftp().mkdir(abspath, local_mode)
624
 
            self._report_activity(1, 'read')
625
610
            if mode is not None:
626
611
                # chmod a dir through sftp will erase any sgid bit set
627
612
                # on the server side.  So, if the bit mode are already
649
634
    def open_write_stream(self, relpath, mode=None):
650
635
        """See Transport.open_write_stream."""
651
636
        # initialise the file to zero-length
652
 
        # this is three round trips, but we don't use this
653
 
        # api more than once per write_group at the moment so
 
637
        # this is three round trips, but we don't use this 
 
638
        # api more than once per write_group at the moment so 
654
639
        # it is a tolerable overhead. Better would be to truncate
655
640
        # the file after opening. RBC 20070805
656
641
        self.put_bytes_non_atomic(relpath, "", mode)
679
664
        :param failure_exc: Paramiko has the super fun ability to raise completely
680
665
                           opaque errors that just set "e.args = ('Failure',)" with
681
666
                           no more information.
682
 
                           If this parameter is set, it defines the exception
 
667
                           If this parameter is set, it defines the exception 
683
668
                           to raise in these cases.
684
669
        """
685
670
        # paramiko seems to generate detailless errors.
694
679
            # strange but true, for the paramiko server.
695
680
            if (e.args == ('Failure',)):
696
681
                raise failure_exc(path, str(e) + more_info)
697
 
            # Can be something like args = ('Directory not empty:
698
 
            # '/srv/bazaar.launchpad.net/blah...: '
699
 
            # [Errno 39] Directory not empty',)
700
 
            if (e.args[0].startswith('Directory not empty: ')
701
 
                or getattr(e, 'errno', None) == errno.ENOTEMPTY):
702
 
                raise errors.DirectoryNotEmpty(path, str(e))
703
 
            if e.args == ('Operation unsupported',):
704
 
                raise errors.TransportNotPossible()
705
682
            mutter('Raising exception with args %s', e.args)
706
683
        if getattr(e, 'errno', None) is not None:
707
684
            mutter('Raising exception with errno %s', e.errno)
734
711
 
735
712
    def _rename_and_overwrite(self, abs_from, abs_to):
736
713
        """Do a fancy rename on the remote server.
737
 
 
 
714
        
738
715
        Using the implementation provided by osutils.
739
716
        """
740
717
        try:
759
736
            self._get_sftp().remove(path)
760
737
        except (IOError, paramiko.SSHException), e:
761
738
            self._translate_io_exception(e, path, ': unable to delete')
762
 
 
 
739
            
763
740
    def external_url(self):
764
741
        """See bzrlib.transport.Transport.external_url."""
765
742
        # the external path for SFTP is the base
780
757
        path = self._remote_path(relpath)
781
758
        try:
782
759
            entries = self._get_sftp().listdir(path)
783
 
            self._report_activity(sum(map(len, entries)), 'read')
784
760
        except (IOError, paramiko.SSHException), e:
785
761
            self._translate_io_exception(e, path, ': failed to list_dir')
786
762
        return [urlutils.escape(entry) for entry in entries]
797
773
        """Return the stat information for a file."""
798
774
        path = self._remote_path(relpath)
799
775
        try:
800
 
            return self._get_sftp().lstat(path)
 
776
            return self._get_sftp().stat(path)
801
777
        except (IOError, paramiko.SSHException), e:
802
778
            self._translate_io_exception(e, path, ': unable to stat')
803
779
 
804
 
    def readlink(self, relpath):
805
 
        """See Transport.readlink."""
806
 
        path = self._remote_path(relpath)
807
 
        try:
808
 
            return self._get_sftp().readlink(path)
809
 
        except (IOError, paramiko.SSHException), e:
810
 
            self._translate_io_exception(e, path, ': unable to readlink')
811
 
 
812
 
    def symlink(self, source, link_name):
813
 
        """See Transport.symlink."""
814
 
        try:
815
 
            conn = self._get_sftp()
816
 
            sftp_retval = conn.symlink(source, link_name)
817
 
            if SFTP_OK != sftp_retval:
818
 
                raise TransportError(
819
 
                    '%r: unable to create symlink to %r' % (link_name, source),
820
 
                    sftp_retval
821
 
                )
822
 
        except (IOError, paramiko.SSHException), e:
823
 
            self._translate_io_exception(e, link_name,
824
 
                                         ': unable to create symlink to %r' % (source))
825
 
 
826
780
    def lock_read(self, relpath):
827
781
        """
828
782
        Lock the given file for shared (read) access.
865
819
        """
866
820
        # TODO: jam 20060816 Paramiko >= 1.6.2 (probably earlier) supports
867
821
        #       using the 'x' flag to indicate SFTP_FLAG_EXCL.
868
 
        #       However, there is no way to set the permission mode at open
 
822
        #       However, there is no way to set the permission mode at open 
869
823
        #       time using the sftp_client.file() functionality.
870
824
        path = self._get_sftp()._adjust_cwd(abspath)
871
825
        # mutter('sftp abspath %s => %s', abspath, path)
872
826
        attr = SFTPAttributes()
873
827
        if mode is not None:
874
828
            attr.st_mode = mode
875
 
        omode = (SFTP_FLAG_WRITE | SFTP_FLAG_CREATE
 
829
        omode = (SFTP_FLAG_WRITE | SFTP_FLAG_CREATE 
876
830
                | SFTP_FLAG_TRUNC | SFTP_FLAG_EXCL)
877
831
        try:
878
832
            t, msg = self._get_sftp()._request(CMD_OPEN, path, omode, attr)
891
845
        else:
892
846
            return True
893
847
 
 
848
# ------------- server test implementation --------------
 
849
import threading
 
850
 
 
851
from bzrlib.tests.stub_sftp import StubServer, StubSFTPServer
 
852
 
 
853
STUB_SERVER_KEY = """
 
854
-----BEGIN RSA PRIVATE KEY-----
 
855
MIICWgIBAAKBgQDTj1bqB4WmayWNPB+8jVSYpZYk80Ujvj680pOTh2bORBjbIAyz
 
856
oWGW+GUjzKxTiiPvVmxFgx5wdsFvF03v34lEVVhMpouqPAYQ15N37K/ir5XY+9m/
 
857
d8ufMCkjeXsQkKqFbAlQcnWMCRnOoPHS3I4vi6hmnDDeeYTSRvfLbW0fhwIBIwKB
 
858
gBIiOqZYaoqbeD9OS9z2K9KR2atlTxGxOJPXiP4ESqP3NVScWNwyZ3NXHpyrJLa0
 
859
EbVtzsQhLn6rF+TzXnOlcipFvjsem3iYzCpuChfGQ6SovTcOjHV9z+hnpXvQ/fon
 
860
soVRZY65wKnF7IAoUwTmJS9opqgrN6kRgCd3DASAMd1bAkEA96SBVWFt/fJBNJ9H
 
861
tYnBKZGw0VeHOYmVYbvMSstssn8un+pQpUm9vlG/bp7Oxd/m+b9KWEh2xPfv6zqU
 
862
avNwHwJBANqzGZa/EpzF4J8pGti7oIAPUIDGMtfIcmqNXVMckrmzQ2vTfqtkEZsA
 
863
4rE1IERRyiJQx6EJsz21wJmGV9WJQ5kCQQDwkS0uXqVdFzgHO6S++tjmjYcxwr3g
 
864
H0CoFYSgbddOT6miqRskOQF3DZVkJT3kyuBgU2zKygz52ukQZMqxCb1fAkASvuTv
 
865
qfpH87Qq5kQhNKdbbwbmd2NxlNabazPijWuphGTdW0VfJdWfklyS2Kr+iqrs/5wV
 
866
HhathJt636Eg7oIjAkA8ht3MQ+XSl9yIJIS8gVpbPxSw5OMfw0PjVE7tBdQruiSc
 
867
nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7
 
868
-----END RSA PRIVATE KEY-----
 
869
"""
 
870
 
 
871
 
 
872
class SocketListener(threading.Thread):
 
873
 
 
874
    def __init__(self, callback):
 
875
        threading.Thread.__init__(self)
 
876
        self._callback = callback
 
877
        self._socket = socket.socket()
 
878
        self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
 
879
        self._socket.bind(('localhost', 0))
 
880
        self._socket.listen(1)
 
881
        self.port = self._socket.getsockname()[1]
 
882
        self._stop_event = threading.Event()
 
883
 
 
884
    def stop(self):
 
885
        # called from outside this thread
 
886
        self._stop_event.set()
 
887
        # use a timeout here, because if the test fails, the server thread may
 
888
        # never notice the stop_event.
 
889
        self.join(5.0)
 
890
        self._socket.close()
 
891
 
 
892
    def run(self):
 
893
        while True:
 
894
            readable, writable_unused, exception_unused = \
 
895
                select.select([self._socket], [], [], 0.1)
 
896
            if self._stop_event.isSet():
 
897
                return
 
898
            if len(readable) == 0:
 
899
                continue
 
900
            try:
 
901
                s, addr_unused = self._socket.accept()
 
902
                # because the loopback socket is inline, and transports are
 
903
                # never explicitly closed, best to launch a new thread.
 
904
                threading.Thread(target=self._callback, args=(s,)).start()
 
905
            except socket.error, x:
 
906
                sys.excepthook(*sys.exc_info())
 
907
                warning('Socket error during accept() within unit test server'
 
908
                        ' thread: %r' % x)
 
909
            except Exception, x:
 
910
                # probably a failed test; unit test thread will log the
 
911
                # failure/error
 
912
                sys.excepthook(*sys.exc_info())
 
913
                warning('Exception from within unit test server thread: %r' % 
 
914
                        x)
 
915
 
 
916
 
 
917
class SocketDelay(object):
 
918
    """A socket decorator to make TCP appear slower.
 
919
 
 
920
    This changes recv, send, and sendall to add a fixed latency to each python
 
921
    call if a new roundtrip is detected. That is, when a recv is called and the
 
922
    flag new_roundtrip is set, latency is charged. Every send and send_all
 
923
    sets this flag.
 
924
 
 
925
    In addition every send, sendall and recv sleeps a bit per character send to
 
926
    simulate bandwidth.
 
927
 
 
928
    Not all methods are implemented, this is deliberate as this class is not a
 
929
    replacement for the builtin sockets layer. fileno is not implemented to
 
930
    prevent the proxy being bypassed. 
 
931
    """
 
932
 
 
933
    simulated_time = 0
 
934
    _proxied_arguments = dict.fromkeys([
 
935
        "close", "getpeername", "getsockname", "getsockopt", "gettimeout",
 
936
        "setblocking", "setsockopt", "settimeout", "shutdown"])
 
937
 
 
938
    def __init__(self, sock, latency, bandwidth=1.0, 
 
939
                 really_sleep=True):
 
940
        """ 
 
941
        :param bandwith: simulated bandwith (MegaBit)
 
942
        :param really_sleep: If set to false, the SocketDelay will just
 
943
        increase a counter, instead of calling time.sleep. This is useful for
 
944
        unittesting the SocketDelay.
 
945
        """
 
946
        self.sock = sock
 
947
        self.latency = latency
 
948
        self.really_sleep = really_sleep
 
949
        self.time_per_byte = 1 / (bandwidth / 8.0 * 1024 * 1024) 
 
950
        self.new_roundtrip = False
 
951
 
 
952
    def sleep(self, s):
 
953
        if self.really_sleep:
 
954
            time.sleep(s)
 
955
        else:
 
956
            SocketDelay.simulated_time += s
 
957
 
 
958
    def __getattr__(self, attr):
 
959
        if attr in SocketDelay._proxied_arguments:
 
960
            return getattr(self.sock, attr)
 
961
        raise AttributeError("'SocketDelay' object has no attribute %r" %
 
962
                             attr)
 
963
 
 
964
    def dup(self):
 
965
        return SocketDelay(self.sock.dup(), self.latency, self.time_per_byte,
 
966
                           self._sleep)
 
967
 
 
968
    def recv(self, *args):
 
969
        data = self.sock.recv(*args)
 
970
        if data and self.new_roundtrip:
 
971
            self.new_roundtrip = False
 
972
            self.sleep(self.latency)
 
973
        self.sleep(len(data) * self.time_per_byte)
 
974
        return data
 
975
 
 
976
    def sendall(self, data, flags=0):
 
977
        if not self.new_roundtrip:
 
978
            self.new_roundtrip = True
 
979
            self.sleep(self.latency)
 
980
        self.sleep(len(data) * self.time_per_byte)
 
981
        return self.sock.sendall(data, flags)
 
982
 
 
983
    def send(self, data, flags=0):
 
984
        if not self.new_roundtrip:
 
985
            self.new_roundtrip = True
 
986
            self.sleep(self.latency)
 
987
        bytes_sent = self.sock.send(data, flags)
 
988
        self.sleep(bytes_sent * self.time_per_byte)
 
989
        return bytes_sent
 
990
 
 
991
 
 
992
class SFTPServer(Server):
 
993
    """Common code for SFTP server facilities."""
 
994
 
 
995
    def __init__(self, server_interface=StubServer):
 
996
        self._original_vendor = None
 
997
        self._homedir = None
 
998
        self._server_homedir = None
 
999
        self._listener = None
 
1000
        self._root = None
 
1001
        self._vendor = ssh.ParamikoVendor()
 
1002
        self._server_interface = server_interface
 
1003
        # sftp server logs
 
1004
        self.logs = []
 
1005
        self.add_latency = 0
 
1006
 
 
1007
    def _get_sftp_url(self, path):
 
1008
        """Calculate an sftp url to this server for path."""
 
1009
        return 'sftp://foo:bar@localhost:%d/%s' % (self._listener.port, path)
 
1010
 
 
1011
    def log(self, message):
 
1012
        """StubServer uses this to log when a new server is created."""
 
1013
        self.logs.append(message)
 
1014
 
 
1015
    def _run_server_entry(self, sock):
 
1016
        """Entry point for all implementations of _run_server.
 
1017
        
 
1018
        If self.add_latency is > 0.000001 then sock is given a latency adding
 
1019
        decorator.
 
1020
        """
 
1021
        if self.add_latency > 0.000001:
 
1022
            sock = SocketDelay(sock, self.add_latency)
 
1023
        return self._run_server(sock)
 
1024
 
 
1025
    def _run_server(self, s):
 
1026
        ssh_server = paramiko.Transport(s)
 
1027
        key_file = pathjoin(self._homedir, 'test_rsa.key')
 
1028
        f = open(key_file, 'w')
 
1029
        f.write(STUB_SERVER_KEY)
 
1030
        f.close()
 
1031
        host_key = paramiko.RSAKey.from_private_key_file(key_file)
 
1032
        ssh_server.add_server_key(host_key)
 
1033
        server = self._server_interface(self)
 
1034
        ssh_server.set_subsystem_handler('sftp', paramiko.SFTPServer,
 
1035
                                         StubSFTPServer, root=self._root,
 
1036
                                         home=self._server_homedir)
 
1037
        event = threading.Event()
 
1038
        ssh_server.start_server(event, server)
 
1039
        event.wait(5.0)
 
1040
    
 
1041
    def setUp(self, backing_server=None):
 
1042
        # XXX: TODO: make sftpserver back onto backing_server rather than local
 
1043
        # disk.
 
1044
        if not (backing_server is None or
 
1045
                isinstance(backing_server, local.LocalURLServer)):
 
1046
            raise AssertionError(
 
1047
                "backing_server should not be %r, because this can only serve the "
 
1048
                "local current working directory." % (backing_server,))
 
1049
        self._original_vendor = ssh._ssh_vendor_manager._cached_ssh_vendor
 
1050
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._vendor
 
1051
        if sys.platform == 'win32':
 
1052
            # Win32 needs to use the UNICODE api
 
1053
            self._homedir = getcwd()
 
1054
        else:
 
1055
            # But Linux SFTP servers should just deal in bytestreams
 
1056
            self._homedir = os.getcwd()
 
1057
        if self._server_homedir is None:
 
1058
            self._server_homedir = self._homedir
 
1059
        self._root = '/'
 
1060
        if sys.platform == 'win32':
 
1061
            self._root = ''
 
1062
        self._listener = SocketListener(self._run_server_entry)
 
1063
        self._listener.setDaemon(True)
 
1064
        self._listener.start()
 
1065
 
 
1066
    def tearDown(self):
 
1067
        """See bzrlib.transport.Server.tearDown."""
 
1068
        self._listener.stop()
 
1069
        ssh._ssh_vendor_manager._cached_ssh_vendor = self._original_vendor
 
1070
 
 
1071
    def get_bogus_url(self):
 
1072
        """See bzrlib.transport.Server.get_bogus_url."""
 
1073
        # this is chosen to try to prevent trouble with proxies, wierd dns, etc
 
1074
        # we bind a random socket, so that we get a guaranteed unused port
 
1075
        # we just never listen on that port
 
1076
        s = socket.socket()
 
1077
        s.bind(('localhost', 0))
 
1078
        return 'sftp://%s:%s/' % s.getsockname()
 
1079
 
 
1080
 
 
1081
class SFTPFullAbsoluteServer(SFTPServer):
 
1082
    """A test server for sftp transports, using absolute urls and ssh."""
 
1083
 
 
1084
    def get_url(self):
 
1085
        """See bzrlib.transport.Server.get_url."""
 
1086
        homedir = self._homedir
 
1087
        if sys.platform != 'win32':
 
1088
            # Remove the initial '/' on all platforms but win32
 
1089
            homedir = homedir[1:]
 
1090
        return self._get_sftp_url(urlutils.escape(homedir))
 
1091
 
 
1092
 
 
1093
class SFTPServerWithoutSSH(SFTPServer):
 
1094
    """An SFTP server that uses a simple TCP socket pair rather than SSH."""
 
1095
 
 
1096
    def __init__(self):
 
1097
        super(SFTPServerWithoutSSH, self).__init__()
 
1098
        self._vendor = ssh.LoopbackVendor()
 
1099
 
 
1100
    def _run_server(self, sock):
 
1101
        # Re-import these as locals, so that they're still accessible during
 
1102
        # interpreter shutdown (when all module globals get set to None, leading
 
1103
        # to confusing errors like "'NoneType' object has no attribute 'error'".
 
1104
        class FakeChannel(object):
 
1105
            def get_transport(self):
 
1106
                return self
 
1107
            def get_log_channel(self):
 
1108
                return 'paramiko'
 
1109
            def get_name(self):
 
1110
                return '1'
 
1111
            def get_hexdump(self):
 
1112
                return False
 
1113
            def close(self):
 
1114
                pass
 
1115
 
 
1116
        server = paramiko.SFTPServer(
 
1117
            FakeChannel(), 'sftp', StubServer(self), StubSFTPServer,
 
1118
            root=self._root, home=self._server_homedir)
 
1119
        try:
 
1120
            server.start_subsystem(
 
1121
                'sftp', None, ssh.SocketAsChannelAdapter(sock))
 
1122
        except socket.error, e:
 
1123
            if (len(e.args) > 0) and (e.args[0] == errno.EPIPE):
 
1124
                # it's okay for the client to disconnect abruptly
 
1125
                # (bug in paramiko 1.6: it should absorb this exception)
 
1126
                pass
 
1127
            else:
 
1128
                raise
 
1129
        except Exception, e:
 
1130
            # This typically seems to happen during interpreter shutdown, so
 
1131
            # most of the useful ways to report this error are won't work.
 
1132
            # Writing the exception type, and then the text of the exception,
 
1133
            # seems to be the best we can do.
 
1134
            import sys
 
1135
            sys.stderr.write('\nEXCEPTION %r: ' % (e.__class__,))
 
1136
            sys.stderr.write('%s\n\n' % (e,))
 
1137
        server.finish_subsystem()
 
1138
 
 
1139
 
 
1140
class SFTPAbsoluteServer(SFTPServerWithoutSSH):
 
1141
    """A test server for sftp transports, using absolute urls."""
 
1142
 
 
1143
    def get_url(self):
 
1144
        """See bzrlib.transport.Server.get_url."""
 
1145
        homedir = self._homedir
 
1146
        if sys.platform != 'win32':
 
1147
            # Remove the initial '/' on all platforms but win32
 
1148
            homedir = homedir[1:]
 
1149
        return self._get_sftp_url(urlutils.escape(homedir))
 
1150
 
 
1151
 
 
1152
class SFTPHomeDirServer(SFTPServerWithoutSSH):
 
1153
    """A test server for sftp transports, using homedir relative urls."""
 
1154
 
 
1155
    def get_url(self):
 
1156
        """See bzrlib.transport.Server.get_url."""
 
1157
        return self._get_sftp_url("~/")
 
1158
 
 
1159
 
 
1160
class SFTPSiblingAbsoluteServer(SFTPAbsoluteServer):
 
1161
    """A test server for sftp transports where only absolute paths will work.
 
1162
 
 
1163
    It does this by serving from a deeply-nested directory that doesn't exist.
 
1164
    """
 
1165
 
 
1166
    def setUp(self, backing_server=None):
 
1167
        self._server_homedir = '/dev/noone/runs/tests/here'
 
1168
        super(SFTPSiblingAbsoluteServer, self).setUp(backing_server)
 
1169
 
894
1170
 
895
1171
def get_test_permutations():
896
1172
    """Return the permutations to be used in testing."""
897
 
    from bzrlib.tests import stub_sftp
898
 
    return [(SFTPTransport, stub_sftp.SFTPAbsoluteServer),
899
 
            (SFTPTransport, stub_sftp.SFTPHomeDirServer),
900
 
            (SFTPTransport, stub_sftp.SFTPSiblingAbsoluteServer),
 
1173
    return [(SFTPTransport, SFTPAbsoluteServer),
 
1174
            (SFTPTransport, SFTPHomeDirServer),
 
1175
            (SFTPTransport, SFTPSiblingAbsoluteServer),
901
1176
            ]