~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/HttpServer.py

  • Committer: Martin Pool
  • Date: 2005-09-30 05:56:05 UTC
  • mto: (1185.14.2)
  • mto: This revision was merged to the branch mainline in revision 1396.
  • Revision ID: mbp@sourcefrog.net-20050930055605-a2c534529b392a7d
- fix upgrade for transport changes

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
 
 
17
 
import BaseHTTPServer
18
 
import errno
19
 
import os
20
 
from SimpleHTTPServer import SimpleHTTPRequestHandler
21
 
import socket
22
 
import posixpath
23
 
import random
24
 
import re
25
 
import sys
26
 
import threading
27
 
import time
28
 
import urllib
29
 
import urlparse
30
 
 
31
 
from bzrlib.transport import Server
32
 
 
33
 
 
34
 
class WebserverNotAvailable(Exception):
35
 
    pass
36
 
 
37
 
 
38
 
class BadWebserverPath(ValueError):
39
 
    def __str__(self):
40
 
        return 'path %s is not in %s' % self.args
41
 
 
42
 
 
43
 
class TestingHTTPRequestHandler(SimpleHTTPRequestHandler):
44
 
 
45
 
    def log_message(self, format, *args):
46
 
        self.server.test_case.log('webserver - %s - - [%s] %s "%s" "%s"',
47
 
                                  self.address_string(),
48
 
                                  self.log_date_time_string(),
49
 
                                  format % args,
50
 
                                  self.headers.get('referer', '-'),
51
 
                                  self.headers.get('user-agent', '-'))
52
 
 
53
 
    def handle_one_request(self):
54
 
        """Handle a single HTTP request.
55
 
 
56
 
        You normally don't need to override this method; see the class
57
 
        __doc__ string for information on how to handle specific HTTP
58
 
        commands such as GET and POST.
59
 
 
60
 
        """
61
 
        for i in xrange(1,11): # Don't try more than 10 times
62
 
            try:
63
 
                self.raw_requestline = self.rfile.readline()
64
 
            except socket.error, e:
65
 
                if e.args[0] in (errno.EAGAIN, errno.EWOULDBLOCK):
66
 
                    # omitted for now because some tests look at the log of
67
 
                    # the server and expect to see no errors.  see recent
68
 
                    # email thread. -- mbp 20051021. 
69
 
                    ## self.log_message('EAGAIN (%d) while reading from raw_requestline' % i)
70
 
                    time.sleep(0.01)
71
 
                    continue
72
 
                raise
73
 
            else:
74
 
                break
75
 
        if not self.raw_requestline:
76
 
            self.close_connection = 1
77
 
            return
78
 
        if not self.parse_request(): # An error code has been sent, just exit
79
 
            return
80
 
        mname = 'do_' + self.command
81
 
        if getattr(self, mname, None) is None:
82
 
            self.send_error(501, "Unsupported method (%r)" % self.command)
83
 
            return
84
 
        method = getattr(self, mname)
85
 
        method()
86
 
 
87
 
    _range_regexp = re.compile(r'^(?P<start>\d+)-(?P<end>\d+)$')
88
 
    _tail_regexp = re.compile(r'^-(?P<tail>\d+)$')
89
 
 
90
 
    def parse_ranges(self, ranges_header):
91
 
        """Parse the range header value and returns ranges and tail.
92
 
 
93
 
        RFC2616 14.35 says that syntactically invalid range
94
 
        specifiers MUST be ignored. In that case, we return 0 for
95
 
        tail and [] for ranges.
96
 
        """
97
 
        tail = 0
98
 
        ranges = []
99
 
        if not ranges_header.startswith('bytes='):
100
 
            # Syntactically invalid header
101
 
            return 0, []
102
 
 
103
 
        ranges_header = ranges_header[len('bytes='):]
104
 
        for range_str in ranges_header.split(','):
105
 
            # FIXME: RFC2616 says end is optional and default to file_size
106
 
            range_match = self._range_regexp.match(range_str)
107
 
            if range_match is not None:
108
 
                start = int(range_match.group('start'))
109
 
                end = int(range_match.group('end'))
110
 
                if start > end:
111
 
                    # Syntactically invalid range
112
 
                    return 0, []
113
 
                ranges.append((start, end))
114
 
            else:
115
 
                tail_match = self._tail_regexp.match(range_str)
116
 
                if tail_match is not None:
117
 
                    tail = int(tail_match.group('tail'))
118
 
                else:
119
 
                    # Syntactically invalid range
120
 
                    return 0, []
121
 
        return tail, ranges
122
 
 
123
 
    def send_range_content(self, file, start, length):
124
 
        file.seek(start)
125
 
        self.wfile.write(file.read(length))
126
 
 
127
 
    def get_single_range(self, file, file_size, start, end):
128
 
        self.send_response(206)
129
 
        length = end - start + 1
130
 
        self.send_header('Accept-Ranges', 'bytes')
