~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/selftest/HTTPTestUtil.py

merge merge tweaks from aaron, which includes latest .dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005 Canonical Ltd
2
 
#
 
1
# Copyright (C) 2005 by Canonical Ltd
 
2
 
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
5
5
# the Free Software Foundation; either version 2 of the License, or
6
6
# (at your option) any later version.
7
 
#
 
7
 
8
8
# This program is distributed in the hope that it will be useful,
9
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
11
# GNU General Public License for more details.
12
 
#
 
12
 
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
16
 
17
 
from cStringIO import StringIO
18
 
import errno
19
 
import md5
20
 
import re
21
 
import sha
22
 
import socket
23
 
import threading
24
 
import time
25
 
import urllib2
26
 
import urlparse
27
 
 
28
 
from bzrlib import (
29
 
    errors,
30
 
    tests,
31
 
    transport,
32
 
    )
33
 
from bzrlib.smart import medium, protocol
34
 
from bzrlib.tests import http_server
35
 
 
36
 
 
37
 
class HTTPServerWithSmarts(http_server.HttpServer):
38
 
    """HTTPServerWithSmarts extends the HttpServer with POST methods that will
39
 
    trigger a smart server to execute with a transport rooted at the rootdir of
40
 
    the HTTP server.
41
 
    """
42
 
 
43
 
    def __init__(self, protocol_version=None):
44
 
        http_server.HttpServer.__init__(self, SmartRequestHandler,
45
 
                                        protocol_version=protocol_version)
46
 
 
47
 
 
48
 
class SmartRequestHandler(http_server.TestingHTTPRequestHandler):
49
 
    """Extend TestingHTTPRequestHandler to support smart client POSTs."""
50
 
 
51
 
    def do_POST(self):
52
 
        """Hand the request off to a smart server instance."""
53
 
        self.send_response(200)
54
 
        self.send_header("Content-type", "application/octet-stream")
55
 
        t = transport.get_transport(self.server.test_case_server._home_dir)
56
 
        # if this fails, we should return 400 bad request, but failure is
57
 
        # failure for now - RBC 20060919
58
 
        data_length = int(self.headers['Content-Length'])
59
 
        # TODO: We might like to support streaming responses.  1.0 allows no
60
 
        # Content-length in this case, so for integrity we should perform our
61
 
        # own chunking within the stream.
62
 
        # 1.1 allows chunked responses, and in this case we could chunk using
63
 
        # the HTTP chunking as this will allow HTTP persistence safely, even if
64
 
        # we have to stop early due to error, but we would also have to use the
65
 
        # HTTP trailer facility which may not be widely available.
66
 
        request_bytes = self.rfile.read(data_length)
67
 
        protocol_factory, unused_bytes = medium._get_protocol_factory_for_bytes(
68
 
            request_bytes)
69
 
        out_buffer = StringIO()
70
 
        smart_protocol_request = protocol_factory(t, out_buffer.write, '/')
71
 
        # Perhaps there should be a SmartServerHTTPMedium that takes care of
72
 
        # feeding the bytes in the http request to the smart_protocol_request,
73
 
        # but for now it's simpler to just feed the bytes directly.
74
 
        smart_protocol_request.accept_bytes(unused_bytes)
75
 
        if not (smart_protocol_request.next_read_size() == 0):
76
 
            raise errors.SmartProtocolError(
77
 
                "not finished reading, but all data sent to protocol.")
78
 
        self.send_header("Content-Length", str(len(out_buffer.getvalue())))
79
 
        self.end_headers()
80
 
        self.wfile.write(out_buffer.getvalue())
81
 
 
82
 
 
83
 
class TestCaseWithWebserver(tests.TestCaseWithTransport):
84
 
    """A support class that provides readonly urls that are http://.
85
 
 
86
 
    This is done by forcing the readonly server to be an http
87
 
    one. This will currently fail if the primary transport is not
88
 
    backed by regular disk files.
89
 
    """
90
 
    def setUp(self):
91
 
        super(TestCaseWithWebserver, self).setUp()
92
 
        self.transport_readonly_server = http_server.HttpServer
93
 
 
94
 
 
95
 
class TestCaseWithTwoWebservers(TestCaseWithWebserver):
96
 
    """A support class providing readonly urls on two servers that are http://.
97
 
 
98
 
    We set up two webservers to allows various tests involving
99
 
    proxies or redirections from one server to the other.
100
 
    """
101
 
    def setUp(self):
102
 
        super(TestCaseWithTwoWebservers, self).setUp()
103
 
        self.transport_secondary_server = http_server.HttpServer
104
 
        self.__secondary_server = None
