~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_smart_request.py

MergeĀ upstream.

Show diffs side-by-side

added added

removed removed

Lines of Context:
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Tests for smart server request infrastructure (bzrlib.smart.request)."""
18
18
 
 
19
import threading
 
20
 
19
21
from bzrlib import errors
20
22
from bzrlib.smart import request
21
 
from bzrlib.tests import TestCase
 
23
from bzrlib.tests import TestCase, TestCaseWithMemoryTransport
 
24
from bzrlib.transport import get_transport
22
25
 
23
26
 
24
27
class NoBodyRequest(request.SmartServerRequest):
28
31
        return request.SuccessfulSmartServerResponse(('ok',))
29
32
 
30
33
 
 
34
class DoErrorRequest(request.SmartServerRequest):
 
35
    """A request that raises an error from self.do()."""
 
36
    
 
37
    def do(self):
 
38
        raise errors.NoSuchFile('xyzzy')
 
39
 
 
40
 
 
41
class ChunkErrorRequest(request.SmartServerRequest):
 
42
    """A request that raises an error from self.do_chunk()."""
 
43
    
 
44
    def do(self):
 
45
        """No-op."""
 
46
        pass
 
47
 
 
48
    def do_chunk(self, bytes):
 
49
        raise errors.NoSuchFile('xyzzy')
 
50
 
 
51
 
 
52
class EndErrorRequest(request.SmartServerRequest):
 
53
    """A request that raises an error from self.do_end()."""
 
54
    
 
55
    def do(self):
 
56
        """No-op."""
 
57
        pass
 
58
 
 
59
    def do_chunk(self, bytes):
 
60
        """No-op."""
 
61
        pass
 
62
        
 
63
    def do_end(self):
 
64
        raise errors.NoSuchFile('xyzzy')
 
65
 
 
66
 
 
67
class CheckJailRequest(request.SmartServerRequest):
 
68
 
 
69
    def __init__(self, *args):
 
70
        request.SmartServerRequest.__init__(self, *args)
 
71
        self.jail_transports_log = []
 
72
 
 
73
    def do(self):
 
74
        self.jail_transports_log.append(request.jail_info.transports)
 
75
 
 
76
    def do_chunk(self, bytes):
 
77
        self.jail_transports_log.append(request.jail_info.transports)
 
78
 
 
79
    def do_end(self):
 
80
        self.jail_transports_log.append(request.jail_info.transports)
 
81
 
 
82
 
31
83
class TestSmartRequest(TestCase):
32
84
 
33
85
    def test_request_class_without_do_body(self):
43
95
        handler.end_received()
44
96
        # Request done, no exception was raised.
45
97
 
 
98
    def test_only_request_code_is_jailed(self):
 
99
        transport = 'dummy transport'
 
100
        handler = request.SmartServerRequestHandler(
 
101
            transport, {'foo': CheckJailRequest}, '/')
 
102
        handler.args_received(('foo',))
 
103
        self.assertEqual(None, request.jail_info.transports)
 
104
        handler.accept_body('bytes')
 
105
        self.assertEqual(None, request.jail_info.transports)
 
106
        handler.end_received()
 
107
        self.assertEqual(None, request.jail_info.transports)
 
108
        self.assertEqual(
 
109
            [[transport]] * 3, handler._command.jail_transports_log)
 
110
 
 
111
 
 
112
 
 
113
class TestSmartRequestHandlerErrorTranslation(TestCase):
 
114
    """Tests that SmartServerRequestHandler will translate exceptions raised by
 
115
    a SmartServerRequest into FailedSmartServerResponses.
 