131
 
        self.send_header("Content-Length", "%d" % length)
132
 
 
133
 
        self.send_header("Content-Type", 'application/octet-stream')
134
 
        self.send_header("Content-Range", "bytes %d-%d/%d" % (start,
135
 
                                                              end,
136
 
                                                              file_size))
137
 
        self.end_headers()
138
 
        self.send_range_content(file, start, length)
139
 
 
140
 
    def get_multiple_ranges(self, file, file_size, ranges):
141
 
        self.send_response(206)
142
 
        self.send_header('Accept-Ranges', 'bytes')
143
 
        boundary = "%d" % random.randint(0,0x7FFFFFFF)
144
 
        self.send_header("Content-Type",
145
 
                         "multipart/byteranges; boundary=%s" % boundary)
146
 
        self.end_headers()
147
 
        for (start, end) in ranges:
148
 
            self.wfile.write("--%s\r\n" % boundary)
149
 
            self.send_header("Content-type", 'application/octet-stream')
150
 
            self.send_header("Content-Range", "bytes %d-%d/%d" % (start,
151
 
                                                                  end,
152
 
                                                                  file_size))
153
 
            self.end_headers()
154
 
            self.send_range_content(file, start, end - start + 1)
155
 
            self.wfile.write("--%s\r\n" % boundary)
156
 
            pass
157
 
 
158
 
    def do_GET(self):
159
 
        """Serve a GET request.
160
 
 
161
 
        Handles the Range header.
162
 
        """
163
 
 
164
 
        path = self.translate_path(self.path)
165
 
        ranges_header_value = self.headers.get('Range')
166
 
        if ranges_header_value is None or os.path.isdir(path):
167
 
            # Let the mother class handle most cases
168
 
            return SimpleHTTPRequestHandler.do_GET(self)
169
 
 
170
 
        try:
171
 
            # Always read in binary mode. Opening files in text
172
 
            # mode may cause newline translations, making the
173
 
            # actual size of the content transmitted *less* than
174
 
            # the content-length!
175
 
            file = open(path, 'rb')
176
 
        except IOError:
177
 
            self.send_error(404, "File not found")
178
 
            return
179
 
 
180
 
        file_size = os.fstat(file.fileno())[6]
181
 
        tail, ranges = self.parse_ranges(ranges_header_value)
182
 
        # Normalize tail into ranges
183
 
        if tail != 0:
184
 
            ranges.append((file_size - tail, file_size))
185
 
 
186
 
        self._satisfiable_ranges = True
187
 
        if len(ranges) == 0:
188
 
            self._satisfiable_ranges = False
189
 
        else:
190
 
            def check_range(range_specifier):
191
 
                start, end = range_specifier
192
 
                # RFC2616 14.35, ranges are invalid if start >= file_size
193
 
                if start >= file_size:
194
 
                    self._satisfiable_ranges = False # Side-effect !
195
 
                    return 0, 0
196
 
                # RFC2616 14.35, end values should be truncated
197
 
                # to file_size -1 if they exceed it
198
 
                end = min(end, file_size - 1)
199
 
                return start, end
200
 
 
201
 
            ranges = map(check_range, ranges)
202
 
 
203
 
        if not self._satisfiable_ranges:
204
 
            # RFC2616 14.16 and 14.35 says that when a server
205
 
            # encounters unsatisfiable range specifiers, it
206
 
            # SHOULD return a 416.
207
 
            file.close()
208
 
            # FIXME: We SHOULD send a Content-Range header too,
209
 
            # but the implementation of send_error does not
210
 
            # allows that. So far.
211
 
            self.send_error(416, "Requested range not satisfiable")
212
 
            return
213
 
 
214
 
        if len(ranges) == 1:
215
 
            (start, end) = ranges[0]
216
 
            self.get_single_range(file, file_size, start, end)
217
 
        else:
218
 
            self.get_multiple_ranges(file, file_size, ranges)
219
 
        file.close()
220
 
 
221
 
    if sys.platform == 'win32':
222
 
        # On win32 you cannot access non-ascii filenames without
223
 
        # decoding them into unicode first.
224
 
        # However, under Linux, you can access bytestream paths
225
 
        # without any problems. If this function was always active
226
 
        # it would probably break tests when LANG=C was set
227
 
        def translate_path(self, path):
228
 
            """Translate a /-separated PATH to the local filename syntax.
229
 
 
230
 
            For bzr, all url paths are considered to be utf8 paths.
231
 
            On Linux, you can access these paths directly over the bytestream
232
 
            request, but on win32, you must decode them, and access them
233
 
            as Unicode files.
234
 
            """
235
 
            # abandon query parameters
236
 
            path = urlparse.urlparse(path)[2]
237
 
            path = posixpath.normpath(urllib.unquote(path))
238
 
            path = path.decode('utf-8')
239
 
            words = path.split('/')
240
 
            words = filter(None, words)
241
 
            path = os.getcwdu()
