~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_wsgi.py

[merge] robert's knit-performance work

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2009, 2011 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
"""Tests for WSGI application"""
18
 
 
19
 
from cStringIO import StringIO
20
 
 
21
 
from bzrlib import tests
22
 
from bzrlib.smart import medium, protocol
23
 
from bzrlib.transport.http import wsgi
24
 
from bzrlib.transport import chroot, memory
25
 
 
26
 
 
27
 
class WSGITestMixin(object):
28
 
 
29
 
    def build_environ(self, updates=None):
30
 
        """Builds an environ dict with all fields required by PEP 333.
31
 
 
32
 
        :param updates: a dict to that will be incorporated into the returned
33
 
            dict using dict.update(updates).
34
 
        """
35
 
        environ = {
36
 
            # Required CGI variables
37
 
            'REQUEST_METHOD': 'GET',
38
 
            'SCRIPT_NAME': '/script/name/',
39
 
            'PATH_INFO': 'path/info',
40
 
            'SERVER_NAME': 'test',
41
 
            'SERVER_PORT': '9999',
42
 
            'SERVER_PROTOCOL': 'HTTP/1.0',
43
 
 
44
 
            # Required WSGI variables
45
 
            'wsgi.version': (1,0),
46
 
            'wsgi.url_scheme': 'http',
47
 
            'wsgi.input': StringIO(''),
48
 
            'wsgi.errors': StringIO(),
49
 
            'wsgi.multithread': False,
50
 
            'wsgi.multiprocess': False,
51
 
            'wsgi.run_once': True,
52
 
        }
53
 
        if updates is not None:
54
 
            environ.update(updates)
55
 
        return environ
56
 
 
57
 
    def read_response(self, iterable):
58
 
        response = ''
59
 
        for string in iterable:
60
 
            response += string
61
 
        return response
62
 
 
63
 
    def start_response(self, status, headers):
64
 
        self.status = status
65
 
        self.headers = headers
66
 
 
67
 
 
68
 
class TestWSGI(tests.TestCase, WSGITestMixin):
69
 
 
70
 
    def setUp(self):
71
 
        tests.TestCase.setUp(self)
72
 
        self.status = None
73
 
        self.headers = None
74
 
 
75
 
    def test_construct(self):
76
 
        app = wsgi.SmartWSGIApp(FakeTransport())
77
 
        self.assertIsInstance(
78
 
            app.backing_transport, chroot.ChrootTransport)
79
 
 
80
 
    def test_http_get_rejected(self):
81
 
        # GET requests are rejected.
82
 
        app = wsgi.SmartWSGIApp(FakeTransport())
83
 
        environ = self.build_environ({'REQUEST_METHOD': 'GET'})
84
 
        iterable = app(environ, self.start_response)
85
 
        self.read_response(iterable)
86
 
        self.assertEqual('405 Method not allowed', self.status)
87
 
        self.assertTrue(('Allow', 'POST') in self.headers)
88
 
 
89
 
    def _fake_make_request(self, transport, write_func, bytes, rcp):
90
 
        request = FakeRequest(transport, write_func)
91
 
        request.accept_bytes(bytes)
92
 
        self.request = request
93
 
        return request
94
 
 
95
 
    def test_smart_wsgi_app_uses_given_relpath(self):
96
 
        # The SmartWSGIApp should use the "bzrlib.relpath" field from the
97
 
        # WSGI environ to clone from its backing transport to get a specific
98
 
        # transport for this request.
99
 
        transport = FakeTransport()
100
 
        wsgi_app = wsgi.SmartWSGIApp(transport)
101
 
        wsgi_app.backing_transport = transport
102
 
        wsgi_app.make_request = self._fake_make_request
103
 
        fake_input = StringIO('fake request')
104
 
        environ = self.build_environ({
105
 
            'REQUEST_METHOD': 'POST',
106
 
            'CONTENT_LENGTH': len(fake_input.getvalue()),
107
 
            'wsgi.input': fake_input,
108
 
            'bzrlib.relpath': 'foo/bar',
109
 
        })
110
 
        iterable = wsgi_app(environ, self.start_response)
111
 
        response = self.read_response(iterable)
112
 
        self.assertEqual([('clone', 'foo/bar/')] , transport.calls)
113
 
 
114
 
    def test_smart_wsgi_app_request_and_response(self):
115
 
        # SmartWSGIApp reads the smart request from the 'wsgi.input' file-like
116
 
        # object in the environ dict, and returns the response via the iterable
117
 
        # returned to the WSGI handler.
118
 
        transport = memory.MemoryTransport()
119
 
        transport.put_bytes('foo', 'some bytes')
120
 
        wsgi_app = wsgi.SmartWSGIApp(transport)
121
 
        wsgi_app.make_request = self._fake_make_request
122
 
        fake_input = StringIO('fake request')
123
 
        environ = self.build_environ({
124
 
            'REQUEST_METHOD': 'POST',
125
 
            'CONTENT_LENGTH': len(fake_input.getvalue()),
126
 
            'wsgi.input': fake_input,
127
 
            'bzrlib.relpath': 'foo',
128
 
        })
