~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_http.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2008-07-17 08:29:40 UTC
  • mfrom: (3549.1.4 stacking)
  • Revision ID: pqm@pqm.ubuntu.com-20080717082940-zdwz5cqhdoot1swx
(mbp) stacking post-merge review tweaks

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005, 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Tests for HTTP implementations.
 
18
 
 
19
This module defines a load_tests() method that parametrize tests classes for
 
20
transport implementation, http protocol versions and authentication schemes.
 
21
"""
 
22
 
 
23
# TODO: Should be renamed to bzrlib.transport.http.tests?
 
24
# TODO: What about renaming to bzrlib.tests.transport.http ?
 
25
 
 
26
from cStringIO import StringIO
 
27
import httplib
 
28
import os
 
29
import select
 
30
import SimpleHTTPServer
 
31
import socket
 
32
import sys
 
33
import threading
 
34
 
 
35
import bzrlib
 
36
from bzrlib import (
 
37
    config,
 
38
    errors,
 
39
    osutils,
 
40
    tests,
 
41
    transport,
 
42
    ui,
 
43
    urlutils,
 
44
    )
 
45
from bzrlib.tests import (
 
46
    http_server,
 
47
    http_utils,
 
48
    )
 
49
from bzrlib.transport import (
 
50
    http,
 
51
    remote,
 
52
    )
 
53
from bzrlib.transport.http import (
 
54
    _urllib,
 
55
    _urllib2_wrappers,
 
56
    )
 
57
 
 
58
 
 
59
try:
 
60
    from bzrlib.transport.http._pycurl import PyCurlTransport
 
61
    pycurl_present = True
 
62
except errors.DependencyNotPresent:
 
63
    pycurl_present = False
 
64
 
 
65
 
 
66
class TransportAdapter(tests.TestScenarioApplier):
 
67
    """Generate the same test for each transport implementation."""
 
68
 
 
69
    def __init__(self):
 
70
        transport_scenarios = [
 
71
            ('urllib', dict(_transport=_urllib.HttpTransport_urllib,
 
72
                            _server=http_server.HttpServer_urllib,
 
73
                            _qualified_prefix='http+urllib',)),
 
74
            ]
 
75
        if pycurl_present:
 
76
            transport_scenarios.append(
 
77
                ('pycurl', dict(_transport=PyCurlTransport,
 
78
                                _server=http_server.HttpServer_PyCurl,
 
79
                                _qualified_prefix='http+pycurl',)))
 
80
        self.scenarios = transport_scenarios
 
81
 
 
82
 
 
83
class TransportProtocolAdapter(TransportAdapter):
 
84
    """Generate the same test for each protocol implementation.
 
85
 
 
86
    In addition to the transport adaptatation that we inherit from.
 
87
    """
 
88
 
 
89
    def __init__(self):
 
90
        super(TransportProtocolAdapter, self).__init__()
 
91
        protocol_scenarios = [
 
92
            ('HTTP/1.0',  dict(_protocol_version='HTTP/1.0')),
 
93
            ('HTTP/1.1',  dict(_protocol_version='HTTP/1.1')),
 
94
            ]
 
95
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
96
                                                  protocol_scenarios)
 
97
 
 
98
 
 
99
class TransportProtocolAuthenticationAdapter(TransportProtocolAdapter):
 
100
    """Generate the same test for each authentication scheme implementation.
 
101
 
 
102
    In addition to the protocol adaptatation that we inherit from.
 
103
    """
 
104
 
 
105
    def __init__(self):
 
106
        super(TransportProtocolAuthenticationAdapter, self).__init__()
 
107
        auth_scheme_scenarios = [
 
108
            ('basic', dict(_auth_scheme='basic')),
 
109
            ('digest', dict(_auth_scheme='digest')),
 
110
            ]
 
111
 
 
112
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
113
                                                  auth_scheme_scenarios)
 
114
 
 
115
def load_tests(standard_tests, module, loader):
 
116
    """Multiply tests for http clients and protocol versions."""
 
117
    # one for each transport
 
118
    t_adapter = TransportAdapter()
 
119
    t_classes= (TestHttpTransportRegistration,
 
120
                TestHttpTransportUrls,
 
121
                )
 
122
    is_testing_for_transports = tests.condition_isinstance(t_classes)
 
123
 
 
124
    # multiplied by one for each protocol version
 
125
    tp_adapter = TransportProtocolAdapter()
 
126
    tp_classes= (SmartHTTPTunnellingTest,
 
127
                 TestDoCatchRedirections,
 
128
                 TestHTTPConnections,
 
129
                 TestHTTPRedirections,
 
130
                 TestHTTPSilentRedirections,
 
131
                 TestLimitedRangeRequestServer,
 
132
                 TestPost,
 
133
                 TestProxyHttpServer,
 
134
                 TestRanges,
 
135
                 TestSpecificRequestHandler,
 
136
                 )
 
137
    is_also_testing_for_protocols = tests.condition_isinstance(tp_classes)
 
138
 
 
139
    # multiplied by one for each authentication scheme
 
140
    tpa_adapter = TransportProtocolAuthenticationAdapter()
 
141
    tpa_classes = (TestAuth,
 
142
                   )
 
143
    is_also_testing_for_authentication = tests.condition_isinstance(
 
144
        tpa_classes)
 
145
 
 
146
    result = loader.suiteClass()
 
147
    for test_class in tests.iter_suite_tests(standard_tests):
 
148
        # Each test class is either standalone or testing for some combination
 
149
        # of transport, protocol version, authentication scheme. Use the right
 
150
        # adpater (or none) depending on the class.
 
151
        if is_testing_for_transports(test_class):
 
152
            result.addTests(t_adapter.adapt(test_class))
 
153
        elif is_also_testing_for_protocols(test_class):
 
154
            result.addTests(tp_adapter.adapt(test_class))
 
155
        elif is_also_testing_for_authentication(test_class):
 
156
            result.addTests(tpa_adapter.adapt(test_class))
 
157
        else:
 
158
            result.addTest(test_class)
 
159
    return result
 
160
 
 
161
 
 
162
class FakeManager(object):
 
163
 
 
164
    def __init__(self):
 
165
        self.credentials = []
 
166
 
 
167
    def add_password(self, realm, host, username, password):
 
168
        self.credentials.append([realm, host, username, password])
 
169
 
 
170
 
 
171
class RecordingServer(object):
 
172
    """A fake HTTP server.
 
173
    
 
174
    It records the bytes sent to it, and replies with a 200.
 
175
    """
 
176
 
 
177
    def __init__(self, expect_body_tail=None):
 
178
        """Constructor.
 
179
 
 
180
        :type expect_body_tail: str
 
181
        :param expect_body_tail: a reply won't be sent until this string is
 
182
            received.
 