242
 
            for word in words:
243
 
                drive, word = os.path.splitdrive(word)
244
 
                head, word = os.path.split(word)
245
 
                if word in (os.curdir, os.pardir): continue
246
 
                path = os.path.join(path, word)
247
 
            return path
248
 
 
249
 
 
250
 
class TestingHTTPServer(BaseHTTPServer.HTTPServer):
251
 
    def __init__(self, server_address, RequestHandlerClass, test_case):
252
 
        BaseHTTPServer.HTTPServer.__init__(self, server_address,
253
 
                                                RequestHandlerClass)
254
 
        self.test_case = test_case
255
 
 
256
 
 
257
 
class HttpServer(Server):
258
 
    """A test server for http transports.
259
 
 
260
 
    Subclasses can provide a specific request handler.
261
 
    """
262
 
 
263
 
    # used to form the url that connects to this server
264
 
    _url_protocol = 'http'
265
 
 
266
 
    # Subclasses can provide a specific request handler
267
 
    def __init__(self, request_handler=TestingHTTPRequestHandler):
268
 
        Server.__init__(self)
269
 
        self.request_handler = request_handler
270
 
 
271
 
    def _get_httpd(self):
272
 
        return TestingHTTPServer(('localhost', 0),
273
 
                                  self.request_handler,
274
 
                                  self)
275
 
 
276
 
    def _http_start(self):
277
 
        httpd = None
278
 
        httpd = self._get_httpd()
279
 
        host, self.port = httpd.socket.getsockname()
280
 
        self._http_base_url = '%s://localhost:%s/' % (self._url_protocol,
281
 
                                                      self.port)
282
 
        self._http_starting.release()
283
 
        httpd.socket.settimeout(0.1)
284
 
 
285
 
        while self._http_running:
286
 
            try:
287
 
                httpd.handle_request()
288
 
            except socket.timeout:
289
 
                pass
290
 
 
291
 
    def _get_remote_url(self, path):
292
 
        path_parts = path.split(os.path.sep)
293
 
        if os.path.isabs(path):
294
 
            if path_parts[:len(self._local_path_parts)] != \
295
 
                   self._local_path_parts:
296
 
                raise BadWebserverPath(path, self.test_dir)
297
 
            remote_path = '/'.join(path_parts[len(self._local_path_parts):])
298
 
        else:
299
 
            remote_path = '/'.join(path_parts)
300
 
 
301
 
        return self._http_base_url + remote_path
302
 
 
303
 
    def log(self, format, *args):
304
 
        """Capture Server log output."""
305
 
        self.logs.append(format % args)
306
 
 
307
 
    def setUp(self):
308
 
        """See bzrlib.transport.Server.setUp."""
309
 
        self._home_dir = os.getcwdu()
310
 
        self._local_path_parts = self._home_dir.split(os.path.sep)
311
 
        self._http_starting = threading.Lock()
312
 
        self._http_starting.acquire()
313
 
        self._http_running = True
314
 
        self._http_base_url = None
315
 
        self._http_thread = threading.Thread(target=self._http_start)
316
 
        self._http_thread.setDaemon(True)
317
 
        self._http_thread.start()
318
 
        # Wait for the server thread to start (i.e release the lock)
319
 
        self._http_starting.acquire()
320
 
        self._http_starting.release()
321
 
        self.logs = []
322
 
 
323
 
    def tearDown(self):
324
 
        """See bzrlib.transport.Server.tearDown."""
325
 
        self._http_running = False
326
 
        self._http_thread.join()
327
 
 
328
 
    def get_url(self):
329
 
        """See bzrlib.transport.Server.get_url."""
330
 
        return self._get_remote_url(self._home_dir)
331
 
 
332
 
    def get_bogus_url(self):
333
 
        """See bzrlib.transport.Server.get_bogus_url."""
334
 
        # this is chosen to try to prevent trouble with proxies, weird dns,
335
 
        # etc
336
 
        return 'http://127.0.0.1:1/'
337
 
 
338
 
 
339
 
class HttpServer_urllib(HttpServer):
340
 
    """Subclass of HttpServer that gives http+urllib urls.
341
 
 
342
 
    This is for use in testing: connections to this server will always go
343
 
    through urllib where possible.
344
 
    """
345
 
 
346
 
    # urls returned by this server should require the urllib client impl
347
 
    _url_protocol = 'http+urllib'
348
 
 
349
 
 
350
 
class HttpServer_PyCurl(HttpServer):
351
 
    """Subclass of HttpServer that gives http+pycurl urls.
352
 
 
353
 
    This is for use in testing: connections to this server will always go
354
 
    through pycurl where possible.
355
 
    """
356
 
 
357
 
    # We don't care about checking the pycurl availability as
358
 
    # this server will be required only when pycurl is present
359
 
 
360
 
    # urls returned by this server should require the pycurl client impl
361
 
    _url_protocol = 'http+pycurl'