129
 
        iterable = wsgi_app(environ, self.start_response)
130
 
        response = self.read_response(iterable)
131
 
        self.assertEqual('200 OK', self.status)
132
 
        self.assertEqual('got bytes: fake request', response)
133
 
 
134
 
    def test_relpath_setter(self):
135
 
        # wsgi.RelpathSetter is WSGI "middleware" to set the 'bzrlib.relpath'
136
 
        # variable.
137
 
        calls = []
138
 
        def fake_app(environ, start_response):
139
 
            calls.append(environ['bzrlib.relpath'])
140
 
        wrapped_app = wsgi.RelpathSetter(
141
 
            fake_app, prefix='/abc/', path_var='FOO')
142
 
        wrapped_app({'FOO': '/abc/xyz/.bzr/smart'}, None)
143
 
        self.assertEqual(['xyz'], calls)
144
 
 
145
 
    def test_relpath_setter_bad_path_prefix(self):
146
 
        # wsgi.RelpathSetter will reject paths with that don't match the prefix
147
 
        # with a 404.  This is probably a sign of misconfiguration; a server
148
 
        # shouldn't ever be invoking our WSGI application with bad paths.
149
 
        def fake_app(environ, start_response):
150
 
            self.fail('The app should never be called when the path is wrong')
151
 
        wrapped_app = wsgi.RelpathSetter(
152
 
            fake_app, prefix='/abc/', path_var='FOO')
153
 
        iterable = wrapped_app(
154
 
            {'FOO': 'AAA/abc/xyz/.bzr/smart'}, self.start_response)
155
 
        self.read_response(iterable)
156
 
        self.assertTrue(self.status.startswith('404'))
157
 
 
158
 
    def test_relpath_setter_bad_path_suffix(self):
159
 
        # Similar to test_relpath_setter_bad_path_prefix: wsgi.RelpathSetter
160
 
        # will reject paths with that don't match the suffix '.bzr/smart' with a
161
 
        # 404 as well.  Again, this shouldn't be seen by our WSGI application if
162
 
        # the server is configured correctly.
163
 
        def fake_app(environ, start_response):
164
 
            self.fail('The app should never be called when the path is wrong')
165
 
        wrapped_app = wsgi.RelpathSetter(
166
 
            fake_app, prefix='/abc/', path_var='FOO')
167
 
        iterable = wrapped_app(
168
 
            {'FOO': '/abc/xyz/.bzr/AAA'}, self.start_response)
169
 
        self.read_response(iterable)
170
 
        self.assertTrue(self.status.startswith('404'))
171
 
 
172
 
    def test_make_app(self):
173
 
        # The make_app helper constructs a SmartWSGIApp wrapped in a
174
 
        # RelpathSetter.
175
 
        app = wsgi.make_app(
176
 
            root='a root',
177
 
            prefix='a prefix',
178
 
            path_var='a path_var')
179
 
        self.assertIsInstance(app, wsgi.RelpathSetter)
180
 
        self.assertIsInstance(app.app, wsgi.SmartWSGIApp)
181
 
        self.assertStartsWith(app.app.backing_transport.base, 'chroot-')
182
 
        backing_transport = app.app.backing_transport
183
 
        chroot_backing_transport = backing_transport.server.backing_transport
184
 
        self.assertEndsWith(chroot_backing_transport.base, 'a%20root/')
185
 
        self.assertEqual(app.app.root_client_path, 'a prefix')
186
 
        self.assertEqual(app.path_var, 'a path_var')
187
 
 
188
 
    def test_incomplete_request(self):
189
 
        transport = FakeTransport()
190
 
        wsgi_app = wsgi.SmartWSGIApp(transport)
191
 
        def make_request(transport, write_func, bytes, root_client_path):
192
 
            request = IncompleteRequest(transport, write_func)
193
 
            request.accept_bytes(bytes)
194
 
            self.request = request
195
 
            return request
196
 
        wsgi_app.make_request = make_request
197
 
 
198
 
        fake_input = StringIO('incomplete request')
199
 
        environ = self.build_environ({
200
 
            'REQUEST_METHOD': 'POST',
201
 
            'CONTENT_LENGTH': len(fake_input.getvalue()),
202
 
            'wsgi.input': fake_input,
203
 
            'bzrlib.relpath': 'foo/bar',
204
 
        })
205
 
        iterable = wsgi_app(environ, self.start_response)
206
 
        response = self.read_response(iterable)
207
 
        self.assertEqual('200 OK', self.status)
208
 
        self.assertEqual('error\x01incomplete request\n', response)
209
 
 
210
 
    def test_protocol_version_detection_one(self):
211
 
        # SmartWSGIApp detects requests that don't start with
212
 
        # REQUEST_VERSION_TWO as version one.
213
 
        transport = memory.MemoryTransport()
214
 
        wsgi_app = wsgi.SmartWSGIApp(transport)
215
 
        fake_input = StringIO('hello\n')