116
    """
 
117
 
 
118
    def assertNoResponse(self, handler):
 
119
        self.assertEqual(None, handler.response)
 
120
 
 
121
    def assertResponseIsTranslatedError(self, handler):
 
122
        expected_translation = ('NoSuchFile', 'xyzzy')
 
123
        self.assertEqual(
 
124
            request.FailedSmartServerResponse(expected_translation),
 
125
            handler.response)
 
126
 
 
127
    def test_error_translation_from_args_received(self):
 
128
        handler = request.SmartServerRequestHandler(
 
129
            None, {'foo': DoErrorRequest}, '/')
 
130
        handler.args_received(('foo',))
 
131
        self.assertResponseIsTranslatedError(handler)
 
132
 
 
133
    def test_error_translation_from_chunk_received(self):
 
134
        handler = request.SmartServerRequestHandler(
 
135
            None, {'foo': ChunkErrorRequest}, '/')
 
136
        handler.args_received(('foo',))
 
137
        self.assertNoResponse(handler)
 
138
        handler.accept_body('bytes')
 
139
        self.assertResponseIsTranslatedError(handler)
 
140
 
 
141
    def test_error_translation_from_end_received(self):
 
142
        handler = request.SmartServerRequestHandler(
 
143
            None, {'foo': EndErrorRequest}, '/')
 
144
        handler.args_received(('foo',))
 
145
        self.assertNoResponse(handler)
 
146
        handler.end_received()
 
147
        self.assertResponseIsTranslatedError(handler)
 
148
 
 
149
 
 
150
class TestRequestHanderErrorTranslation(TestCase):
 
151
    """Tests for bzrlib.smart.request._translate_error."""
 
152
 
 
153
    def assertTranslationEqual(self, expected_tuple, error):
 
154
        self.assertEqual(expected_tuple, request._translate_error(error))
 
155
 
 
156
    def test_NoSuchFile(self):
 
157
        self.assertTranslationEqual(
 
158
            ('NoSuchFile', 'path'), errors.NoSuchFile('path'))
 
159
 
 
160
    def test_LockContention(self):
 
161
        self.assertTranslationEqual(
 
162
            ('LockContention', 'lock', 'msg'),
 
163
            errors.LockContention('lock', 'msg'))
 
164
 
 
165
    def test_TokenMismatch(self):
 
166
        self.assertTranslationEqual(
 
167
            ('TokenMismatch', 'some-token', 'actual-token'),
 
168
            errors.TokenMismatch('some-token', 'actual-token'))
 
169
 
 
170
 
 
171
class TestRequestJail(TestCaseWithMemoryTransport):
 
172
    
 
173
    def test_jail(self):
 
174
        transport = self.get_transport('blah')
 
175
        req = request.SmartServerRequest(transport)
 
176
        self.assertEqual(None, request.jail_info.transports)
 
177
        req.setup_jail()
 
178
        self.assertEqual([transport], request.jail_info.transports)
 
179
        req.teardown_jail()
 
180
        self.assertEqual(None, request.jail_info.transports)
 
181
 
 
182
 
 
183
class TestJailHook(TestCaseWithMemoryTransport):
 
184
 
 
185
    def tearDown(self):
 
186
        request.jail_info.transports = None
 
187
        TestCaseWithMemoryTransport.tearDown(self)
 
188
 
 
189
    def test_jail_hook(self):
 
190
        request.jail_info.transports = None
 
191
        _pre_open_hook = request._pre_open_hook
 
192
        # Any transport is fine if jail_info.transports is None
 
193
        t = self.get_transport('foo')
 
194
        _pre_open_hook(t)
 
195
        # A transport in jail_info.transports is allowed
 
196
        request.jail_info.transports = [t]
 
197
        _pre_open_hook(t)
 
198
        # A child of a transport in jail_info is allowed
 
199
        _pre_open_hook(t.clone('child'))
 
200
        # A parent is not allowed
 
201
        self.assertRaises(errors.BzrError, _pre_open_hook, t.clone('..'))
 
202
        # A completely unrelated transport is not allowed
 
203
        self.assertRaises(
 
204
            errors.BzrError, _pre_open_hook, get_transport('http://host/'))
 
205