183
        """
 
184
        self._expect_body_tail = expect_body_tail
 
185
        self.host = None
 
186
        self.port = None
 
187
        self.received_bytes = ''
 
188
 
 
189
    def setUp(self):
 
190
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
191
        self._sock.bind(('127.0.0.1', 0))
 
192
        self.host, self.port = self._sock.getsockname()
 
193
        self._ready = threading.Event()
 
194
        self._thread = threading.Thread(target=self._accept_read_and_reply)
 
195
        self._thread.setDaemon(True)
 
196
        self._thread.start()
 
197
        self._ready.wait(5)
 
198
 
 
199
    def _accept_read_and_reply(self):
 
200
        self._sock.listen(1)
 
201
        self._ready.set()
 
202
        self._sock.settimeout(5)
 
203
        try:
 
204
            conn, address = self._sock.accept()
 
205
            # On win32, the accepted connection will be non-blocking to start
 
206
            # with because we're using settimeout.
 
207
            conn.setblocking(True)
 
208
            while not self.received_bytes.endswith(self._expect_body_tail):
 
209
                self.received_bytes += conn.recv(4096)
 
210
            conn.sendall('HTTP/1.1 200 OK\r\n')
 
211
        except socket.timeout:
 
212
            # Make sure the client isn't stuck waiting for us to e.g. accept.
 
213
            self._sock.close()
 
214
        except socket.error:
 
215
            # The client may have already closed the socket.
 
216
            pass
 
217
 
 
218
    def tearDown(self):
 
219
        try:
 
220
            self._sock.close()
 
221
        except socket.error:
 
222
            # We might have already closed it.  We don't care.
 
223
            pass
 
224
        self.host = None
 
225
        self.port = None
 
226
 
 
227
 
 
228
class TestHTTPServer(tests.TestCase):
 
229
    """Test the HTTP servers implementations."""
 
230
 
 
231
    def test_invalid_protocol(self):
 
232
        class BogusRequestHandler(http_server.TestingHTTPRequestHandler):
 
233
 
 
234
            protocol_version = 'HTTP/0.1'
 
235
 
 
236
        server = http_server.HttpServer(BogusRequestHandler)
 
237
        try:
 
238
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
239
        except:
 
240
            server.tearDown()
 
241
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
242
 
 
243
    def test_force_invalid_protocol(self):
 
244
        server = http_server.HttpServer(protocol_version='HTTP/0.1')
 
245
        try:
 
246
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
247
        except:
 
248
            server.tearDown()
 
249
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
250
 
 
251
    def test_server_start_and_stop(self):
 
252
        server = http_server.HttpServer()
 
253
        server.setUp()
 
254
        self.assertTrue(server._http_running)
 
255
        server.tearDown()
 
256
        self.assertFalse(server._http_running)
 
257
 
 
258
    def test_create_http_server_one_zero(self):
 
259
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
260
 
 
261
            protocol_version = 'HTTP/1.0'
 
262
 
 
263
        server = http_server.HttpServer(RequestHandlerOneZero)
 
264
        server.setUp()
 
265
        self.addCleanup(server.tearDown)
 
266
        self.assertIsInstance(server._httpd, http_server.TestingHTTPServer)
 
267
 
 
268
    def test_create_http_server_one_one(self):
 
269
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
270
 
 
271
            protocol_version = 'HTTP/1.1'
 
272
 
 
273
        server = http_server.HttpServer(RequestHandlerOneOne)
 
274
        server.setUp()
 
275
        self.addCleanup(server.tearDown)
 
276
        self.assertIsInstance(server._httpd,
 
277
                              http_server.TestingThreadingHTTPServer)
 
278
 
 
279
    def test_create_http_server_force_one_one(self):
 
280
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
281
 
 
282
            protocol_version = 'HTTP/1.0'
 
283
 
 
284
        server = http_server.HttpServer(RequestHandlerOneZero,
 
285
                                        protocol_version='HTTP/1.1')
 
286
        server.setUp()
 
287
        self.addCleanup(server.tearDown)
 
288
        self.assertIsInstance(server._httpd,
 
289
                              http_server.TestingThreadingHTTPServer)
 
290
 
 
291
    def test_create_http_server_force_one_zero(self):
 
292
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
293
 
 
294
            protocol_version = 'HTTP/1.1'
 
295
 
 
296
        server = http_server.HttpServer(RequestHandlerOneOne,
 
297
                                        protocol_version='HTTP/1.0')
 
298
        server.setUp()
 
299
        self.addCleanup(server.tearDown)
 
300
        self.assertIsInstance(server._httpd,
 
301
                              http_server.TestingHTTPServer)
 
302
 
 
303
 
 
304
class TestWithTransport_pycurl(object):
 
305
    """Test case to inherit from if pycurl is present"""
 
306
 
 
307
    def _get_pycurl_maybe(self):
 
308
        try:
 
309
            from bzrlib.transport.http._pycurl import PyCurlTransport
 
310
            return PyCurlTransport
 
311
        except errors.DependencyNotPresent:
 
312
            raise tests.TestSkipped('pycurl not present')
 
313
 
 
314
    _transport = property(_get_pycurl_maybe)
 
315
 
 
316
 
 
317
class TestHttpUrls(tests.TestCase):
 
318
 
 
319
    # TODO: This should be moved to authorization tests once they
 
320
    # are written.
 
321
 
 
322
    def test_url_parsing(self):
 
323
        f = FakeManager()
 
324
        url = http.extract_auth('http://example.com', f)
 
325
        self.assertEquals('http://example.com', url)
 
326
        self.assertEquals(0, len(f.credentials))
 
327
        url = http.extract_auth(
 
328
            'http://user:pass@www.bazaar-vcs.org/bzr/bzr.dev', f)
 
329
        self.assertEquals('http://www.bazaar-vcs.org/bzr/bzr.dev', url)
 
330
        self.assertEquals(1, len(f.credentials))
 
331
        self.assertEquals([None, 'www.bazaar-vcs.org', 'user', 'pass'],
 
332
                          f.credentials[0])
 
333
 
 
334
 
 
335
class TestHttpTransportUrls(tests.TestCase):
 
336
    """Test the http urls."""
 
337
 
 
338
    def test_abs_url(self):
 
339
        """Construction of absolute http URLs"""
 
340
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
341
        eq = self.assertEqualDiff
 
342
        eq(t.abspath('.'), 'http://bazaar-vcs.org/bzr/bzr.dev')
 
343
        eq(t.abspath('foo/bar'), 'http://bazaar-vcs.org/bzr/bzr.dev/foo/bar')
 
344
        eq(t.abspath('.bzr'), 'http://bazaar-vcs.org/bzr/bzr.dev/.bzr')
 
345
        eq(t.abspath('.bzr/1//2/./3'),
 
346
           'http://bazaar-vcs.org/bzr/bzr.dev/.bzr/1/2/3')
 
347
 
 
348
    def test_invalid_http_urls(self):
 
349
        """Trap invalid construction of urls"""
 
350
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
351
        self.assertRaises(errors.InvalidURL,
 
352
                          self._transport,
 
353
                          'http://http://bazaar-vcs.org/bzr/bzr.dev/')
 
354
 
 
355
    def test_http_root_urls(self):
 
356
        """Construction of URLs from server root"""
 
357
        t = self._transport('http://bzr.ozlabs.org/')
 
358
        eq = self.assertEqualDiff
 
359
        eq(t.abspath('.bzr/tree-version'),
 
360
           'http://bzr.ozlabs.org/.bzr/tree-version')
 
361
 
 
362
    def test_http_impl_urls(self):
 
363
        """There are servers which ask for particular clients to connect"""
 
364
        server = self._server()
 
365
        try:
 
366
            server.setUp()
 
367
            url = server.get_url()
 
368
            self.assertTrue(url.startswith('%s://' % self._qualified_prefix))
 
369
        finally:
 
370
            server.tearDown()
 
371
 
 
372
 
 
373
class TestHttps_pycurl(TestWithTransport_pycurl, tests.TestCase):
 
374
 
 
375
    # TODO: This should really be moved into another pycurl
 
376
    # specific test. When https tests will be implemented, take
 
377
    # this one into account.
 
378
    def test_pycurl_without_https_support(self):
 
379
        """Test that pycurl without SSL do not fail with a traceback.
 
380
 
 
381
        For the purpose of the test, we force pycurl to ignore
 
382
        https by supplying a fake version_info that do not
 
383
        support it.
 
384
        """
 
385
        try:
 
386
            import pycurl
 
387
        except ImportError:
 
388
            raise tests.TestSkipped('pycurl not present')
 
389
 
 
390
        version_info_orig = pycurl.version_info
 
391
        try:
 
392
            # Now that we have pycurl imported, we can fake its version_info
 
393
            # This was taken from a windows pycurl without SSL
 
394
            # (thanks to bialix)
 
395
            pycurl.version_info = lambda : (2,
 
396
                                            '7.13.2',
 
397
                                            462082,
 
398
                                            'i386-pc-win32',
 
399
                                            2576,
 
400
                                            None,
 
401
                                            0,
 
402
                                            None,
 
403
                                            ('ftp', 'gopher', 'telnet',
 
404
                                             'dict', 'ldap', 'http', 'file'),
 
405
                                            None,
 
406
                                            0,
 
407
                                            None)
 
408
            self.assertRaises(errors.DependencyNotPresent, self._transport,
 
409
                              'https://launchpad.net')
 
410
        finally:
 
411
            # Restore the right function
 
412
            pycurl.version_info = version_info_orig
 
413
 
 
414
 
 
415
class TestHTTPConnections(http_utils.TestCaseWithWebserver):
 
416
    """Test the http connections."""
 
417
 
 
418
    def setUp(self):
 
419
        http_utils.TestCaseWithWebserver.setUp(self)
 
420
        self.build_tree(['foo/', 'foo/bar'], line_endings='binary',
 
421
                        transport=self.get_transport())
 
422
 
 
423
    def test_http_has(self):
 
424
        server = self.get_readonly_server()
 
425
        t = self._transport(server.get_url())
 
426
        self.assertEqual(t.has('foo/bar'), True)
 
427
        self.assertEqual(len(server.logs), 1)
 
428
        self.assertContainsRe(server.logs[0],
 
429
            r'"HEAD /foo/bar HTTP/1.." (200|302) - "-" "bzr/')
 
430
 
 
431
    def test_http_has_not_found(self):
 
432
        server = self.get_readonly_server()
 
433
        t = self._transport(server.get_url())
 
434
        self.assertEqual(t.has('not-found'), False)
 
435
        self.assertContainsRe(server.logs[1],
 
436
            r'"HEAD /not-found HTTP/1.." 404 - "-" "bzr/')
 
437
 
 
438
    def test_http_get(self):
 
439
        server = self.get_readonly_server()
 
440
        t = self._transport(server.get_url())
 
441
        fp = t.get('foo/bar')
 
442
        self.assertEqualDiff(
 
443
            fp.read(),
 
444
            'contents of foo/bar\n')
 
445
        self.assertEqual(len(server.logs), 1)
 
446
        self.assertTrue(server.logs[0].find(
 
447
            '"GET /foo/bar HTTP/1.1" 200 - "-" "bzr/%s'
 
448
            % bzrlib.__version__) > -1)
 
449
 
 
450
    def test_get_smart_medium(self):
 
451
        # For HTTP, get_smart_medium should return the transport object.
 
452
        server = self.get_readonly_server()
 
453
        http_transport = self._transport(server.get_url())
 
454
        medium = http_transport.get_smart_medium()
 
455
        self.assertIs(medium, http_transport)
 
456
 
 
457
    def test_has_on_bogus_host(self):
 
458
        # Get a free address and don't 'accept' on it, so that we
 
459
        # can be sure there is no http handler there, but set a
 
460
        # reasonable timeout to not slow down tests too much.
 
461
        default_timeout = socket.getdefaulttimeout()
 
462
        try:
 
463
            socket.setdefaulttimeout(2)
 
464
            s = socket.socket()
 
465
            s.bind(('localhost', 0))
 
466
            t = self._transport('http://%s:%s/' % s.getsockname())
 
467
            self.assertRaises(errors.ConnectionError, t.has, 'foo/bar')
 
468
        finally:
 
469
            socket.setdefaulttimeout(default_timeout)
 
470
 
 
471
 
 
472
class TestHttpTransportRegistration(tests.TestCase):
 
473
    """Test registrations of various http implementations"""
 
474
 
 
475
    def test_http_registered(self):
 
476
        t = transport.get_transport('%s://foo.com/' % self._qualified_prefix)
 
477
        self.assertIsInstance(t, transport.Transport)
 
478
        self.assertIsInstance(t, self._transport)
 
479
 
 
480
 
 
481
class TestPost(tests.TestCase):
 
482
 
 
483
    def test_post_body_is_received(self):
 
484
        server = RecordingServer(expect_body_tail='end-of-body')
 
485
        server.setUp()
 
486
        self.addCleanup(server.tearDown)
 
487
        scheme = self._qualified_prefix
 
488
        url = '%s://%s:%s/' % (scheme, server.host, server.port)
 
489
        http_transport = self._transport(url)
 
490
        code, response = http_transport._post('abc def end-of-body')
 
491
        self.assertTrue(
 
492
            server.received_bytes.startswith('POST /.bzr/smart HTTP/1.'))
 
493
        self.assertTrue('content-length: 19\r' in server.received_bytes.lower())
 
494
        # The transport should not be assuming that the server can accept
 
495
        # chunked encoding the first time it connects, because HTTP/1.1, so we
 
496
        # check for the literal string.
 
497
        self.assertTrue(
 
498
            server.received_bytes.endswith('\r\n\r\nabc def end-of-body'))
 
499
 
 
500
 
 
501
class TestRangeHeader(tests.TestCase):
 
502
    """Test range_header method"""
 
503
 
 
504
    def check_header(self, value, ranges=[], tail=0):
 
505
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
506
        coalesce = transport.Transport._coalesce_offsets
 
507
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
508
        range_header = http.HttpTransportBase._range_header
 
509
        self.assertEqual(value, range_header(coalesced, tail))
 
510
 
 
511
    def test_range_header_single(self):
 
512
        self.check_header('0-9', ranges=[(0,9)])
 
513
        self.check_header('100-109', ranges=[(100,109)])
 
514
 
 
515
    def test_range_header_tail(self):
 
516
        self.check_header('-10', tail=10)
 
517
        self.check_header('-50', tail=50)
 
518
 
 
519
    def test_range_header_multi(self):
 
520
        self.check_header('0-9,100-200,300-5000',
 
521
                          ranges=[(0,9), (100, 200), (300,5000)])
 
522
 
 
523
    def test_range_header_mixed(self):
 
524
        self.check_header('0-9,300-5000,-50',
 
525
                          ranges=[(0,9), (300,5000)],
 
526
                          tail=50)
 
527
 
 
528
 
 
529
class TestSpecificRequestHandler(http_utils.TestCaseWithWebserver):
 
530
    """Tests a specific request handler.
 