105
 
 
106
 
    def create_transport_secondary_server(self):
107
 
        """Create a transport server from class defined at init.
108
 
 
109
 
        This is mostly a hook for daughter classes.
110
 
        """
111
 
        return self.transport_secondary_server()
112
 
 
113
 
    def get_secondary_server(self):
114
 
        """Get the server instance for the secondary transport."""
115
 
        if self.__secondary_server is None:
116
 
            self.__secondary_server = self.create_transport_secondary_server()
117
 
            self.__secondary_server.setUp()
118
 
            self.addCleanup(self.__secondary_server.tearDown)
119
 
        return self.__secondary_server
120
 
 
121
 
 
122
 
class ProxyServer(http_server.HttpServer):
123
 
    """A proxy test server for http transports."""
124
 
 
125
 
    proxy_requests = True
126
 
 
127
 
 
128
 
class RedirectRequestHandler(http_server.TestingHTTPRequestHandler):
129
 
    """Redirect all request to the specified server"""
130
 
 
131
 
    def parse_request(self):
132
 
        """Redirect a single HTTP request to another host"""
133
 
        valid = http_server.TestingHTTPRequestHandler.parse_request(self)
134
 
        if valid:
135
 
            tcs = self.server.test_case_server
136
 
            code, target = tcs.is_redirected(self.path)
137
 
            if code is not None and target is not None:
138
 
                # Redirect as instructed
139
 
                self.send_response(code)
140
 
                self.send_header('Location', target)
141
 
                # We do not send a body
142
 
                self.send_header('Content-Length', '0')
143
 
                self.end_headers()
144
 
                return False # The job is done
 
17
import BaseHTTPServer, SimpleHTTPServer
 
18
from bzrlib.selftest import TestCaseInTempDir
 
19
 
 
20
 
 
21
class WebserverNotAvailable(Exception):
 
22
    pass
 
23
 
 
24
class BadWebserverPath(ValueError):
 
25
    def __str__(self):
 
26
        return 'path %s is not in %s' % self.args
 
27
 
 
28
class TestingHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
 
29
    def log_message(self, format, *args):
 
30
        self.server.test_case.log("webserver - %s - - [%s] %s\n" %
 
31
                                  (self.address_string(),
 
32
                                   self.log_date_time_string(),
 
33
                                   format%args))
 
34
 
 
35
class TestingHTTPServer(BaseHTTPServer.HTTPServer):
 
36
    def __init__(self, server_address, RequestHandlerClass, test_case):
 
37
        BaseHTTPServer.HTTPServer.__init__(self, server_address,
 
38
                                                RequestHandlerClass)
 
39
        self.test_case = test_case
 
40
 
 
41
class TestCaseWithWebserver(TestCaseInTempDir):
 
42
    """Derived class that starts a localhost-only webserver
 
43
    (in addition to what TestCaseInTempDir does).
 
44
 
 
45
    This is useful for testing RemoteBranch.
 
46
    """
 
47
 
 
48
    _HTTP_PORTS = range(13000, 0x8000)
 
49
 
 
50
    def _http_start(self):
 
51
        import SimpleHTTPServer, BaseHTTPServer, socket, errno
 
52
        httpd = None
 
53
        for port in self._HTTP_PORTS:
 
54
            try:
 
55
                httpd = TestingHTTPServer(('localhost', port),
 
56
                                          TestingHTTPRequestHandler,
 
57
                                          self)
 
58
            except socket.error, e:
 
59
                if e.args[0] == errno.EADDRINUSE:
 
60
                    continue
 
61
                print >>sys.stderr, "Cannot run webserver :-("
 
62
                raise
145
63
            else:
146
 
                # We leave the parent class serve the request
 
64
                break
 
65
 
 
66
        if httpd is None:
 
67
            raise WebserverNotAvailable("Cannot run webserver :-( "
 
68
                                        "no free ports in range %s..%s" %
 
69
                                        (_HTTP_PORTS[0], _HTTP_PORTS[-1]))
 
70
 
 
71
        self._http_base_url = 'http://localhost:%s/' % port
 
72
        self._http_starting.release()
 
73
        httpd.socket.settimeout(1)
 
74
 
 
75
        while self._http_running:
 
76
            try:
 
77
                httpd.handle_request()
 
78
            except socket.timeout:
147
79
                pass
148
 
        return valid
149
 
 
150
 
 
151
 
class HTTPServerRedirecting(http_server.HttpServer):
152
 
    """An HttpServer redirecting to another server """
153
 
 
154
 
    def __init__(self, request_handler=RedirectRequestHandler,
155
 
                 protocol_version=None):