216
 
        environ = self.build_environ({
217
 
            'REQUEST_METHOD': 'POST',
218
 
            'CONTENT_LENGTH': len(fake_input.getvalue()),
219
 
            'wsgi.input': fake_input,
220
 
            'bzrlib.relpath': 'foo',
221
 
        })
222
 
        iterable = wsgi_app(environ, self.start_response)
223
 
        response = self.read_response(iterable)
224
 
        self.assertEqual('200 OK', self.status)
225
 
        # Expect a version 1-encoded response.
226
 
        self.assertEqual('ok\x012\n', response)
227
 
 
228
 
    def test_protocol_version_detection_two(self):
229
 
        # SmartWSGIApp detects requests that start with REQUEST_VERSION_TWO
230
 
        # as version two.
231
 
        transport = memory.MemoryTransport()
232
 
        wsgi_app = wsgi.SmartWSGIApp(transport)
233
 
        fake_input = StringIO(protocol.REQUEST_VERSION_TWO + 'hello\n')
234
 
        environ = self.build_environ({
235
 
            'REQUEST_METHOD': 'POST',
236
 
            'CONTENT_LENGTH': len(fake_input.getvalue()),
237
 
            'wsgi.input': fake_input,
238
 
            'bzrlib.relpath': 'foo',
239
 
        })
240
 
        iterable = wsgi_app(environ, self.start_response)
241
 
        response = self.read_response(iterable)
242
 
        self.assertEqual('200 OK', self.status)
243
 
        # Expect a version 2-encoded response.
244
 
        self.assertEqual(
245
 
            protocol.RESPONSE_VERSION_TWO + 'success\nok\x012\n', response)
246
 
 
247
 
 
248
 
class TestWSGIJail(tests.TestCaseWithMemoryTransport, WSGITestMixin):
249
 
 
250
 
    def make_hpss_wsgi_request(self, wsgi_relpath, *args):
251
 
        write_buf = StringIO()
252
 
        request_medium = medium.SmartSimplePipesClientMedium(
253
 
            None, write_buf, 'fake:' + wsgi_relpath)
254
 
        request_encoder = protocol.ProtocolThreeRequester(
255
 
            request_medium.get_request())
256
 
        request_encoder.call(*args)
257
 
        write_buf.seek(0)
258
 
        environ = self.build_environ({
259
 
            'REQUEST_METHOD': 'POST',
260
 
            'CONTENT_LENGTH': len(write_buf.getvalue()),
261
 
            'wsgi.input': write_buf,
262
 
            'bzrlib.relpath': wsgi_relpath,
263
 
        })
264
 
        return environ
265
 
 
266
 
    def test_jail_root(self):
267
 
        """The WSGI HPSS glue allows access to the whole WSGI backing
268
 
        transport, regardless of which HTTP path the request was delivered
269
 
        to.
270
 
        """
271
 
        # make a branch in a shared repo
272
 
        self.make_repository('repo', shared=True)
273
 
        branch = self.make_bzrdir('repo/branch').create_branch()
274
 
        # serve the repo via bzr+http WSGI
275
 
        wsgi_app = wsgi.SmartWSGIApp(self.get_transport())
276
 
        # send a request to /repo/branch that will have to access /repo.
277
 
        environ = self.make_hpss_wsgi_request(
278
 
            '/repo/branch', 'BzrDir.open_branchV2', '.')
279
 
        iterable = wsgi_app(environ, self.start_response)
280
 
        response_bytes = self.read_response(iterable)
281
 
        self.assertEqual('200 OK', self.status)
282
 
        # expect a successful response, rather than a jail break error
283
 
        from bzrlib.tests.test_smart_transport import LoggingMessageHandler
284
 
        message_handler = LoggingMessageHandler()
285
 
        decoder = protocol.ProtocolThreeDecoder(
286
 
            message_handler, expect_version_marker=True)
287
 
        decoder.accept_bytes(response_bytes)
288
 
        self.assertTrue(
289
 
            ('structure', ('branch', branch._format.network_name()))
290
 
            in message_handler.event_log)
291
 
 
292
 
 
293
 
class FakeRequest(object):
294
 
 
295
 
    def __init__(self, transport, write_func):
296
 
        self.transport = transport
297
 
        self.write_func = write_func
298
 
        self.accepted_bytes = ''
299
 
 
300
 
    def accept_bytes(self, bytes):
301
 
        self.accepted_bytes = bytes
302
 
        self.write_func('got bytes: ' + bytes)
303
 
 
304
 
    def next_read_size(self):
305
 
        return 0
306
 
 
307
 
 
308
 
class FakeTransport(object):
309
 
 
310
 
    def __init__(self):
311
 
        self.calls = []
312
 
        self.base = 'fake:///'
313
 
 
314
 
    def abspath(self, relpath):
315
 
        return 'fake:///' + relpath
316
 
 
317
 
    def clone(self, relpath):
318
 
        self.calls.append(('clone', relpath))
319
 
        return self
320
 
 
321
 
 
322
 
class IncompleteRequest(FakeRequest):
323
 
    """A request-like object that always expects to read more bytes."""
324
 
 
325
 
    def next_read_size(self):
326
 
        # this request always asks for more
327
 
        return 1
328