531
 
 
532
    Daughter classes are expected to override _req_handler_class
 
533
    """
 
534
 
 
535
    # Provide a useful default
 
536
    _req_handler_class = http_server.TestingHTTPRequestHandler
 
537
 
 
538
    def create_transport_readonly_server(self):
 
539
        return http_server.HttpServer(self._req_handler_class,
 
540
                                      protocol_version=self._protocol_version)
 
541
 
 
542
    def _testing_pycurl(self):
 
543
        return pycurl_present and self._transport == PyCurlTransport
 
544
 
 
545
 
 
546
class WallRequestHandler(http_server.TestingHTTPRequestHandler):
 
547
    """Whatever request comes in, close the connection"""
 
548
 
 
549
    def handle_one_request(self):
 
550
        """Handle a single HTTP request, by abruptly closing the connection"""
 
551
        self.close_connection = 1
 
552
 
 
553
 
 
554
class TestWallServer(TestSpecificRequestHandler):
 
555
    """Tests exceptions during the connection phase"""
 
556
 
 
557
    _req_handler_class = WallRequestHandler
 
558
 
 
559
    def test_http_has(self):
 
560
        server = self.get_readonly_server()
 
561
        t = self._transport(server.get_url())
 
562
        # Unfortunately httplib (see HTTPResponse._read_status
 
563
        # for details) make no distinction between a closed
 
564
        # socket and badly formatted status line, so we can't
 
565
        # just test for ConnectionError, we have to test
 
566
        # InvalidHttpResponse too.
 
567
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
568
                          t.has, 'foo/bar')
 
569
 
 
570
    def test_http_get(self):
 
571
        server = self.get_readonly_server()
 
572
        t = self._transport(server.get_url())
 
573
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
574
                          t.get, 'foo/bar')
 
575
 
 
576
 
 
577
class BadStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
578
    """Whatever request comes in, returns a bad status"""
 
579
 
 
580
    def parse_request(self):
 
581
        """Fakes handling a single HTTP request, returns a bad status"""
 
582
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
583
        self.send_response(0, "Bad status")
 
584
        self.close_connection = 1
 
585
        return False
 
586
 
 
587
 
 
588
class TestBadStatusServer(TestSpecificRequestHandler):
 
589
    """Tests bad status from server."""
 
590
 
 
591
    _req_handler_class = BadStatusRequestHandler
 
592
 
 
593
    def test_http_has(self):
 
594
        server = self.get_readonly_server()
 
595
        t = self._transport(server.get_url())
 
596
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
597
 
 
598
    def test_http_get(self):
 
599
        server = self.get_readonly_server()
 
600
        t = self._transport(server.get_url())
 
601
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
602
 
 
603
 
 
604
class InvalidStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
605
    """Whatever request comes in, returns an invalid status"""
 
606
 
 
607
    def parse_request(self):
 
608
        """Fakes handling a single HTTP request, returns a bad status"""
 
609
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
610
        self.wfile.write("Invalid status line\r\n")
 
611
        return False
 
612
 
 
613
 
 
614
class TestInvalidStatusServer(TestBadStatusServer):
 
615
    """Tests invalid status from server.
 
616
 
 
617
    Both implementations raises the same error as for a bad status.
 
618
    """
 
619
 
 
620
    _req_handler_class = InvalidStatusRequestHandler
 
621
 
 
622
    def test_http_has(self):
 
623
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
624
            raise tests.KnownFailure(
 
625
                'pycurl hangs if the server send back garbage')
 
626
        super(TestInvalidStatusServer, self).test_http_has()
 
627
 
 
628
    def test_http_get(self):
 
629
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
630
            raise tests.KnownFailure(
 
631
                'pycurl hangs if the server send back garbage')
 
632
        super(TestInvalidStatusServer, self).test_http_get()
 
633
 
 
634
 
 
635
class BadProtocolRequestHandler(http_server.TestingHTTPRequestHandler):
 
636
    """Whatever request comes in, returns a bad protocol version"""
 
637
 
 
638
    def parse_request(self):
 
639
        """Fakes handling a single HTTP request, returns a bad status"""
 
640
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
641
        # Returns an invalid protocol version, but curl just
 
642
        # ignores it and those cannot be tested.
 
643
        self.wfile.write("%s %d %s\r\n" % ('HTTP/0.0',
 
644
                                           404,
 
645
                                           'Look at my protocol version'))
 
646
        return False
 
647
 
 
648
 
 
649
class TestBadProtocolServer(TestSpecificRequestHandler):
 
650
    """Tests bad protocol from server."""
 
651
 
 
652
    _req_handler_class = BadProtocolRequestHandler
 
653
 
 
654
    def setUp(self):
 
655
        if pycurl_present and self._transport == PyCurlTransport:
 
656
            raise tests.TestNotApplicable(
 
657
                "pycurl doesn't check the protocol version")
 
658
        super(TestBadProtocolServer, self).setUp()
 
659
 
 
660
    def test_http_has(self):
 
661
        server = self.get_readonly_server()
 
662
        t = self._transport(server.get_url())
 
663
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
664
 
 
665
    def test_http_get(self):
 
666
        server = self.get_readonly_server()
 
667
        t = self._transport(server.get_url())
 
668
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
669
 
 
670
 
 
671
class ForbiddenRequestHandler(http_server.TestingHTTPRequestHandler):
 
672
    """Whatever request comes in, returns a 403 code"""
 
673
 
 
674
    def parse_request(self):
 
675
        """Handle a single HTTP request, by replying we cannot handle it"""
 
676
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
677
        self.send_error(403)
 
678
        return False
 
679
 
 
680
 
 
681
class TestForbiddenServer(TestSpecificRequestHandler):
 
682
    """Tests forbidden server"""
 
683
 
 
684
    _req_handler_class = ForbiddenRequestHandler
 
685
 
 
686
    def test_http_has(self):
 
687
        server = self.get_readonly_server()
 
688
        t = self._transport(server.get_url())
 
689
        self.assertRaises(errors.TransportError, t.has, 'foo/bar')
 
690
 
 
691
    def test_http_get(self):
 
692
        server = self.get_readonly_server()
 
693
        t = self._transport(server.get_url())
 
694
        self.assertRaises(errors.TransportError, t.get, 'foo/bar')
 
695
 
 
696
 
 
697
class TestRecordingServer(tests.TestCase):
 
698
 
 
699
    def test_create(self):
 
700
        server = RecordingServer(expect_body_tail=None)
 
701
        self.assertEqual('', server.received_bytes)
 
702
        self.assertEqual(None, server.host)
 
703
        self.assertEqual(None, server.port)
 
704
 
 
705
    def test_setUp_and_tearDown(self):
 
706
        server = RecordingServer(expect_body_tail=None)
 
707
        server.setUp()
 
708
        try:
 
709
            self.assertNotEqual(None, server.host)
 
710
            self.assertNotEqual(None, server.port)
 
711
        finally:
 
712
            server.tearDown()
 
713
        self.assertEqual(None, server.host)
 
714
        self.assertEqual(None, server.port)
 
715
 
 
716
    def test_send_receive_bytes(self):
 
717
        server = RecordingServer(expect_body_tail='c')
 
718
        server.setUp()
 
719
        self.addCleanup(server.tearDown)
 
720
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
721
        sock.connect((server.host, server.port))
 
722
        sock.sendall('abc')
 
723
        self.assertEqual('HTTP/1.1 200 OK\r\n',
 
724
                         osutils.recv_all(sock, 4096))
 
725
        self.assertEqual('abc', server.received_bytes)
 
726
 
 
727
 
 
728
class TestRangeRequestServer(TestSpecificRequestHandler):
 
729
    """Tests readv requests against server.
 