156
 
        http_server.HttpServer.__init__(self, request_handler,
157
 
                                        protocol_version=protocol_version)
158
 
        # redirections is a list of tuples (source, target, code)
159
 
        # - source is a regexp for the paths requested
160
 
        # - target is a replacement for re.sub describing where
161
 
        #   the request will be redirected
162
 
        # - code is the http error code associated to the
163
 
        #   redirection (301 permanent, 302 temporarry, etc
164
 
        self.redirections = []
165
 
 
166
 
    def redirect_to(self, host, port):
167
 
        """Redirect all requests to a specific host:port"""
168
 
        self.redirections = [('(.*)',
169
 
                              r'http://%s:%s\1' % (host, port) ,
170
 
                              301)]
171
 
 
172
 
    def is_redirected(self, path):
173
 
        """Is the path redirected by this server.
174
 
 
175
 
        :param path: the requested relative path
176
 
 
177
 
        :returns: a tuple (code, target) if a matching
178
 
             redirection is found, (None, None) otherwise.
179
 
        """
180
 
        code = None
181
 
        target = None
182
 
        for (rsource, rtarget, rcode) in self.redirections:
183
 
            target, match = re.subn(rsource, rtarget, path)
184
 
            if match:
185
 
                code = rcode
186
 
                break # The first match wins
187
 
            else:
188
 
                target = None
189
 
        return code, target
190
 
 
191
 
 
192
 
class TestCaseWithRedirectedWebserver(TestCaseWithTwoWebservers):
193
 
   """A support class providing redirections from one server to another.
194
 
 
195
 
   We set up two webservers to allows various tests involving
196
 
   redirections.
197
 
   The 'old' server is redirected to the 'new' server.
198
 
   """
199
 
 
200
 
   def create_transport_secondary_server(self):
201
 
       """Create the secondary server redirecting to the primary server"""
202
 
       new = self.get_readonly_server()
203
 
       redirecting = HTTPServerRedirecting()
204
 
       redirecting.redirect_to(new.host, new.port)
205
 
       return redirecting
206
 
 
207
 
   def setUp(self):
208
 
       super(TestCaseWithRedirectedWebserver, self).setUp()
209
 
       # The redirections will point to the new server
210
 
       self.new_server = self.get_readonly_server()
211
 
       # The requests to the old server will be redirected
212
 
       self.old_server = self.get_secondary_server()
213
 
 
214
 
 
215
 
class AuthRequestHandler(http_server.TestingHTTPRequestHandler):
216
 
    """Requires an authentication to process requests.
217
 
 
218
 
    This is intended to be used with a server that always and
219
 
    only use one authentication scheme (implemented by daughter
220
 
    classes).
221
 
    """
222
 
 
223
 
    # The following attributes should be defined in the server
224
 
    # - auth_header_sent: the header name sent to require auth
225
 
    # - auth_header_recv: the header received containing auth
226
 
    # - auth_error_code: the error code to indicate auth required
227
 
 
228
 
    def do_GET(self):
229
 
        if self.authorized():
230
 
            return http_server.TestingHTTPRequestHandler.do_GET(self)
 
80
 
 
81
    def get_remote_url(self, path):
 
82
        import os
 
83
 
 
84
        path_parts = path.split(os.path.sep)
 
85
        if os.path.isabs(path):
 
86
            if path_parts[:len(self._local_path_parts)] != \
 
87
                   self._local_path_parts:
 
88
                raise BadWebserverPath(path, self.test_dir)
 
89
            remote_path = '/'.join(path_parts[len(self._local_path_parts):])
231
90
        else:
232
 
            # Note that we must update test_case_server *before*
233
 
            # sending the error or the client may try to read it
234
 
            # before we have sent the whole error back.
235
 
            tcs = self.server.test_case_server
236
 
            tcs.auth_required_errors += 1
237
 
            self.send_response(tcs.auth_error_code)
238
 
            self.send_header_auth_reqed()
239
 
            # We do not send a body
240
 
            self.send_header('Content-Length', '0')
241
 
            self.end_headers()
242
 
            return
243
 
 
244
 
 
245
 
class BasicAuthRequestHandler(AuthRequestHandler):
246
 
    """Implements the basic authentication of a request"""
247
 
 
248
 
    def authorized(self):
249
 
        tcs = self.server.test_case_server
250
 
        if tcs.auth_scheme != 'basic':
251
 
            return False
252
 
 
253
 
        auth_header = self.headers.get(tcs.auth_header_recv, None)
254
 
        if auth_header:
255
 
            scheme, raw_auth = auth_header.split(' ', 1)