730
 
 
731
    We test against default "normal" server.
 
732
    """
 
733
 
 
734
    def setUp(self):
 
735
        super(TestRangeRequestServer, self).setUp()
 
736
        self.build_tree_contents([('a', '0123456789')],)
 
737
 
 
738
    def test_readv(self):
 
739
        server = self.get_readonly_server()
 
740
        t = self._transport(server.get_url())
 
741
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
742
        self.assertEqual(l[0], (0, '0'))
 
743
        self.assertEqual(l[1], (1, '1'))
 
744
        self.assertEqual(l[2], (3, '34'))
 
745
        self.assertEqual(l[3], (9, '9'))
 
746
 
 
747
    def test_readv_out_of_order(self):
 
748
        server = self.get_readonly_server()
 
749
        t = self._transport(server.get_url())
 
750
        l = list(t.readv('a', ((1, 1), (9, 1), (0, 1), (3, 2))))
 
751
        self.assertEqual(l[0], (1, '1'))
 
752
        self.assertEqual(l[1], (9, '9'))
 
753
        self.assertEqual(l[2], (0, '0'))
 
754
        self.assertEqual(l[3], (3, '34'))
 
755
 
 
756
    def test_readv_invalid_ranges(self):
 
757
        server = self.get_readonly_server()
 
758
        t = self._transport(server.get_url())
 
759
 
 
760
        # This is intentionally reading off the end of the file
 
761
        # since we are sure that it cannot get there
 
762
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
763
                              t.readv, 'a', [(1,1), (8,10)])
 
764
 
 
765
        # This is trying to seek past the end of the file, it should
 
766
        # also raise a special error
 
767
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
768
                              t.readv, 'a', [(12,2)])
 
769
 
 
770
    def test_readv_multiple_get_requests(self):
 
771
        server = self.get_readonly_server()
 
772
        t = self._transport(server.get_url())
 
773
        # force transport to issue multiple requests
 
774
        t._max_readv_combine = 1
 
775
        t._max_get_ranges = 1
 
776
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
777
        self.assertEqual(l[0], (0, '0'))
 
778
        self.assertEqual(l[1], (1, '1'))
 
779
        self.assertEqual(l[2], (3, '34'))
 
780
        self.assertEqual(l[3], (9, '9'))
 
781
        # The server should have issued 4 requests
 
782
        self.assertEqual(4, server.GET_request_nb)
 
783
 
 
784
    def test_readv_get_max_size(self):
 
785
        server = self.get_readonly_server()
 
786
        t = self._transport(server.get_url())
 
787
        # force transport to issue multiple requests by limiting the number of
 
788
        # bytes by request. Note that this apply to coalesced offsets only, a
 
789
        # single range will keep its size even if bigger than the limit.
 
790
        t._get_max_size = 2
 
791
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
792
        self.assertEqual(l[0], (0, '0'))
 
793
        self.assertEqual(l[1], (1, '1'))
 
794
        self.assertEqual(l[2], (2, '2345'))
 
795
        self.assertEqual(l[3], (6, '6789'))
 
796
        # The server should have issued 3 requests
 
797
        self.assertEqual(3, server.GET_request_nb)
 
798
 
 
799
    def test_complete_readv_leave_pipe_clean(self):
 
800
        server = self.get_readonly_server()
 
801
        t = self._transport(server.get_url())
 
802
        # force transport to issue multiple requests
 
803
        t._get_max_size = 2
 
804
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
805
        # The server should have issued 3 requests
 
806
        self.assertEqual(3, server.GET_request_nb)
 
807
        self.assertEqual('0123456789', t.get_bytes('a'))
 
808
        self.assertEqual(4, server.GET_request_nb)
 
809
 
 
810
    def test_incomplete_readv_leave_pipe_clean(self):
 
811
        server = self.get_readonly_server()
 
812
        t = self._transport(server.get_url())
 
813
        # force transport to issue multiple requests
 
814
        t._get_max_size = 2
 
815
        # Don't collapse readv results into a list so that we leave unread
 
816
        # bytes on the socket
 
817
        ireadv = iter(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
818
        self.assertEqual((0, '0'), ireadv.next())
 
819
        # The server should have issued one request so far 
 
820
        self.assertEqual(1, server.GET_request_nb)
 
821
        self.assertEqual('0123456789', t.get_bytes('a'))
 
822
        # get_bytes issued an additional request, the readv pending ones are
 
823
        # lost
 
824
        self.assertEqual(2, server.GET_request_nb)
 
825
 
 
826
 
 
827
class SingleRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
828
    """Always reply to range request as if they were single.
 
829
 
 
830
    Don't be explicit about it, just to annoy the clients.
 
831
    """
 
832
 
 
833
    def get_multiple_ranges(self, file, file_size, ranges):
 
834
        """Answer as if it was a single range request and ignores the rest"""
 
835
        (start, end) = ranges[0]
 
836
        return self.get_single_range(file, file_size, start, end)
 
837
 
 
838
 
 
839
class TestSingleRangeRequestServer(TestRangeRequestServer):
 
840
    """Test readv against a server which accept only single range requests"""
 
841
 
 
842
    _req_handler_class = SingleRangeRequestHandler
 
843
 
 
844
 
 
845
class SingleOnlyRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
846
    """Only reply to simple range requests, errors out on multiple"""
 
847
 
 
848
    def get_multiple_ranges(self, file, file_size, ranges):
 
849
        """Refuses the multiple ranges request"""
 
850
        if len(ranges) > 1:
 
851
            file.close()
 
852
            self.send_error(416, "Requested range not satisfiable")
 
853
            return
 
854
        (start, end) = ranges[0]
 
855
        return self.get_single_range(file, file_size, start, end)
 
856
 
 
857
 
 
858
class TestSingleOnlyRangeRequestServer(TestRangeRequestServer):
 
859
    """Test readv against a server which only accept single range requests"""
 
860
 
 
861
    _req_handler_class = SingleOnlyRangeRequestHandler
 
862
 
 
863
 
 
864
class NoRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
865
    """Ignore range requests without notice"""
 
866
 
 
867
    def do_GET(self):
 
868
        # Update the statistics
 
869
        self.server.test_case_server.GET_request_nb += 1
 
870
        # Just bypass the range handling done by TestingHTTPRequestHandler
 
871
        return SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
 
872
 
 
873
 
 
874
class TestNoRangeRequestServer(TestRangeRequestServer):
 
875
    """Test readv against a server which do not accept range requests"""
 
876
 
 
877
    _req_handler_class = NoRangeRequestHandler
 
878
 
 
879
 
 
880
class MultipleRangeWithoutContentLengthRequestHandler(
 
881
    http_server.TestingHTTPRequestHandler):
 
882
    """Reply to multiple range requests without content length header."""
 
883
 
 
884
    def get_multiple_ranges(self, file, file_size, ranges):
 
885
        self.send_response(206)
 
886
        self.send_header('Accept-Ranges', 'bytes')
 
887
        boundary = "%d" % random.randint(0,0x7FFFFFFF)
 
888
        self.send_header("Content-Type",
 
889
                         "multipart/byteranges; boundary=%s" % boundary)
 
890
        self.end_headers()
 
891
        for (start, end) in ranges:
 
892
            self.wfile.write("--%s\r\n" % boundary)
 
893
            self.send_header("Content-type", 'application/octet-stream')
 
894
            self.send_header("Content-Range", "bytes %d-%d/%d" % (start,
 
895
                                                                  end,
 
896
                                                                  file_size))
 
897
            self.end_headers()
 
898
            self.send_range_content(file, start, end - start + 1)
 
899
        # Final boundary
 
900
        self.wfile.write("--%s\r\n" % boundary)
 
901
 
 
902
 
 
903
class TestMultipleRangeWithoutContentLengthServer(TestRangeRequestServer):
 
904
 
 
905
    _req_handler_class = MultipleRangeWithoutContentLengthRequestHandler
 
906
 
 
907
 
 
908
class TruncatedMultipleRangeRequestHandler(
 
909
    http_server.TestingHTTPRequestHandler):
 
910
    """Reply to multiple range requests truncating the last ones.
 
911
 
 
912
    This server generates responses whose Content-Length describes all the
 
913
    ranges, but fail to include the last ones leading to client short reads.
 