256
 
            if scheme.lower() == tcs.auth_scheme:
257
 
                user, password = raw_auth.decode('base64').split(':')
258
 
                return tcs.authorized(user, password)
259
 
 
260
 
        return False
261
 
 
262
 
    def send_header_auth_reqed(self):
263
 
        tcs = self.server.test_case_server
264
 
        self.send_header(tcs.auth_header_sent,
265
 
                         'Basic realm="%s"' % tcs.auth_realm)
266
 
 
267
 
 
268
 
# FIXME: We could send an Authentication-Info header too when
269
 
# the authentication is succesful
270
 
 
271
 
class DigestAuthRequestHandler(AuthRequestHandler):
272
 
    """Implements the digest authentication of a request.
273
 
 
274
 
    We need persistence for some attributes and that can't be
275
 
    achieved here since we get instantiated for each request. We
276
 
    rely on the DigestAuthServer to take care of them.
277
 
    """
278
 
 
279
 
    def authorized(self):
280
 
        tcs = self.server.test_case_server
281
 
        if tcs.auth_scheme != 'digest':
282
 
            return False
283
 
 
284
 
        auth_header = self.headers.get(tcs.auth_header_recv, None)
285
 
        if auth_header is None:
286
 
            return False
287
 
        scheme, auth = auth_header.split(None, 1)
288
 
        if scheme.lower() == tcs.auth_scheme:
289
 
            auth_dict = urllib2.parse_keqv_list(urllib2.parse_http_list(auth))
290
 
 
291
 
            return tcs.digest_authorized(auth_dict, self.command)
292
 
 
293
 
        return False
294
 
 
295
 
    def send_header_auth_reqed(self):
296
 
        tcs = self.server.test_case_server
297
 
        header = 'Digest realm="%s", ' % tcs.auth_realm
298
 
        header += 'nonce="%s", algorithm="%s", qop="auth"' % (tcs.auth_nonce,
299
 
                                                              'MD5')
300
 
        self.send_header(tcs.auth_header_sent,header)
301
 
 
302
 
 
303
 
class AuthServer(http_server.HttpServer):
304
 
    """Extends HttpServer with a dictionary of passwords.
305
 
 
306
 
    This is used as a base class for various schemes which should
307
 
    all use or redefined the associated AuthRequestHandler.
308
 
 
309
 
    Note that no users are defined by default, so add_user should
310
 
    be called before issuing the first request.
311
 
    """
312
 
 
313
 
    # The following attributes should be set dy daughter classes
314
 
    # and are used by AuthRequestHandler.
315
 
    auth_header_sent = None
316
 
    auth_header_recv = None
317
 
    auth_error_code = None
318
 
    auth_realm = "Thou should not pass"
319
 
 
320
 
    def __init__(self, request_handler, auth_scheme,
321
 
                 protocol_version=None):
322
 
        http_server.HttpServer.__init__(self, request_handler,
323
 
                                        protocol_version=protocol_version)
324
 
        self.auth_scheme = auth_scheme
325
 
        self.password_of = {}
326
 
        self.auth_required_errors = 0
327
 
 
328
 
    def add_user(self, user, password):
329
 
        """Declare a user with an associated password.
330
 
 
331
 
        password can be empty, use an empty string ('') in that
332
 
        case, not None.
333
 
        """
334
 
        self.password_of[user] = password
335
 
 
336
 
    def authorized(self, user, password):
337
 
        """Check that the given user provided the right password"""
338
 
        expected_password = self.password_of.get(user, None)
339
 
        return expected_password is not None and password == expected_password
340
 
 
341
 
 
342
 
# FIXME: There is some code duplication with
343
 
# _urllib2_wrappers.py.DigestAuthHandler. If that duplication
344
 
# grows, it may require a refactoring. Also, we don't implement
345
 
# SHA algorithm nor MD5-sess here, but that does not seem worth
346
 
# it.
347
 
class DigestAuthServer(AuthServer):
348
 
    """A digest authentication server"""
349
 
 
350
 
    auth_nonce = 'now!'
351
 
 
352
 
    def __init__(self, request_handler, auth_scheme,
353
 
                 protocol_version=None):
354
 
        AuthServer.__init__(self, request_handler, auth_scheme,
355
 
                            protocol_version=protocol_version)
356
 
 
357
 
    def digest_authorized(self, auth, command):
358
 
        nonce = auth['nonce']
359
 
        if nonce != self.auth_nonce:
360
 
            return False
361
 
        realm = auth['realm']
362
 
        if realm != self.auth_realm:
363
 
            return False
364
 
        user = auth['username']
365
 
        if not self.password_of.has_key(user):