914
    This has been observed randomly with lighttpd (bug #179368).
 
915
    """
 
916
 
 
917
    _truncated_ranges = 2
 
918
 
 
919
    def get_multiple_ranges(self, file, file_size, ranges):
 
920
        self.send_response(206)
 
921
        self.send_header('Accept-Ranges', 'bytes')
 
922
        boundary = 'tagada'
 
923
        self.send_header('Content-Type',
 
924
                         'multipart/byteranges; boundary=%s' % boundary)
 
925
        boundary_line = '--%s\r\n' % boundary
 
926
        # Calculate the Content-Length
 
927
        content_length = 0
 
928
        for (start, end) in ranges:
 
929
            content_length += len(boundary_line)
 
930
            content_length += self._header_line_length(
 
931
                'Content-type', 'application/octet-stream')
 
932
            content_length += self._header_line_length(
 
933
                'Content-Range', 'bytes %d-%d/%d' % (start, end, file_size))
 
934
            content_length += len('\r\n') # end headers
 
935
            content_length += end - start # + 1
 
936
        content_length += len(boundary_line)
 
937
        self.send_header('Content-length', content_length)
 
938
        self.end_headers()
 
939
 
 
940
        # Send the multipart body
 
941
        cur = 0
 
942
        for (start, end) in ranges:
 
943
            self.wfile.write(boundary_line)
 
944
            self.send_header('Content-type', 'application/octet-stream')
 
945
            self.send_header('Content-Range', 'bytes %d-%d/%d'
 
946
                             % (start, end, file_size))
 
947
            self.end_headers()
 
948
            if cur + self._truncated_ranges >= len(ranges):
 
949
                # Abruptly ends the response and close the connection
 
950
                self.close_connection = 1
 
951
                return
 
952
            self.send_range_content(file, start, end - start + 1)
 
953
            cur += 1
 
954
        # No final boundary
 
955
        self.wfile.write(boundary_line)
 
956
 
 
957
 
 
958
class TestTruncatedMultipleRangeServer(TestSpecificRequestHandler):
 
959
 
 
960
    _req_handler_class = TruncatedMultipleRangeRequestHandler
 
961
 
 
962
    def setUp(self):
 
963
        super(TestTruncatedMultipleRangeServer, self).setUp()
 
964
        self.build_tree_contents([('a', '0123456789')],)
 
965
 
 
966
    def test_readv_with_short_reads(self):
 
967
        server = self.get_readonly_server()
 
968
        t = self._transport(server.get_url())
 
969
        # Force separate ranges for each offset
 
970
        t._bytes_to_read_before_seek = 0
 
971
        ireadv = iter(t.readv('a', ((0, 1), (2, 1), (4, 2), (9, 1))))
 
972
        self.assertEqual((0, '0'), ireadv.next())
 
973
        self.assertEqual((2, '2'), ireadv.next())
 
974
        if not self._testing_pycurl():
 
975
            # Only one request have been issued so far (except for pycurl that
 
976
            # try to read the whole response at once)
 
977
            self.assertEqual(1, server.GET_request_nb)
 
978
        self.assertEqual((4, '45'), ireadv.next())
 
979
        self.assertEqual((9, '9'), ireadv.next())
 
980
        # Both implementations issue 3 requests but:
 
981
        # - urllib does two multiple (4 ranges, then 2 ranges) then a single
 
982
        #   range,
 
983
        # - pycurl does two multiple (4 ranges, 4 ranges) then a single range
 
984
        self.assertEqual(3, server.GET_request_nb)
 
985
        # Finally the client have tried a single range request and stays in
 
986
        # that mode
 
987
        self.assertEqual('single', t._range_hint)
 
988
 
 
989
class LimitedRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
990
    """Errors out when range specifiers exceed the limit"""
 
991
 
 
992
    def get_multiple_ranges(self, file, file_size, ranges):
 
993
        """Refuses the multiple ranges request"""
 
994
        tcs = self.server.test_case_server
 
995
        if tcs.range_limit is not None and len(ranges) > tcs.range_limit:
 
996
            file.close()
 
997
            # Emulate apache behavior
 
998
            self.send_error(400, "Bad Request")
 
999
            return
 
1000
        return http_server.TestingHTTPRequestHandler.get_multiple_ranges(
 
1001
            self, file, file_size, ranges)
 
1002
 
 
1003
 
 
1004
class LimitedRangeHTTPServer(http_server.HttpServer):
 
1005
    """An HttpServer erroring out on requests with too much range specifiers"""
 
1006
 
 
1007
    def __init__(self, request_handler=LimitedRangeRequestHandler,
 
1008
                 protocol_version=None,
 
1009
                 range_limit=None):
 
1010
        http_server.HttpServer.__init__(self, request_handler,
 
1011
                                        protocol_version=protocol_version)
 
1012
        self.range_limit = range_limit
 
1013
 
 
1014
 
 
1015
class TestLimitedRangeRequestServer(http_utils.TestCaseWithWebserver):
 
1016
    """Tests readv requests against a server erroring out on too much ranges."""
 
1017
 
 
1018
    # Requests with more range specifiers will error out
 
1019
    range_limit = 3
 
1020
 
 
1021
    def create_transport_readonly_server(self):
 
1022
        return LimitedRangeHTTPServer(range_limit=self.range_limit,
 
1023
                                      protocol_version=self._protocol_version)
 
1024
 
 
1025
    def get_transport(self):
 
1026
        return self._transport(self.get_readonly_server().get_url())
 
1027
 
 
1028
    def setUp(self):
 
1029
        http_utils.TestCaseWithWebserver.setUp(self)
 
1030
        # We need to manipulate ranges that correspond to real chunks in the
 
1031
        # response, so we build a content appropriately.
 
1032
        filler = ''.join(['abcdefghij' for x in range(102)])
 
1033
        content = ''.join(['%04d' % v + filler for v in range(16)])
 
1034
        self.build_tree_contents([('a', content)],)
 
1035
 
 
1036
    def test_few_ranges(self):
 
1037
        t = self.get_transport()
 
1038
        l = list(t.readv('a', ((0, 4), (1024, 4), )))
 
1039
        self.assertEqual(l[0], (0, '0000'))
 
1040
        self.assertEqual(l[1], (1024, '0001'))
 
1041
        self.assertEqual(1, self.get_readonly_server().GET_request_nb)
 
1042
 
 
1043
    def test_more_ranges(self):
 
1044
        t = self.get_transport()
 
1045
        l = list(t.readv('a', ((0, 4), (1024, 4), (4096, 4), (8192, 4))))
 
1046
        self.assertEqual(l[0], (0, '0000'))
 
1047
        self.assertEqual(l[1], (1024, '0001'))
 
1048
        self.assertEqual(l[2], (4096, '0004'))
 
1049
        self.assertEqual(l[3], (8192, '0008'))
 
1050
        # The server will refuse to serve the first request (too much ranges),
 
1051
        # a second request will succeed.
 
1052
        self.assertEqual(2, self.get_readonly_server().GET_request_nb)
 
1053
 
 
1054
 
 
1055
class TestHttpProxyWhiteBox(tests.TestCase):
 
1056
    """Whitebox test proxy http authorization.
 
1057
 
 
1058
    Only the urllib implementation is tested here.
 
1059
    """
 
1060
 
 
1061
    def setUp(self):
 
1062
        tests.TestCase.setUp(self)
 
1063
        self._old_env = {}
 
1064
 
 
1065
    def tearDown(self):
 
1066
        self._restore_env()
 
1067
        tests.TestCase.tearDown(self)
 
1068
 
 
1069
    def _install_env(self, env):
 
1070
        for name, value in env.iteritems():
 
1071
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1072
 
 
1073
    def _restore_env(self):
 
1074
        for name, value in self._old_env.iteritems():
 
1075
            osutils.set_or_unset_env(name, value)
 
1076
 
 
1077
    def _proxied_request(self):
 
1078
        handler = _urllib2_wrappers.ProxyHandler()
 
1079
        request = _urllib2_wrappers.Request('GET','http://baz/buzzle')
 
1080
        handler.set_proxy(request, 'http')
 
1081
        return request
 
1082
 
 
1083
    def test_empty_user(self):
 
1084
        self._install_env({'http_proxy': 'http://bar.com'})
 
1085
        request = self._proxied_request()
 
1086
        self.assertFalse(request.headers.has_key('Proxy-authorization'))
 
1087
 
 
1088
    def test_invalid_proxy(self):
 
1089
        """A proxy env variable without scheme"""
 
1090
        self._install_env({'http_proxy': 'host:1234'})
 
1091
        self.assertRaises(errors.InvalidURL, self._proxied_request)
 
1092
 
 
1093
 
 
1094
class TestProxyHttpServer(http_utils.TestCaseWithTwoWebservers):
 
1095
    """Tests proxy server.
 
1096
 
 
1097
    Be aware that we do not setup a real proxy here. Instead, we
 
1098
    check that the *connection* goes through the proxy by serving
 
1099
    different content (the faked proxy server append '-proxied'
 
1100
    to the file names).
 
1101
    """
 
1102
 
 
1103
    # FIXME: We don't have an https server available, so we don't
 
1104
    # test https connections.
 
1105
 
 
1106
    def setUp(self):
 
1107
        super(TestProxyHttpServer, self).setUp()
 
1108
        self.build_tree_contents([('foo', 'contents of foo\n'),
 
1109
                                  ('foo-proxied', 'proxied contents of foo\n')])
 
1110
        # Let's setup some attributes for tests
 
1111
        self.server = self.get_readonly_server()
 
1112
        self.proxy_address = '%s:%d' % (self.server.host, self.server.port)
 
1113
        if self._testing_pycurl():
 
1114
            # Oh my ! pycurl does not check for the port as part of
 
1115
            # no_proxy :-( So we just test the host part
 
1116
            self.no_proxy_host = 'localhost'
 
1117
        else:
 
1118
            self.no_proxy_host = self.proxy_address
 
1119
        # The secondary server is the proxy
 
1120
        self.proxy = self.get_secondary_server()
 
1121
        self.proxy_url = self.proxy.get_url()
 
1122
        self._old_env = {}
 
1123
 
 
1124
    def _testing_pycurl(self):
 
1125
        return pycurl_present and self._transport == PyCurlTransport
 
1126
 
 
1127
    def create_transport_secondary_server(self):
 
1128
        """Creates an http server that will serve files with
 
1129
        '-proxied' appended to their names.
 
1130
        """
 
1131
        return http_utils.ProxyServer(protocol_version=self._protocol_version)
 
1132
 
 
1133
    def _install_env(self, env):
 
1134
        for name, value in env.iteritems():
 
1135
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1136
 
 
1137
    def _restore_env(self):
 
1138
        for name, value in self._old_env.iteritems():
 
1139
            osutils.set_or_unset_env(name, value)
 
1140
 
 
1141
    def proxied_in_env(self, env):
 
1142
        self._install_env(env)
 
1143
        url = self.server.get_url()
 
1144
        t = self._transport(url)
 
1145
        try:
 
1146
            self.assertEqual(t.get('foo').read(), 'proxied contents of foo\n')
 
1147
        finally:
 
1148
            self._restore_env()
 
1149
 
 
1150
    def not_proxied_in_env(self, env):
 
1151
        self._install_env(env)
 
1152
        url = self.server.get_url()
 
1153
        t = self._transport(url)
 
1154
        try:
 
1155
            self.assertEqual(t.get('foo').read(), 'contents of foo\n')
 
1156
        finally:
 
1157
            self._restore_env()
 
1158
 
 
1159
    def test_http_proxy(self):
 
1160
        self.proxied_in_env({'http_proxy': self.proxy_url})
 
1161
 
 
1162
    def test_HTTP_PROXY(self):
 
1163
        if self._testing_pycurl():
 
1164
            # pycurl does not check HTTP_PROXY for security reasons
 
1165
            # (for use in a CGI context that we do not care
 
1166
            # about. Should we ?)
 
1167
            raise tests.TestNotApplicable(
 
1168
                'pycurl does not check HTTP_PROXY for security reasons')
 
1169
        self.proxied_in_env({'HTTP_PROXY': self.proxy_url})
 
1170
 
 
1171
    def test_all_proxy(self):
 
1172
        self.proxied_in_env({'all_proxy': self.proxy_url})
 
1173
 
 
1174
    def test_ALL_PROXY(self):
 
1175
        self.proxied_in_env({'ALL_PROXY': self.proxy_url})
 
1176
 
 
1177
    def test_http_proxy_with_no_proxy(self):
 
1178
        self.not_proxied_in_env({'http_proxy': self.proxy_url,
 
1179
                                 'no_proxy': self.no_proxy_host})
 
1180
 
 
1181
    def test_HTTP_PROXY_with_NO_PROXY(self):
 
1182
        if self._testing_pycurl():
 
1183
            raise tests.TestNotApplicable(
 
1184
                'pycurl does not check HTTP_PROXY for security reasons')
 
1185
        self.not_proxied_in_env({'HTTP_PROXY': self.proxy_url,
 
1186
                                 'NO_PROXY': self.no_proxy_host})
 
1187
 
 
1188
    def test_all_proxy_with_no_proxy(self):
 
1189
        self.not_proxied_in_env({'all_proxy': self.proxy_url,
 
1190
                                 'no_proxy': self.no_proxy_host})
 
1191
 
 
1192
    def test_ALL_PROXY_with_NO_PROXY(self):
 
1193
        self.not_proxied_in_env({'ALL_PROXY': self.proxy_url,
 
1194
                                 'NO_PROXY': self.no_proxy_host})
 
1195
 
 
1196
    def test_http_proxy_without_scheme(self):
 
1197
        if self._testing_pycurl():
 
1198
            # pycurl *ignores* invalid proxy env variables. If that ever change
 
1199
            # in the future, this test will fail indicating that pycurl do not
 
1200
            # ignore anymore such variables.
 
1201
            self.not_proxied_in_env({'http_proxy': self.proxy_address})
 
1202
        else:
 
1203
            self.assertRaises(errors.InvalidURL,
 
1204
                              self.proxied_in_env,
 
1205
                              {'http_proxy': self.proxy_address})
 
1206
 
 
1207
 
 
1208
class TestRanges(http_utils.TestCaseWithWebserver):
 
1209
    """Test the Range header in GET methods."""
 
1210
 
 
1211
    def setUp(self):
 
1212
        http_utils.TestCaseWithWebserver.setUp(self)
 
1213
        self.build_tree_contents([('a', '0123456789')],)
 
1214
        server = self.get_readonly_server()
 
1215
        self.transport = self._transport(server.get_url())
 
1216
 
 
1217
    def create_transport_readonly_server(self):
 
1218
        return http_server.HttpServer(protocol_version=self._protocol_version)
 
1219
 
 
1220
    def _file_contents(self, relpath, ranges):
 
1221
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
1222
        coalesce = self.transport._coalesce_offsets
 
1223
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
1224
        code, data = self.transport._get(relpath, coalesced)
 
1225
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1226
        for start, end in ranges:
 
1227
            data.seek(start)
 
1228
            yield data.read(end - start + 1)
 
1229
 
 
1230
    def _file_tail(self, relpath, tail_amount):
 
1231
        code, data = self.transport._get(relpath, [], tail_amount)
 
1232
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1233
        data.seek(-tail_amount, 2)
 
1234
        return data.read(tail_amount)
 
1235
 
 
1236
    def test_range_header(self):
 
1237
        # Valid ranges
 
1238
        map(self.assertEqual,['0', '234'],
 
1239
            list(self._file_contents('a', [(0,0), (2,4)])),)
 
1240
 
 
1241
    def test_range_header_tail(self):
 
1242
        self.assertEqual('789', self._file_tail('a', 3))
 
1243
 
 
1244
    def test_syntactically_invalid_range_header(self):
 
1245
        self.assertListRaises(errors.InvalidHttpRange,
 
1246
                          self._file_contents, 'a', [(4, 3)])
 
1247
 
 
1248
    def test_semantically_invalid_range_header(self):
 
1249
        self.assertListRaises(errors.InvalidHttpRange,
 
1250
                          self._file_contents, 'a', [(42, 128)])
 
1251
 
 
1252
 
 
1253
class TestHTTPRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1254
    """Test redirection between http servers."""
 
1255
 
 
1256
    def create_transport_secondary_server(self):
 
1257
        """Create the secondary server redirecting to the primary server"""
 
1258
        new = self.get_readonly_server()
 
1259
 
 
1260
        redirecting = http_utils.HTTPServerRedirecting(
 
1261
            protocol_version=self._protocol_version)
 
1262
        redirecting.redirect_to(new.host, new.port)
 
1263
        return redirecting
 
1264
 
 
1265
    def setUp(self):
 
1266
        super(TestHTTPRedirections, self).setUp()
 
1267
        self.build_tree_contents([('a', '0123456789'),
 
1268
                                  ('bundle',
 
1269
                                  '# Bazaar revision bundle v0.9\n#\n')
 
1270
                                  ],)
 
1271
 
 
1272
        self.old_transport = self._transport(self.old_server.get_url())
 
1273
 
 
1274
    def test_redirected(self):
 
1275
        self.assertRaises(errors.RedirectRequested, self.old_transport.get, 'a')
 
1276
        t = self._transport(self.new_server.get_url())
 
1277
        self.assertEqual('0123456789', t.get('a').read())
 
1278
 
 
1279
    def test_read_redirected_bundle_from_url(self):
 
1280
        from bzrlib.bundle import read_bundle_from_url
 
1281
        url = self.old_transport.abspath('bundle')
 
1282
        bundle = read_bundle_from_url(url)
 
1283
        # If read_bundle_from_url was successful we get an empty bundle
 
1284
        self.assertEqual([], bundle.revisions)
 
1285
 
 
1286
 
 
1287
class RedirectedRequest(_urllib2_wrappers.Request):
 
1288
    """Request following redirections. """
 
1289
 
 
1290
    init_orig = _urllib2_wrappers.Request.__init__
 
1291
 
 
1292
    def __init__(self, method, url, *args, **kwargs):
 
1293
        """Constructor.
 
1294
 
 
1295
        """
 
1296
        # Since the tests using this class will replace
 
1297
        # _urllib2_wrappers.Request, we can't just call the base class __init__
 
1298
        # or we'll loop.
 
1299
        RedirectedRequest.init_orig(self, method, url, args, kwargs)
 
1300
        self.follow_redirections = True
 
1301
 
 
1302
 
 
1303
class TestHTTPSilentRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1304
    """Test redirections.
 
1305
 
 
1306
    http implementations do not redirect silently anymore (they
 
1307
    do not redirect at all in fact). The mechanism is still in
 
1308
    place at the _urllib2_wrappers.Request level and these tests
 
1309
    exercise it.
 
1310
 
 
1311
    For the pycurl implementation
 
1312
    the redirection have been deleted as we may deprecate pycurl
 
1313
    and I have no place to keep a working implementation.
 
1314
    -- vila 20070212
 
1315
    """
 
1316
 
 
1317
    def setUp(self):
 
1318
        if pycurl_present and self._transport == PyCurlTransport:
 
1319
            raise tests.TestNotApplicable(
 
1320
                "pycurl doesn't redirect silently annymore")
 
1321
        super(TestHTTPSilentRedirections, self).setUp()
 
1322
        self.setup_redirected_request()
 
1323
        self.addCleanup(self.cleanup_redirected_request)
 
1324
        self.build_tree_contents([('a','a'),
 
1325
                                  ('1/',),
 
1326
                                  ('1/a', 'redirected once'),
 
1327
                                  ('2/',),
 
1328
                                  ('2/a', 'redirected twice'),
 
1329
                                  ('3/',),
 
1330
                                  ('3/a', 'redirected thrice'),
 
1331
                                  ('4/',),
 
1332
                                  ('4/a', 'redirected 4 times'),
 
1333
                                  ('5/',),
 
1334
                                  ('5/a', 'redirected 5 times'),
 
1335
                                  ],)
 
1336
 
 
1337
        self.old_transport = self._transport(self.old_server.get_url())
 
1338
 
 
1339
    def setup_redirected_request(self):
 
1340
        self.original_class = _urllib2_wrappers.Request
 
1341
        _urllib2_wrappers.Request = RedirectedRequest
 
1342
 
 
1343
    def cleanup_redirected_request(self):
 
1344
        _urllib2_wrappers.Request = self.original_class
 
1345
 
 
1346
    def create_transport_secondary_server(self):
 
1347
        """Create the secondary server, redirections are defined in the tests"""
 
1348
        return http_utils.HTTPServerRedirecting(
 
1349
            protocol_version=self._protocol_version)
 
1350
 
 
1351
    def test_one_redirection(self):
 
1352
        t = self.old_transport
 
1353
 
 
1354
        req = RedirectedRequest('GET', t.abspath('a'))
 
1355
        req.follow_redirections = True
 
1356
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1357
                                       self.new_server.port)
 
1358
        self.old_server.redirections = \
 
1359
            [('(.*)', r'%s/1\1' % (new_prefix), 301),]
 
1360
        self.assertEquals('redirected once',t._perform(req).read())
 
1361
 
 
1362
    def test_five_redirections(self):
 
1363
        t = self.old_transport
 
1364
 
 
1365
        req = RedirectedRequest('GET', t.abspath('a'))
 
1366
        req.follow_redirections = True
 
1367
        old_prefix = 'http://%s:%s' % (self.old_server.host,
 
1368
                                       self.old_server.port)
 
1369
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1370
                                       self.new_server.port)
 
1371
        self.old_server.redirections = [
 
1372
            ('/1(.*)', r'%s/2\1' % (old_prefix), 302),
 
1373
            ('/2(.*)', r'%s/3\1' % (old_prefix), 303),
 
1374
            ('/3(.*)', r'%s/4\1' % (old_prefix), 307),
 
1375
            ('/4(.*)', r'%s/5\1' % (new_prefix), 301),
 
1376
            ('(/[^/]+)', r'%s/1\1' % (old_prefix), 301),
 
1377
            ]
 
1378
        self.assertEquals('redirected 5 times',t._perform(req).read())
 
1379
 
 
1380
 
 
1381
class TestDoCatchRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1382
    """Test transport.do_catching_redirections."""
 
1383
 
 
1384
    def setUp(self):
 
1385
        super(TestDoCatchRedirections, self).setUp()
 
1386
        self.build_tree_contents([('a', '0123456789'),],)
 
1387
 
 
1388
        self.old_transport = self._transport(self.old_server.get_url())
 
1389
 
 
1390
    def get_a(self, transport):
 
1391
        return transport.get('a')
 
1392
 
 
1393
    def test_no_redirection(self):
 
1394
        t = self._transport(self.new_server.get_url())
 
1395
 
 
1396
        # We use None for redirected so that we fail if redirected
 
1397
        self.assertEquals('0123456789',
 
1398
                          transport.do_catching_redirections(
 
1399
                self.get_a, t, None).read())
 
1400
 
 
1401
    def test_one_redirection(self):
 
1402
        self.redirections = 0
 
1403
 
 
1404
        def redirected(transport, exception, redirection_notice):
 
1405
            self.redirections += 1
 
1406
            dir, file = urlutils.split(exception.target)
 
1407
            return self._transport(dir)
 
1408
 
 
1409
        self.assertEquals('0123456789',
 
1410
                          transport.do_catching_redirections(
 
1411
                self.get_a, self.old_transport, redirected).read())
 
1412
        self.assertEquals(1, self.redirections)
 
1413
 
 
1414
    def test_redirection_loop(self):
 
1415
 
 
1416
        def redirected(transport, exception, redirection_notice):
 
1417
            # By using the redirected url as a base dir for the
 
1418
            # *old* transport, we create a loop: a => a/a =>
 
1419
            # a/a/a
 
1420
            return self.old_transport.clone(exception.target)
 
1421
 
 
1422
        self.assertRaises(errors.TooManyRedirections,
 
1423
                          transport.do_catching_redirections,
 
1424
                          self.get_a, self.old_transport, redirected)
 
1425
 
 
1426
 
 
1427
class TestAuth(http_utils.TestCaseWithWebserver):
 
1428
    """Test authentication scheme"""
 
1429
 
 
1430
    _auth_header = 'Authorization'
 
1431
    _password_prompt_prefix = ''
 
1432
 
 
1433
    def setUp(self):
 
1434
        super(TestAuth, self).setUp()
 
1435
        self.server = self.get_readonly_server()
 
1436
        self.build_tree_contents([('a', 'contents of a\n'),
 
1437
                                  ('b', 'contents of b\n'),])
 
1438
 
 
1439
    def create_transport_readonly_server(self):
 
1440
        if self._auth_scheme == 'basic':
 
1441
            server = http_utils.HTTPBasicAuthServer(
 
1442
                protocol_version=self._protocol_version)
 
1443
        else:
 
1444
            if self._auth_scheme != 'digest':
 
1445
                raise AssertionError('Unknown auth scheme: %r'
 
1446
                                     % self._auth_scheme)
 
1447
            server = http_utils.HTTPDigestAuthServer(
 
1448
                protocol_version=self._protocol_version)
 
1449
        return server
 
1450
 
 
1451
    def _testing_pycurl(self):
 
1452
        return pycurl_present and self._transport == PyCurlTransport
 
1453
 
 
1454
    def get_user_url(self, user=None, password=None):
 
1455
        """Build an url embedding user and password"""
 
1456
        url = '%s://' % self.server._url_protocol
 
1457
        if user is not None:
 
1458
            url += user
 
1459
            if password is not None:
 
1460
                url += ':' + password
 
1461
            url += '@'
 
1462
        url += '%s:%s/' % (self.server.host, self.server.port)
 
1463
        return url
 
1464
 
 
1465
    def get_user_transport(self, user=None, password=None):
 
1466
        return self._transport(self.get_user_url(user, password))
 
1467
 
 
1468
    def test_no_user(self):
 
1469
        self.server.add_user('joe', 'foo')
 
1470
        t = self.get_user_transport()
 
1471
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1472
        # Only one 'Authentication Required' error should occur
 
1473
        self.assertEqual(1, self.server.auth_required_errors)
 
1474
 
 
1475
    def test_empty_pass(self):
 
1476
        self.server.add_user('joe', '')
 
1477
        t = self.get_user_transport('joe', '')
 
1478
        self.assertEqual('contents of a\n', t.get('a').read())
 
1479
        # Only one 'Authentication Required' error should occur
 
1480
        self.assertEqual(1, self.server.auth_required_errors)
 
1481
 
 
1482
    def test_user_pass(self):
 
1483
        self.server.add_user('joe', 'foo')
 
1484
        t = self.get_user_transport('joe', 'foo')
 
1485
        self.assertEqual('contents of a\n', t.get('a').read())
 
1486
        # Only one 'Authentication Required' error should occur
 
1487
        self.assertEqual(1, self.server.auth_required_errors)
 
1488
 
 
1489
    def test_unknown_user(self):
 
1490
        self.server.add_user('joe', 'foo')
 
1491
        t = self.get_user_transport('bill', 'foo')
 
1492
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1493
        # Two 'Authentication Required' errors should occur (the
 
1494
        # initial 'who are you' and 'I don't know you, who are
 
1495
        # you').
 
1496
        self.assertEqual(2, self.server.auth_required_errors)
 
1497
 
 
1498
    def test_wrong_pass(self):
 
1499
        self.server.add_user('joe', 'foo')
 
1500
        t = self.get_user_transport('joe', 'bar')
 
1501
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1502
        # Two 'Authentication Required' errors should occur (the
 
1503
        # initial 'who are you' and 'this is not you, who are you')
 
1504
        self.assertEqual(2, self.server.auth_required_errors)
 
1505
 
 
1506
    def test_prompt_for_password(self):
 
1507
        if self._testing_pycurl():
 
1508
            raise tests.TestNotApplicable(
 
1509
                'pycurl cannot prompt, it handles auth by embedding'
 
1510
                ' user:pass in urls only')
 
1511
 
 
1512
        self.server.add_user('joe', 'foo')
 
1513
        t = self.get_user_transport('joe', None)
 
1514
        stdout = tests.StringIOWrapper()
 
1515
        ui.ui_factory = tests.TestUIFactory(stdin='foo\n', stdout=stdout)
 
1516
        self.assertEqual('contents of a\n',t.get('a').read())
 
1517
        # stdin should be empty
 
1518
        self.assertEqual('', ui.ui_factory.stdin.readline())
 
1519
        self._check_password_prompt(t._unqualified_scheme, 'joe',
 
1520
                                    stdout.getvalue())
 
1521
        # And we shouldn't prompt again for a different request
 
1522
        # against the same transport.
 
1523
        self.assertEqual('contents of b\n',t.get('b').read())
 
1524
        t2 = t.clone()
 
1525
        # And neither against a clone
 
1526
        self.assertEqual('contents of b\n',t2.get('b').read())
 
1527
        # Only one 'Authentication Required' error should occur
 
1528
        self.assertEqual(1, self.server.auth_required_errors)
 
1529
 
 
1530
    def _check_password_prompt(self, scheme, user, actual_prompt):
 
1531
        expected_prompt = (self._password_prompt_prefix
 
1532
                           + ("%s %s@%s:%d, Realm: '%s' password: "
 
1533
                              % (scheme.upper(),
 
1534
                                 user, self.server.host, self.server.port,
 
1535
                                 self.server.auth_realm)))
 
1536
        self.assertEquals(expected_prompt, actual_prompt)
 
1537
 
 
1538
    def test_no_prompt_for_password_when_using_auth_config(self):
 
1539
        if self._testing_pycurl():
 
1540
            raise tests.TestNotApplicable(
 
1541
                'pycurl does not support authentication.conf'
 
1542
                ' since it cannot prompt')
 
1543
 
 
1544
        user =' joe'
 
1545
        password = 'foo'
 
1546
        stdin_content = 'bar\n'  # Not the right password
 
1547
        self.server.add_user(user, password)
 
1548
        t = self.get_user_transport(user, None)
 
1549
        ui.ui_factory = tests.TestUIFactory(stdin=stdin_content,
 
1550
                                            stdout=tests.StringIOWrapper())
 
1551
        # Create a minimal config file with the right password
 
1552
        conf = config.AuthenticationConfig()
 
1553
        conf._get_config().update(
 
1554
            {'httptest': {'scheme': 'http', 'port': self.server.port,
 
1555
                          'user': user, 'password': password}})
 
1556
        conf._save()
 
1557
        # Issue a request to the server to connect
 
1558
        self.assertEqual('contents of a\n',t.get('a').read())
 
1559
        # stdin should have  been left untouched
 
1560
        self.assertEqual(stdin_content, ui.ui_factory.stdin.readline())
 
1561
        # Only one 'Authentication Required' error should occur
 
1562
        self.assertEqual(1, self.server.auth_required_errors)
 
1563
 
 
1564
    def test_changing_nonce(self):
 
1565
        if self._auth_scheme != 'digest':
 
1566
            raise tests.TestNotApplicable('HTTP auth digest only test')
 
1567
        if self._testing_pycurl():
 
1568
            raise tests.KnownFailure(
 
1569
                'pycurl does not handle a nonce change')
 
1570
        self.server.add_user('joe', 'foo')
 
1571
        t = self.get_user_transport('joe', 'foo')
 
1572
        self.assertEqual('contents of a\n', t.get('a').read())
 
1573
        self.assertEqual('contents of b\n', t.get('b').read())
 
1574
        # Only one 'Authentication Required' error should have
 
1575
        # occured so far
 
1576
        self.assertEqual(1, self.server.auth_required_errors)
 
1577
        # The server invalidates the current nonce
 
1578
        self.server.auth_nonce = self.server.auth_nonce + '. No, now!'
 
1579
        self.assertEqual('contents of a\n', t.get('a').read())
 
1580
        # Two 'Authentication Required' errors should occur (the
 
1581
        # initial 'who are you' and a second 'who are you' with the new nonce)
 
1582
        self.assertEqual(2, self.server.auth_required_errors)
 
1583
 
 
1584
 
 
1585
 
 
1586
class TestProxyAuth(TestAuth):
 
1587
    """Test proxy authentication schemes."""
 
1588
 
 
1589
    _auth_header = 'Proxy-authorization'
 
1590
    _password_prompt_prefix='Proxy '
 
1591
 
 
1592
    def setUp(self):
 
1593
        super(TestProxyAuth, self).setUp()
 
1594
        self._old_env = {}
 
1595
        self.addCleanup(self._restore_env)
 
1596
        # Override the contents to avoid false positives
 
1597
        self.build_tree_contents([('a', 'not proxied contents of a\n'),
 
1598
                                  ('b', 'not proxied contents of b\n'),
 
1599
                                  ('a-proxied', 'contents of a\n'),
 
1600
                                  ('b-proxied', 'contents of b\n'),
 
1601
                                  ])
 
1602
 
 
1603
    def create_transport_readonly_server(self):
 
1604
        if self._auth_scheme == 'basic':
 
1605
            server = http_utils.ProxyBasicAuthServer(
 
1606
                protocol_version=self._protocol_version)
 
1607
        else:
 
1608
            if self._auth_scheme != 'digest':
 
1609
                raise AssertionError('Unknown auth scheme: %r'
 
1610
                                     % self._auth_scheme)
 
1611
            server = http_utils.ProxyDigestAuthServer(
 
1612
                protocol_version=self._protocol_version)
 
1613
        return server
 
1614
 
 
1615
    def get_user_transport(self, user=None, password=None):
 
1616
        self._install_env({'all_proxy': self.get_user_url(user, password)})
 
1617
        return self._transport(self.server.get_url())
 
1618
 
 
1619
    def _install_env(self, env):
 
1620
        for name, value in env.iteritems():
 
1621
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1622
 
 
1623
    def _restore_env(self):
 
1624
        for name, value in self._old_env.iteritems():
 
1625
            osutils.set_or_unset_env(name, value)
 
1626
 
 
1627
    def test_empty_pass(self):
 
1628
        if self._testing_pycurl():
 
1629
            import pycurl
 
1630
            if pycurl.version_info()[1] < '7.16.0':
 
1631
                raise tests.KnownFailure(
 
1632
                    'pycurl < 7.16.0 does not handle empty proxy passwords')
 
1633
        super(TestProxyAuth, self).test_empty_pass()
 
1634
 
 
1635
 
 
1636
class SampleSocket(object):
 
1637
    """A socket-like object for use in testing the HTTP request handler."""
 
1638
 
 
1639
    def __init__(self, socket_read_content):
 
1640
        """Constructs a sample socket.
 
1641
 
 
1642
        :param socket_read_content: a byte sequence
 
1643
        """
 
1644
        # Use plain python StringIO so we can monkey-patch the close method to
 
1645
        # not discard the contents.
 
1646
        from StringIO import StringIO
 
1647
        self.readfile = StringIO(socket_read_content)
 
1648
        self.writefile = StringIO()
 
1649
        self.writefile.close = lambda: None
 
1650
 
 
1651
    def makefile(self, mode='r', bufsize=None):
 
1652
        if 'r' in mode:
 
1653
            return self.readfile
 
1654
        else:
 
1655
            return self.writefile
 
1656
 
 
1657
 
 
1658
class SmartHTTPTunnellingTest(tests.TestCaseWithTransport):
 
1659
 
 
1660
    def setUp(self):
 
1661
        super(SmartHTTPTunnellingTest, self).setUp()
 
1662
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1663
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1664
        self.transport_readonly_server = http_utils.HTTPServerWithSmarts
 
1665
 
 
1666
    def create_transport_readonly_server(self):
 
1667
        return http_utils.HTTPServerWithSmarts(
 
1668
            protocol_version=self._protocol_version)
 
1669
 
 
1670
    def test_bulk_data(self):
 
1671
        # We should be able to send and receive bulk data in a single message.
 
1672
        # The 'readv' command in the smart protocol both sends and receives
 
1673
        # bulk data, so we use that.
 
1674
        self.build_tree(['data-file'])
 
1675
        http_server = self.get_readonly_server()
 
1676
        http_transport = self._transport(http_server.get_url())
 
1677
        medium = http_transport.get_smart_medium()
 
1678
        # Since we provide the medium, the url below will be mostly ignored
 
1679
        # during the test, as long as the path is '/'.
 
1680
        remote_transport = remote.RemoteTransport('bzr://fake_host/',
 
1681
                                                  medium=medium)
 
1682
        self.assertEqual(
 
1683
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
 
1684
 
 
1685
    def test_http_send_smart_request(self):
 
1686
 
 
1687
        post_body = 'hello\n'
 
1688
        expected_reply_body = 'ok\x012\n'
 
1689
 
 
1690
        http_server = self.get_readonly_server()
 
1691
        http_transport = self._transport(http_server.get_url())
 
1692
        medium = http_transport.get_smart_medium()
 
1693
        response = medium.send_http_smart_request(post_body)
 
1694
        reply_body = response.read()
 
1695
        self.assertEqual(expected_reply_body, reply_body)
 
1696
 
 
1697
    def test_smart_http_server_post_request_handler(self):
 
1698
        httpd = self.get_readonly_server()._get_httpd()
 
1699
 
 
1700
        socket = SampleSocket(
 
1701
            'POST /.bzr/smart %s \r\n' % self._protocol_version
 
1702
            # HTTP/1.1 posts must have a Content-Length (but it doesn't hurt
 
1703
            # for 1.0)
 
1704
            + 'Content-Length: 6\r\n'
 
1705
            '\r\n'
 
1706
            'hello\n')
 
1707
        # Beware: the ('localhost', 80) below is the
 
1708
        # client_address parameter, but we don't have one because
 
1709
        # we have defined a socket which is not bound to an
 
1710
        # address. The test framework never uses this client
 
1711
        # address, so far...
 
1712
        request_handler = http_utils.SmartRequestHandler(socket,
 
1713
                                                         ('localhost', 80),
 
1714
                                                         httpd)
 
1715
        response = socket.writefile.getvalue()
 
1716
        self.assertStartsWith(response, '%s 200 ' % self._protocol_version)
 
1717
        # This includes the end of the HTTP headers, and all the body.
 
1718
        expected_end_of_response = '\r\n\r\nok\x012\n'
 
1719
        self.assertEndsWith(response, expected_end_of_response)
 
1720
 
 
1721
 
 
1722
class ForbiddenRequestHandler(http_server.TestingHTTPRequestHandler):
 
1723
    """No smart server here request handler."""
 
1724
 
 
1725
    def do_POST(self):
 
1726
        self.send_error(403, "Forbidden")
 
1727
 
 
1728
 
 
1729
class SmartClientAgainstNotSmartServer(TestSpecificRequestHandler):
 
1730
    """Test smart client behaviour against an http server without smarts."""
 
1731
 
 
1732
    _req_handler_class = ForbiddenRequestHandler
 
1733
 
 
1734
    def test_probe_smart_server(self):
 
1735
        """Test error handling against server refusing smart requests."""
 
1736
        server = self.get_readonly_server()
 
1737
        t = self._transport(server.get_url())
 
1738
        # No need to build a valid smart request here, the server will not even
 
1739
        # try to interpret it.
 
1740
        self.assertRaises(errors.SmartProtocolError,
 
1741
                          t.send_http_smart_request, 'whatever')
 
1742