366
 
            return False
367
 
        algorithm= auth['algorithm']
368
 
        if algorithm != 'MD5':
369
 
            return False
370
 
        qop = auth['qop']
371
 
        if qop != 'auth':
372
 
            return False
373
 
 
374
 
        password = self.password_of[user]
375
 
 
376
 
        # Recalculate the response_digest to compare with the one
377
 
        # sent by the client
378
 
        A1 = '%s:%s:%s' % (user, realm, password)
379
 
        A2 = '%s:%s' % (command, auth['uri'])
380
 
 
381
 
        H = lambda x: md5.new(x).hexdigest()
382
 
        KD = lambda secret, data: H("%s:%s" % (secret, data))
383
 
 
384
 
        nonce_count = int(auth['nc'], 16)
385
 
 
386
 
        ncvalue = '%08x' % nonce_count
387
 
 
388
 
        cnonce = auth['cnonce']
389
 
        noncebit = '%s:%s:%s:%s:%s' % (nonce, ncvalue, cnonce, qop, H(A2))
390
 
        response_digest = KD(H(A1), noncebit)
391
 
 
392
 
        return response_digest == auth['response']
393
 
 
394
 
class HTTPAuthServer(AuthServer):
395
 
    """An HTTP server requiring authentication"""
396
 
 
397
 
    def init_http_auth(self):
398
 
        self.auth_header_sent = 'WWW-Authenticate'
399
 
        self.auth_header_recv = 'Authorization'
400
 
        self.auth_error_code = 401
401
 
 
402
 
 
403
 
class ProxyAuthServer(AuthServer):
404
 
    """A proxy server requiring authentication"""
405
 
 
406
 
    def init_proxy_auth(self):
407
 
        self.proxy_requests = True
408
 
        self.auth_header_sent = 'Proxy-Authenticate'
409
 
        self.auth_header_recv = 'Proxy-Authorization'
410
 
        self.auth_error_code = 407
411
 
 
412
 
 
413
 
class HTTPBasicAuthServer(HTTPAuthServer):
414
 
    """An HTTP server requiring basic authentication"""
415
 
 
416
 
    def __init__(self, protocol_version=None):
417
 
        HTTPAuthServer.__init__(self, BasicAuthRequestHandler, 'basic',
418
 
                                protocol_version=protocol_version)
419
 
        self.init_http_auth()
420
 
 
421
 
 
422
 
class HTTPDigestAuthServer(DigestAuthServer, HTTPAuthServer):
423
 
    """An HTTP server requiring digest authentication"""
424
 
 
425
 
    def __init__(self, protocol_version=None):
426
 
        DigestAuthServer.__init__(self, DigestAuthRequestHandler, 'digest',
427
 
                                  protocol_version=protocol_version)
428
 
        self.init_http_auth()
429
 
 
430
 
 
431
 
class ProxyBasicAuthServer(ProxyAuthServer):
432
 
    """A proxy server requiring basic authentication"""
433
 
 
434
 
    def __init__(self, protocol_version=None):
435
 
        ProxyAuthServer.__init__(self, BasicAuthRequestHandler, 'basic',
436
 
                                 protocol_version=protocol_version)
437
 
        self.init_proxy_auth()
438
 
 
439
 
 
440
 
class ProxyDigestAuthServer(DigestAuthServer, ProxyAuthServer):
441
 
    """A proxy server requiring basic authentication"""
442
 
 
443
 
    def __init__(self, protocol_version=None):
444
 
        ProxyAuthServer.__init__(self, DigestAuthRequestHandler, 'digest',
445
 
                                 protocol_version=protocol_version)
446
 
        self.init_proxy_auth()
447
 
 
448
 
 
 
91
            remote_path = '/'.join(path_parts)
 
92
 
 
93
        self._http_starting.acquire()
 
94
        self._http_starting.release()
 
95
        return self._http_base_url + remote_path
 
96
 
 
97
    def setUp(self):
 
98
        super(TestCaseWithWebserver, self).setUp()
 
99
        import threading, os
 
100
        self._local_path_parts = self.test_dir.split(os.path.sep)
 
101
        self._http_starting = threading.Lock()
 
102
        self._http_starting.acquire()
 
103
        self._http_running = True
 
104
        self._http_base_url = None
 
105
        self._http_thread = threading.Thread(target=self._http_start)
 
106
        self._http_thread.setDaemon(True)
 
107
        self._http_thread.start()
 
108
 
 
109
    def tearDown(self):
 
110
        self._http_running = False
 
111
        self._http_thread.join()
 
112
        super(TestCaseWithWebserver, self).tearDown()