~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_http.py

  • Committer: John Arbash Meinel
  • Author(s): Mark Hammond
  • Date: 2008-09-09 17:02:21 UTC
  • mto: This revision was merged to the branch mainline in revision 3697.
  • Revision ID: john@arbash-meinel.com-20080909170221-svim3jw2mrz0amp3
An updated transparent icon for bzr.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005, 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""Tests for HTTP implementations.
 
18
 
 
19
This module defines a load_tests() method that parametrize tests classes for
 
20
transport implementation, http protocol versions and authentication schemes.
 
21
"""
 
22
 
 
23
# TODO: Should be renamed to bzrlib.transport.http.tests?
 
24
# TODO: What about renaming to bzrlib.tests.transport.http ?
 
25
 
 
26
from cStringIO import StringIO
 
27
import httplib
 
28
import os
 
29
import select
 
30
import SimpleHTTPServer
 
31
import socket
 
32
import sys
 
33
import threading
 
34
 
 
35
import bzrlib
 
36
from bzrlib import (
 
37
    bzrdir,
 
38
    config,
 
39
    errors,
 
40
    osutils,
 
41
    remote as _mod_remote,
 
42
    tests,
 
43
    transport,
 
44
    ui,
 
45
    urlutils,
 
46
    )
 
47
from bzrlib.tests import (
 
48
    http_server,
 
49
    http_utils,
 
50
    )
 
51
from bzrlib.transport import (
 
52
    http,
 
53
    remote,
 
54
    )
 
55
from bzrlib.transport.http import (
 
56
    _urllib,
 
57
    _urllib2_wrappers,
 
58
    )
 
59
 
 
60
 
 
61
try:
 
62
    from bzrlib.transport.http._pycurl import PyCurlTransport
 
63
    pycurl_present = True
 
64
except errors.DependencyNotPresent:
 
65
    pycurl_present = False
 
66
 
 
67
 
 
68
class TransportAdapter(tests.TestScenarioApplier):
 
69
    """Generate the same test for each transport implementation."""
 
70
 
 
71
    def __init__(self):
 
72
        transport_scenarios = [
 
73
            ('urllib', dict(_transport=_urllib.HttpTransport_urllib,
 
74
                            _server=http_server.HttpServer_urllib,
 
75
                            _qualified_prefix='http+urllib',)),
 
76
            ]
 
77
        if pycurl_present:
 
78
            transport_scenarios.append(
 
79
                ('pycurl', dict(_transport=PyCurlTransport,
 
80
                                _server=http_server.HttpServer_PyCurl,
 
81
                                _qualified_prefix='http+pycurl',)))
 
82
        self.scenarios = transport_scenarios
 
83
 
 
84
 
 
85
class TransportProtocolAdapter(TransportAdapter):
 
86
    """Generate the same test for each protocol implementation.
 
87
 
 
88
    In addition to the transport adaptatation that we inherit from.
 
89
    """
 
90
 
 
91
    def __init__(self):
 
92
        super(TransportProtocolAdapter, self).__init__()
 
93
        protocol_scenarios = [
 
94
            ('HTTP/1.0',  dict(_protocol_version='HTTP/1.0')),
 
95
            ('HTTP/1.1',  dict(_protocol_version='HTTP/1.1')),
 
96
            ]
 
97
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
98
                                                  protocol_scenarios)
 
99
 
 
100
 
 
101
class TransportProtocolAuthenticationAdapter(TransportProtocolAdapter):
 
102
    """Generate the same test for each authentication scheme implementation.
 
103
 
 
104
    In addition to the protocol adaptatation that we inherit from.
 
105
    """
 
106
 
 
107
    def __init__(self):
 
108
        super(TransportProtocolAuthenticationAdapter, self).__init__()
 
109
        auth_scheme_scenarios = [
 
110
            ('basic', dict(_auth_scheme='basic')),
 
111
            ('digest', dict(_auth_scheme='digest')),
 
112
            ]
 
113
 
 
114
        self.scenarios = tests.multiply_scenarios(self.scenarios,
 
115
                                                  auth_scheme_scenarios)
 
116
 
 
117
def load_tests(standard_tests, module, loader):
 
118
    """Multiply tests for http clients and protocol versions."""
 
119
    # one for each transport
 
120
    t_adapter = TransportAdapter()
 
121
    t_classes= (TestHttpTransportRegistration,
 
122
                TestHttpTransportUrls,
 
123
                )
 
124
    is_testing_for_transports = tests.condition_isinstance(t_classes)
 
125
 
 
126
    # multiplied by one for each protocol version
 
127
    tp_adapter = TransportProtocolAdapter()
 
128
    tp_classes= (SmartHTTPTunnellingTest,
 
129
                 TestDoCatchRedirections,
 
130
                 TestHTTPConnections,
 
131
                 TestHTTPRedirections,
 
132
                 TestHTTPSilentRedirections,
 
133
                 TestLimitedRangeRequestServer,
 
134
                 TestPost,
 
135
                 TestProxyHttpServer,
 
136
                 TestRanges,
 
137
                 TestSpecificRequestHandler,
 
138
                 )
 
139
    is_also_testing_for_protocols = tests.condition_isinstance(tp_classes)
 
140
 
 
141
    # multiplied by one for each authentication scheme
 
142
    tpa_adapter = TransportProtocolAuthenticationAdapter()
 
143
    tpa_classes = (TestAuth,
 
144
                   )
 
145
    is_also_testing_for_authentication = tests.condition_isinstance(
 
146
        tpa_classes)
 
147
 
 
148
    result = loader.suiteClass()
 
149
    for test_class in tests.iter_suite_tests(standard_tests):
 
150
        # Each test class is either standalone or testing for some combination
 
151
        # of transport, protocol version, authentication scheme. Use the right
 
152
        # adpater (or none) depending on the class.
 
153
        if is_testing_for_transports(test_class):
 
154
            result.addTests(t_adapter.adapt(test_class))
 
155
        elif is_also_testing_for_protocols(test_class):
 
156
            result.addTests(tp_adapter.adapt(test_class))
 
157
        elif is_also_testing_for_authentication(test_class):
 
158
            result.addTests(tpa_adapter.adapt(test_class))
 
159
        else:
 
160
            result.addTest(test_class)
 
161
    return result
 
162
 
 
163
 
 
164
class FakeManager(object):
 
165
 
 
166
    def __init__(self):
 
167
        self.credentials = []
 
168
 
 
169
    def add_password(self, realm, host, username, password):
 
170
        self.credentials.append([realm, host, username, password])
 
171
 
 
172
 
 
173
class RecordingServer(object):
 
174
    """A fake HTTP server.
 
175
    
 
176
    It records the bytes sent to it, and replies with a 200.
 
177
    """
 
178
 
 
179
    def __init__(self, expect_body_tail=None):
 
180
        """Constructor.
 
181
 
 
182
        :type expect_body_tail: str
 
183
        :param expect_body_tail: a reply won't be sent until this string is
 
184
            received.
 
185
        """
 
186
        self._expect_body_tail = expect_body_tail
 
187
        self.host = None
 
188
        self.port = None
 
189
        self.received_bytes = ''
 
190
 
 
191
    def setUp(self):
 
192
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
193
        self._sock.bind(('127.0.0.1', 0))
 
194
        self.host, self.port = self._sock.getsockname()
 
195
        self._ready = threading.Event()
 
196
        self._thread = threading.Thread(target=self._accept_read_and_reply)
 
197
        self._thread.setDaemon(True)
 
198
        self._thread.start()
 
199
        self._ready.wait(5)
 
200
 
 
201
    def _accept_read_and_reply(self):
 
202
        self._sock.listen(1)
 
203
        self._ready.set()
 
204
        self._sock.settimeout(5)
 
205
        try:
 
206
            conn, address = self._sock.accept()
 
207
            # On win32, the accepted connection will be non-blocking to start
 
208
            # with because we're using settimeout.
 
209
            conn.setblocking(True)
 
210
            while not self.received_bytes.endswith(self._expect_body_tail):
 
211
                self.received_bytes += conn.recv(4096)
 
212
            conn.sendall('HTTP/1.1 200 OK\r\n')
 
213
        except socket.timeout:
 
214
            # Make sure the client isn't stuck waiting for us to e.g. accept.
 
215
            self._sock.close()
 
216
        except socket.error:
 
217
            # The client may have already closed the socket.
 
218
            pass
 
219
 
 
220
    def tearDown(self):
 
221
        try:
 
222
            self._sock.close()
 
223
        except socket.error:
 
224
            # We might have already closed it.  We don't care.
 
225
            pass
 
226
        self.host = None
 
227
        self.port = None
 
228
 
 
229
 
 
230
class TestHTTPServer(tests.TestCase):
 
231
    """Test the HTTP servers implementations."""
 
232
 
 
233
    def test_invalid_protocol(self):
 
234
        class BogusRequestHandler(http_server.TestingHTTPRequestHandler):
 
235
 
 
236
            protocol_version = 'HTTP/0.1'
 
237
 
 
238
        server = http_server.HttpServer(BogusRequestHandler)
 
239
        try:
 
240
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
241
        except:
 
242
            server.tearDown()
 
243
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
244
 
 
245
    def test_force_invalid_protocol(self):
 
246
        server = http_server.HttpServer(protocol_version='HTTP/0.1')
 
247
        try:
 
248
            self.assertRaises(httplib.UnknownProtocol,server.setUp)
 
249
        except:
 
250
            server.tearDown()
 
251
            self.fail('HTTP Server creation did not raise UnknownProtocol')
 
252
 
 
253
    def test_server_start_and_stop(self):
 
254
        server = http_server.HttpServer()
 
255
        server.setUp()
 
256
        self.assertTrue(server._http_running)
 
257
        server.tearDown()
 
258
        self.assertFalse(server._http_running)
 
259
 
 
260
    def test_create_http_server_one_zero(self):
 
261
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
262
 
 
263
            protocol_version = 'HTTP/1.0'
 
264
 
 
265
        server = http_server.HttpServer(RequestHandlerOneZero)
 
266
        server.setUp()
 
267
        self.addCleanup(server.tearDown)
 
268
        self.assertIsInstance(server._httpd, http_server.TestingHTTPServer)
 
269
 
 
270
    def test_create_http_server_one_one(self):
 
271
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
272
 
 
273
            protocol_version = 'HTTP/1.1'
 
274
 
 
275
        server = http_server.HttpServer(RequestHandlerOneOne)
 
276
        server.setUp()
 
277
        self.addCleanup(server.tearDown)
 
278
        self.assertIsInstance(server._httpd,
 
279
                              http_server.TestingThreadingHTTPServer)
 
280
 
 
281
    def test_create_http_server_force_one_one(self):
 
282
        class RequestHandlerOneZero(http_server.TestingHTTPRequestHandler):
 
283
 
 
284
            protocol_version = 'HTTP/1.0'
 
285
 
 
286
        server = http_server.HttpServer(RequestHandlerOneZero,
 
287
                                        protocol_version='HTTP/1.1')
 
288
        server.setUp()
 
289
        self.addCleanup(server.tearDown)
 
290
        self.assertIsInstance(server._httpd,
 
291
                              http_server.TestingThreadingHTTPServer)
 
292
 
 
293
    def test_create_http_server_force_one_zero(self):
 
294
        class RequestHandlerOneOne(http_server.TestingHTTPRequestHandler):
 
295
 
 
296
            protocol_version = 'HTTP/1.1'
 
297
 
 
298
        server = http_server.HttpServer(RequestHandlerOneOne,
 
299
                                        protocol_version='HTTP/1.0')
 
300
        server.setUp()
 
301
        self.addCleanup(server.tearDown)
 
302
        self.assertIsInstance(server._httpd,
 
303
                              http_server.TestingHTTPServer)
 
304
 
 
305
 
 
306
class TestWithTransport_pycurl(object):
 
307
    """Test case to inherit from if pycurl is present"""
 
308
 
 
309
    def _get_pycurl_maybe(self):
 
310
        try:
 
311
            from bzrlib.transport.http._pycurl import PyCurlTransport
 
312
            return PyCurlTransport
 
313
        except errors.DependencyNotPresent:
 
314
            raise tests.TestSkipped('pycurl not present')
 
315
 
 
316
    _transport = property(_get_pycurl_maybe)
 
317
 
 
318
 
 
319
class TestHttpUrls(tests.TestCase):
 
320
 
 
321
    # TODO: This should be moved to authorization tests once they
 
322
    # are written.
 
323
 
 
324
    def test_url_parsing(self):
 
325
        f = FakeManager()
 
326
        url = http.extract_auth('http://example.com', f)
 
327
        self.assertEquals('http://example.com', url)
 
328
        self.assertEquals(0, len(f.credentials))
 
329
        url = http.extract_auth(
 
330
            'http://user:pass@www.bazaar-vcs.org/bzr/bzr.dev', f)
 
331
        self.assertEquals('http://www.bazaar-vcs.org/bzr/bzr.dev', url)
 
332
        self.assertEquals(1, len(f.credentials))
 
333
        self.assertEquals([None, 'www.bazaar-vcs.org', 'user', 'pass'],
 
334
                          f.credentials[0])
 
335
 
 
336
 
 
337
class TestHttpTransportUrls(tests.TestCase):
 
338
    """Test the http urls."""
 
339
 
 
340
    def test_abs_url(self):
 
341
        """Construction of absolute http URLs"""
 
342
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
343
        eq = self.assertEqualDiff
 
344
        eq(t.abspath('.'), 'http://bazaar-vcs.org/bzr/bzr.dev')
 
345
        eq(t.abspath('foo/bar'), 'http://bazaar-vcs.org/bzr/bzr.dev/foo/bar')
 
346
        eq(t.abspath('.bzr'), 'http://bazaar-vcs.org/bzr/bzr.dev/.bzr')
 
347
        eq(t.abspath('.bzr/1//2/./3'),
 
348
           'http://bazaar-vcs.org/bzr/bzr.dev/.bzr/1/2/3')
 
349
 
 
350
    def test_invalid_http_urls(self):
 
351
        """Trap invalid construction of urls"""
 
352
        t = self._transport('http://bazaar-vcs.org/bzr/bzr.dev/')
 
353
        self.assertRaises(errors.InvalidURL,
 
354
                          self._transport,
 
355
                          'http://http://bazaar-vcs.org/bzr/bzr.dev/')
 
356
 
 
357
    def test_http_root_urls(self):
 
358
        """Construction of URLs from server root"""
 
359
        t = self._transport('http://bzr.ozlabs.org/')
 
360
        eq = self.assertEqualDiff
 
361
        eq(t.abspath('.bzr/tree-version'),
 
362
           'http://bzr.ozlabs.org/.bzr/tree-version')
 
363
 
 
364
    def test_http_impl_urls(self):
 
365
        """There are servers which ask for particular clients to connect"""
 
366
        server = self._server()
 
367
        try:
 
368
            server.setUp()
 
369
            url = server.get_url()
 
370
            self.assertTrue(url.startswith('%s://' % self._qualified_prefix))
 
371
        finally:
 
372
            server.tearDown()
 
373
 
 
374
 
 
375
class TestHttps_pycurl(TestWithTransport_pycurl, tests.TestCase):
 
376
 
 
377
    # TODO: This should really be moved into another pycurl
 
378
    # specific test. When https tests will be implemented, take
 
379
    # this one into account.
 
380
    def test_pycurl_without_https_support(self):
 
381
        """Test that pycurl without SSL do not fail with a traceback.
 
382
 
 
383
        For the purpose of the test, we force pycurl to ignore
 
384
        https by supplying a fake version_info that do not
 
385
        support it.
 
386
        """
 
387
        try:
 
388
            import pycurl
 
389
        except ImportError:
 
390
            raise tests.TestSkipped('pycurl not present')
 
391
 
 
392
        version_info_orig = pycurl.version_info
 
393
        try:
 
394
            # Now that we have pycurl imported, we can fake its version_info
 
395
            # This was taken from a windows pycurl without SSL
 
396
            # (thanks to bialix)
 
397
            pycurl.version_info = lambda : (2,
 
398
                                            '7.13.2',
 
399
                                            462082,
 
400
                                            'i386-pc-win32',
 
401
                                            2576,
 
402
                                            None,
 
403
                                            0,
 
404
                                            None,
 
405
                                            ('ftp', 'gopher', 'telnet',
 
406
                                             'dict', 'ldap', 'http', 'file'),
 
407
                                            None,
 
408
                                            0,
 
409
                                            None)
 
410
            self.assertRaises(errors.DependencyNotPresent, self._transport,
 
411
                              'https://launchpad.net')
 
412
        finally:
 
413
            # Restore the right function
 
414
            pycurl.version_info = version_info_orig
 
415
 
 
416
 
 
417
class TestHTTPConnections(http_utils.TestCaseWithWebserver):
 
418
    """Test the http connections."""
 
419
 
 
420
    def setUp(self):
 
421
        http_utils.TestCaseWithWebserver.setUp(self)
 
422
        self.build_tree(['foo/', 'foo/bar'], line_endings='binary',
 
423
                        transport=self.get_transport())
 
424
 
 
425
    def test_http_has(self):
 
426
        server = self.get_readonly_server()
 
427
        t = self._transport(server.get_url())
 
428
        self.assertEqual(t.has('foo/bar'), True)
 
429
        self.assertEqual(len(server.logs), 1)
 
430
        self.assertContainsRe(server.logs[0],
 
431
            r'"HEAD /foo/bar HTTP/1.." (200|302) - "-" "bzr/')
 
432
 
 
433
    def test_http_has_not_found(self):
 
434
        server = self.get_readonly_server()
 
435
        t = self._transport(server.get_url())
 
436
        self.assertEqual(t.has('not-found'), False)
 
437
        self.assertContainsRe(server.logs[1],
 
438
            r'"HEAD /not-found HTTP/1.." 404 - "-" "bzr/')
 
439
 
 
440
    def test_http_get(self):
 
441
        server = self.get_readonly_server()
 
442
        t = self._transport(server.get_url())
 
443
        fp = t.get('foo/bar')
 
444
        self.assertEqualDiff(
 
445
            fp.read(),
 
446
            'contents of foo/bar\n')
 
447
        self.assertEqual(len(server.logs), 1)
 
448
        self.assertTrue(server.logs[0].find(
 
449
            '"GET /foo/bar HTTP/1.1" 200 - "-" "bzr/%s'
 
450
            % bzrlib.__version__) > -1)
 
451
 
 
452
    def test_get_smart_medium(self):
 
453
        # For HTTP, get_smart_medium should return the transport object.
 
454
        server = self.get_readonly_server()
 
455
        http_transport = self._transport(server.get_url())
 
456
        medium = http_transport.get_smart_medium()
 
457
        self.assertIs(medium, http_transport)
 
458
 
 
459
    def test_has_on_bogus_host(self):
 
460
        # Get a free address and don't 'accept' on it, so that we
 
461
        # can be sure there is no http handler there, but set a
 
462
        # reasonable timeout to not slow down tests too much.
 
463
        default_timeout = socket.getdefaulttimeout()
 
464
        try:
 
465
            socket.setdefaulttimeout(2)
 
466
            s = socket.socket()
 
467
            s.bind(('localhost', 0))
 
468
            t = self._transport('http://%s:%s/' % s.getsockname())
 
469
            self.assertRaises(errors.ConnectionError, t.has, 'foo/bar')
 
470
        finally:
 
471
            socket.setdefaulttimeout(default_timeout)
 
472
 
 
473
 
 
474
class TestHttpTransportRegistration(tests.TestCase):
 
475
    """Test registrations of various http implementations"""
 
476
 
 
477
    def test_http_registered(self):
 
478
        t = transport.get_transport('%s://foo.com/' % self._qualified_prefix)
 
479
        self.assertIsInstance(t, transport.Transport)
 
480
        self.assertIsInstance(t, self._transport)
 
481
 
 
482
 
 
483
class TestPost(tests.TestCase):
 
484
 
 
485
    def test_post_body_is_received(self):
 
486
        server = RecordingServer(expect_body_tail='end-of-body')
 
487
        server.setUp()
 
488
        self.addCleanup(server.tearDown)
 
489
        scheme = self._qualified_prefix
 
490
        url = '%s://%s:%s/' % (scheme, server.host, server.port)
 
491
        http_transport = self._transport(url)
 
492
        code, response = http_transport._post('abc def end-of-body')
 
493
        self.assertTrue(
 
494
            server.received_bytes.startswith('POST /.bzr/smart HTTP/1.'))
 
495
        self.assertTrue('content-length: 19\r' in server.received_bytes.lower())
 
496
        # The transport should not be assuming that the server can accept
 
497
        # chunked encoding the first time it connects, because HTTP/1.1, so we
 
498
        # check for the literal string.
 
499
        self.assertTrue(
 
500
            server.received_bytes.endswith('\r\n\r\nabc def end-of-body'))
 
501
 
 
502
 
 
503
class TestRangeHeader(tests.TestCase):
 
504
    """Test range_header method"""
 
505
 
 
506
    def check_header(self, value, ranges=[], tail=0):
 
507
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
508
        coalesce = transport.Transport._coalesce_offsets
 
509
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
510
        range_header = http.HttpTransportBase._range_header
 
511
        self.assertEqual(value, range_header(coalesced, tail))
 
512
 
 
513
    def test_range_header_single(self):
 
514
        self.check_header('0-9', ranges=[(0,9)])
 
515
        self.check_header('100-109', ranges=[(100,109)])
 
516
 
 
517
    def test_range_header_tail(self):
 
518
        self.check_header('-10', tail=10)
 
519
        self.check_header('-50', tail=50)
 
520
 
 
521
    def test_range_header_multi(self):
 
522
        self.check_header('0-9,100-200,300-5000',
 
523
                          ranges=[(0,9), (100, 200), (300,5000)])
 
524
 
 
525
    def test_range_header_mixed(self):
 
526
        self.check_header('0-9,300-5000,-50',
 
527
                          ranges=[(0,9), (300,5000)],
 
528
                          tail=50)
 
529
 
 
530
 
 
531
class TestSpecificRequestHandler(http_utils.TestCaseWithWebserver):
 
532
    """Tests a specific request handler.
 
533
 
 
534
    Daughter classes are expected to override _req_handler_class
 
535
    """
 
536
 
 
537
    # Provide a useful default
 
538
    _req_handler_class = http_server.TestingHTTPRequestHandler
 
539
 
 
540
    def create_transport_readonly_server(self):
 
541
        return http_server.HttpServer(self._req_handler_class,
 
542
                                      protocol_version=self._protocol_version)
 
543
 
 
544
    def _testing_pycurl(self):
 
545
        return pycurl_present and self._transport == PyCurlTransport
 
546
 
 
547
 
 
548
class WallRequestHandler(http_server.TestingHTTPRequestHandler):
 
549
    """Whatever request comes in, close the connection"""
 
550
 
 
551
    def handle_one_request(self):
 
552
        """Handle a single HTTP request, by abruptly closing the connection"""
 
553
        self.close_connection = 1
 
554
 
 
555
 
 
556
class TestWallServer(TestSpecificRequestHandler):
 
557
    """Tests exceptions during the connection phase"""
 
558
 
 
559
    _req_handler_class = WallRequestHandler
 
560
 
 
561
    def test_http_has(self):
 
562
        server = self.get_readonly_server()
 
563
        t = self._transport(server.get_url())
 
564
        # Unfortunately httplib (see HTTPResponse._read_status
 
565
        # for details) make no distinction between a closed
 
566
        # socket and badly formatted status line, so we can't
 
567
        # just test for ConnectionError, we have to test
 
568
        # InvalidHttpResponse too.
 
569
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
570
                          t.has, 'foo/bar')
 
571
 
 
572
    def test_http_get(self):
 
573
        server = self.get_readonly_server()
 
574
        t = self._transport(server.get_url())
 
575
        self.assertRaises((errors.ConnectionError, errors.InvalidHttpResponse),
 
576
                          t.get, 'foo/bar')
 
577
 
 
578
 
 
579
class BadStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
580
    """Whatever request comes in, returns a bad status"""
 
581
 
 
582
    def parse_request(self):
 
583
        """Fakes handling a single HTTP request, returns a bad status"""
 
584
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
585
        self.send_response(0, "Bad status")
 
586
        self.close_connection = 1
 
587
        return False
 
588
 
 
589
 
 
590
class TestBadStatusServer(TestSpecificRequestHandler):
 
591
    """Tests bad status from server."""
 
592
 
 
593
    _req_handler_class = BadStatusRequestHandler
 
594
 
 
595
    def test_http_has(self):
 
596
        server = self.get_readonly_server()
 
597
        t = self._transport(server.get_url())
 
598
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
599
 
 
600
    def test_http_get(self):
 
601
        server = self.get_readonly_server()
 
602
        t = self._transport(server.get_url())
 
603
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
604
 
 
605
 
 
606
class InvalidStatusRequestHandler(http_server.TestingHTTPRequestHandler):
 
607
    """Whatever request comes in, returns an invalid status"""
 
608
 
 
609
    def parse_request(self):
 
610
        """Fakes handling a single HTTP request, returns a bad status"""
 
611
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
612
        self.wfile.write("Invalid status line\r\n")
 
613
        return False
 
614
 
 
615
 
 
616
class TestInvalidStatusServer(TestBadStatusServer):
 
617
    """Tests invalid status from server.
 
618
 
 
619
    Both implementations raises the same error as for a bad status.
 
620
    """
 
621
 
 
622
    _req_handler_class = InvalidStatusRequestHandler
 
623
 
 
624
    def test_http_has(self):
 
625
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
626
            raise tests.KnownFailure(
 
627
                'pycurl hangs if the server send back garbage')
 
628
        super(TestInvalidStatusServer, self).test_http_has()
 
629
 
 
630
    def test_http_get(self):
 
631
        if self._testing_pycurl() and self._protocol_version == 'HTTP/1.1':
 
632
            raise tests.KnownFailure(
 
633
                'pycurl hangs if the server send back garbage')
 
634
        super(TestInvalidStatusServer, self).test_http_get()
 
635
 
 
636
 
 
637
class BadProtocolRequestHandler(http_server.TestingHTTPRequestHandler):
 
638
    """Whatever request comes in, returns a bad protocol version"""
 
639
 
 
640
    def parse_request(self):
 
641
        """Fakes handling a single HTTP request, returns a bad status"""
 
642
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
643
        # Returns an invalid protocol version, but curl just
 
644
        # ignores it and those cannot be tested.
 
645
        self.wfile.write("%s %d %s\r\n" % ('HTTP/0.0',
 
646
                                           404,
 
647
                                           'Look at my protocol version'))
 
648
        return False
 
649
 
 
650
 
 
651
class TestBadProtocolServer(TestSpecificRequestHandler):
 
652
    """Tests bad protocol from server."""
 
653
 
 
654
    _req_handler_class = BadProtocolRequestHandler
 
655
 
 
656
    def setUp(self):
 
657
        if pycurl_present and self._transport == PyCurlTransport:
 
658
            raise tests.TestNotApplicable(
 
659
                "pycurl doesn't check the protocol version")
 
660
        super(TestBadProtocolServer, self).setUp()
 
661
 
 
662
    def test_http_has(self):
 
663
        server = self.get_readonly_server()
 
664
        t = self._transport(server.get_url())
 
665
        self.assertRaises(errors.InvalidHttpResponse, t.has, 'foo/bar')
 
666
 
 
667
    def test_http_get(self):
 
668
        server = self.get_readonly_server()
 
669
        t = self._transport(server.get_url())
 
670
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'foo/bar')
 
671
 
 
672
 
 
673
class ForbiddenRequestHandler(http_server.TestingHTTPRequestHandler):
 
674
    """Whatever request comes in, returns a 403 code"""
 
675
 
 
676
    def parse_request(self):
 
677
        """Handle a single HTTP request, by replying we cannot handle it"""
 
678
        ignored = http_server.TestingHTTPRequestHandler.parse_request(self)
 
679
        self.send_error(403)
 
680
        return False
 
681
 
 
682
 
 
683
class TestForbiddenServer(TestSpecificRequestHandler):
 
684
    """Tests forbidden server"""
 
685
 
 
686
    _req_handler_class = ForbiddenRequestHandler
 
687
 
 
688
    def test_http_has(self):
 
689
        server = self.get_readonly_server()
 
690
        t = self._transport(server.get_url())
 
691
        self.assertRaises(errors.TransportError, t.has, 'foo/bar')
 
692
 
 
693
    def test_http_get(self):
 
694
        server = self.get_readonly_server()
 
695
        t = self._transport(server.get_url())
 
696
        self.assertRaises(errors.TransportError, t.get, 'foo/bar')
 
697
 
 
698
 
 
699
class TestRecordingServer(tests.TestCase):
 
700
 
 
701
    def test_create(self):
 
702
        server = RecordingServer(expect_body_tail=None)
 
703
        self.assertEqual('', server.received_bytes)
 
704
        self.assertEqual(None, server.host)
 
705
        self.assertEqual(None, server.port)
 
706
 
 
707
    def test_setUp_and_tearDown(self):
 
708
        server = RecordingServer(expect_body_tail=None)
 
709
        server.setUp()
 
710
        try:
 
711
            self.assertNotEqual(None, server.host)
 
712
            self.assertNotEqual(None, server.port)
 
713
        finally:
 
714
            server.tearDown()
 
715
        self.assertEqual(None, server.host)
 
716
        self.assertEqual(None, server.port)
 
717
 
 
718
    def test_send_receive_bytes(self):
 
719
        server = RecordingServer(expect_body_tail='c')
 
720
        server.setUp()
 
721
        self.addCleanup(server.tearDown)
 
722
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
723
        sock.connect((server.host, server.port))
 
724
        sock.sendall('abc')
 
725
        self.assertEqual('HTTP/1.1 200 OK\r\n',
 
726
                         osutils.recv_all(sock, 4096))
 
727
        self.assertEqual('abc', server.received_bytes)
 
728
 
 
729
 
 
730
class TestRangeRequestServer(TestSpecificRequestHandler):
 
731
    """Tests readv requests against server.
 
732
 
 
733
    We test against default "normal" server.
 
734
    """
 
735
 
 
736
    def setUp(self):
 
737
        super(TestRangeRequestServer, self).setUp()
 
738
        self.build_tree_contents([('a', '0123456789')],)
 
739
 
 
740
    def test_readv(self):
 
741
        server = self.get_readonly_server()
 
742
        t = self._transport(server.get_url())
 
743
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
744
        self.assertEqual(l[0], (0, '0'))
 
745
        self.assertEqual(l[1], (1, '1'))
 
746
        self.assertEqual(l[2], (3, '34'))
 
747
        self.assertEqual(l[3], (9, '9'))
 
748
 
 
749
    def test_readv_out_of_order(self):
 
750
        server = self.get_readonly_server()
 
751
        t = self._transport(server.get_url())
 
752
        l = list(t.readv('a', ((1, 1), (9, 1), (0, 1), (3, 2))))
 
753
        self.assertEqual(l[0], (1, '1'))
 
754
        self.assertEqual(l[1], (9, '9'))
 
755
        self.assertEqual(l[2], (0, '0'))
 
756
        self.assertEqual(l[3], (3, '34'))
 
757
 
 
758
    def test_readv_invalid_ranges(self):
 
759
        server = self.get_readonly_server()
 
760
        t = self._transport(server.get_url())
 
761
 
 
762
        # This is intentionally reading off the end of the file
 
763
        # since we are sure that it cannot get there
 
764
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
765
                              t.readv, 'a', [(1,1), (8,10)])
 
766
 
 
767
        # This is trying to seek past the end of the file, it should
 
768
        # also raise a special error
 
769
        self.assertListRaises((errors.InvalidRange, errors.ShortReadvError,),
 
770
                              t.readv, 'a', [(12,2)])
 
771
 
 
772
    def test_readv_multiple_get_requests(self):
 
773
        server = self.get_readonly_server()
 
774
        t = self._transport(server.get_url())
 
775
        # force transport to issue multiple requests
 
776
        t._max_readv_combine = 1
 
777
        t._max_get_ranges = 1
 
778
        l = list(t.readv('a', ((0, 1), (1, 1), (3, 2), (9, 1))))
 
779
        self.assertEqual(l[0], (0, '0'))
 
780
        self.assertEqual(l[1], (1, '1'))
 
781
        self.assertEqual(l[2], (3, '34'))
 
782
        self.assertEqual(l[3], (9, '9'))
 
783
        # The server should have issued 4 requests
 
784
        self.assertEqual(4, server.GET_request_nb)
 
785
 
 
786
    def test_readv_get_max_size(self):
 
787
        server = self.get_readonly_server()
 
788
        t = self._transport(server.get_url())
 
789
        # force transport to issue multiple requests by limiting the number of
 
790
        # bytes by request. Note that this apply to coalesced offsets only, a
 
791
        # single range will keep its size even if bigger than the limit.
 
792
        t._get_max_size = 2
 
793
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
794
        self.assertEqual(l[0], (0, '0'))
 
795
        self.assertEqual(l[1], (1, '1'))
 
796
        self.assertEqual(l[2], (2, '2345'))
 
797
        self.assertEqual(l[3], (6, '6789'))
 
798
        # The server should have issued 3 requests
 
799
        self.assertEqual(3, server.GET_request_nb)
 
800
 
 
801
    def test_complete_readv_leave_pipe_clean(self):
 
802
        server = self.get_readonly_server()
 
803
        t = self._transport(server.get_url())
 
804
        # force transport to issue multiple requests
 
805
        t._get_max_size = 2
 
806
        l = list(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
807
        # The server should have issued 3 requests
 
808
        self.assertEqual(3, server.GET_request_nb)
 
809
        self.assertEqual('0123456789', t.get_bytes('a'))
 
810
        self.assertEqual(4, server.GET_request_nb)
 
811
 
 
812
    def test_incomplete_readv_leave_pipe_clean(self):
 
813
        server = self.get_readonly_server()
 
814
        t = self._transport(server.get_url())
 
815
        # force transport to issue multiple requests
 
816
        t._get_max_size = 2
 
817
        # Don't collapse readv results into a list so that we leave unread
 
818
        # bytes on the socket
 
819
        ireadv = iter(t.readv('a', ((0, 1), (1, 1), (2, 4), (6, 4))))
 
820
        self.assertEqual((0, '0'), ireadv.next())
 
821
        # The server should have issued one request so far 
 
822
        self.assertEqual(1, server.GET_request_nb)
 
823
        self.assertEqual('0123456789', t.get_bytes('a'))
 
824
        # get_bytes issued an additional request, the readv pending ones are
 
825
        # lost
 
826
        self.assertEqual(2, server.GET_request_nb)
 
827
 
 
828
 
 
829
class SingleRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
830
    """Always reply to range request as if they were single.
 
831
 
 
832
    Don't be explicit about it, just to annoy the clients.
 
833
    """
 
834
 
 
835
    def get_multiple_ranges(self, file, file_size, ranges):
 
836
        """Answer as if it was a single range request and ignores the rest"""
 
837
        (start, end) = ranges[0]
 
838
        return self.get_single_range(file, file_size, start, end)
 
839
 
 
840
 
 
841
class TestSingleRangeRequestServer(TestRangeRequestServer):
 
842
    """Test readv against a server which accept only single range requests"""
 
843
 
 
844
    _req_handler_class = SingleRangeRequestHandler
 
845
 
 
846
 
 
847
class SingleOnlyRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
848
    """Only reply to simple range requests, errors out on multiple"""
 
849
 
 
850
    def get_multiple_ranges(self, file, file_size, ranges):
 
851
        """Refuses the multiple ranges request"""
 
852
        if len(ranges) > 1:
 
853
            file.close()
 
854
            self.send_error(416, "Requested range not satisfiable")
 
855
            return
 
856
        (start, end) = ranges[0]
 
857
        return self.get_single_range(file, file_size, start, end)
 
858
 
 
859
 
 
860
class TestSingleOnlyRangeRequestServer(TestRangeRequestServer):
 
861
    """Test readv against a server which only accept single range requests"""
 
862
 
 
863
    _req_handler_class = SingleOnlyRangeRequestHandler
 
864
 
 
865
 
 
866
class NoRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
867
    """Ignore range requests without notice"""
 
868
 
 
869
    def do_GET(self):
 
870
        # Update the statistics
 
871
        self.server.test_case_server.GET_request_nb += 1
 
872
        # Just bypass the range handling done by TestingHTTPRequestHandler
 
873
        return SimpleHTTPServer.SimpleHTTPRequestHandler.do_GET(self)
 
874
 
 
875
 
 
876
class TestNoRangeRequestServer(TestRangeRequestServer):
 
877
    """Test readv against a server which do not accept range requests"""
 
878
 
 
879
    _req_handler_class = NoRangeRequestHandler
 
880
 
 
881
 
 
882
class MultipleRangeWithoutContentLengthRequestHandler(
 
883
    http_server.TestingHTTPRequestHandler):
 
884
    """Reply to multiple range requests without content length header."""
 
885
 
 
886
    def get_multiple_ranges(self, file, file_size, ranges):
 
887
        self.send_response(206)
 
888
        self.send_header('Accept-Ranges', 'bytes')
 
889
        boundary = "%d" % random.randint(0,0x7FFFFFFF)
 
890
        self.send_header("Content-Type",
 
891
                         "multipart/byteranges; boundary=%s" % boundary)
 
892
        self.end_headers()
 
893
        for (start, end) in ranges:
 
894
            self.wfile.write("--%s\r\n" % boundary)
 
895
            self.send_header("Content-type", 'application/octet-stream')
 
896
            self.send_header("Content-Range", "bytes %d-%d/%d" % (start,
 
897
                                                                  end,
 
898
                                                                  file_size))
 
899
            self.end_headers()
 
900
            self.send_range_content(file, start, end - start + 1)
 
901
        # Final boundary
 
902
        self.wfile.write("--%s\r\n" % boundary)
 
903
 
 
904
 
 
905
class TestMultipleRangeWithoutContentLengthServer(TestRangeRequestServer):
 
906
 
 
907
    _req_handler_class = MultipleRangeWithoutContentLengthRequestHandler
 
908
 
 
909
 
 
910
class TruncatedMultipleRangeRequestHandler(
 
911
    http_server.TestingHTTPRequestHandler):
 
912
    """Reply to multiple range requests truncating the last ones.
 
913
 
 
914
    This server generates responses whose Content-Length describes all the
 
915
    ranges, but fail to include the last ones leading to client short reads.
 
916
    This has been observed randomly with lighttpd (bug #179368).
 
917
    """
 
918
 
 
919
    _truncated_ranges = 2
 
920
 
 
921
    def get_multiple_ranges(self, file, file_size, ranges):
 
922
        self.send_response(206)
 
923
        self.send_header('Accept-Ranges', 'bytes')
 
924
        boundary = 'tagada'
 
925
        self.send_header('Content-Type',
 
926
                         'multipart/byteranges; boundary=%s' % boundary)
 
927
        boundary_line = '--%s\r\n' % boundary
 
928
        # Calculate the Content-Length
 
929
        content_length = 0
 
930
        for (start, end) in ranges:
 
931
            content_length += len(boundary_line)
 
932
            content_length += self._header_line_length(
 
933
                'Content-type', 'application/octet-stream')
 
934
            content_length += self._header_line_length(
 
935
                'Content-Range', 'bytes %d-%d/%d' % (start, end, file_size))
 
936
            content_length += len('\r\n') # end headers
 
937
            content_length += end - start # + 1
 
938
        content_length += len(boundary_line)
 
939
        self.send_header('Content-length', content_length)
 
940
        self.end_headers()
 
941
 
 
942
        # Send the multipart body
 
943
        cur = 0
 
944
        for (start, end) in ranges:
 
945
            self.wfile.write(boundary_line)
 
946
            self.send_header('Content-type', 'application/octet-stream')
 
947
            self.send_header('Content-Range', 'bytes %d-%d/%d'
 
948
                             % (start, end, file_size))
 
949
            self.end_headers()
 
950
            if cur + self._truncated_ranges >= len(ranges):
 
951
                # Abruptly ends the response and close the connection
 
952
                self.close_connection = 1
 
953
                return
 
954
            self.send_range_content(file, start, end - start + 1)
 
955
            cur += 1
 
956
        # No final boundary
 
957
        self.wfile.write(boundary_line)
 
958
 
 
959
 
 
960
class TestTruncatedMultipleRangeServer(TestSpecificRequestHandler):
 
961
 
 
962
    _req_handler_class = TruncatedMultipleRangeRequestHandler
 
963
 
 
964
    def setUp(self):
 
965
        super(TestTruncatedMultipleRangeServer, self).setUp()
 
966
        self.build_tree_contents([('a', '0123456789')],)
 
967
 
 
968
    def test_readv_with_short_reads(self):
 
969
        server = self.get_readonly_server()
 
970
        t = self._transport(server.get_url())
 
971
        # Force separate ranges for each offset
 
972
        t._bytes_to_read_before_seek = 0
 
973
        ireadv = iter(t.readv('a', ((0, 1), (2, 1), (4, 2), (9, 1))))
 
974
        self.assertEqual((0, '0'), ireadv.next())
 
975
        self.assertEqual((2, '2'), ireadv.next())
 
976
        if not self._testing_pycurl():
 
977
            # Only one request have been issued so far (except for pycurl that
 
978
            # try to read the whole response at once)
 
979
            self.assertEqual(1, server.GET_request_nb)
 
980
        self.assertEqual((4, '45'), ireadv.next())
 
981
        self.assertEqual((9, '9'), ireadv.next())
 
982
        # Both implementations issue 3 requests but:
 
983
        # - urllib does two multiple (4 ranges, then 2 ranges) then a single
 
984
        #   range,
 
985
        # - pycurl does two multiple (4 ranges, 4 ranges) then a single range
 
986
        self.assertEqual(3, server.GET_request_nb)
 
987
        # Finally the client have tried a single range request and stays in
 
988
        # that mode
 
989
        self.assertEqual('single', t._range_hint)
 
990
 
 
991
class LimitedRangeRequestHandler(http_server.TestingHTTPRequestHandler):
 
992
    """Errors out when range specifiers exceed the limit"""
 
993
 
 
994
    def get_multiple_ranges(self, file, file_size, ranges):
 
995
        """Refuses the multiple ranges request"""
 
996
        tcs = self.server.test_case_server
 
997
        if tcs.range_limit is not None and len(ranges) > tcs.range_limit:
 
998
            file.close()
 
999
            # Emulate apache behavior
 
1000
            self.send_error(400, "Bad Request")
 
1001
            return
 
1002
        return http_server.TestingHTTPRequestHandler.get_multiple_ranges(
 
1003
            self, file, file_size, ranges)
 
1004
 
 
1005
 
 
1006
class LimitedRangeHTTPServer(http_server.HttpServer):
 
1007
    """An HttpServer erroring out on requests with too much range specifiers"""
 
1008
 
 
1009
    def __init__(self, request_handler=LimitedRangeRequestHandler,
 
1010
                 protocol_version=None,
 
1011
                 range_limit=None):
 
1012
        http_server.HttpServer.__init__(self, request_handler,
 
1013
                                        protocol_version=protocol_version)
 
1014
        self.range_limit = range_limit
 
1015
 
 
1016
 
 
1017
class TestLimitedRangeRequestServer(http_utils.TestCaseWithWebserver):
 
1018
    """Tests readv requests against a server erroring out on too much ranges."""
 
1019
 
 
1020
    # Requests with more range specifiers will error out
 
1021
    range_limit = 3
 
1022
 
 
1023
    def create_transport_readonly_server(self):
 
1024
        return LimitedRangeHTTPServer(range_limit=self.range_limit,
 
1025
                                      protocol_version=self._protocol_version)
 
1026
 
 
1027
    def get_transport(self):
 
1028
        return self._transport(self.get_readonly_server().get_url())
 
1029
 
 
1030
    def setUp(self):
 
1031
        http_utils.TestCaseWithWebserver.setUp(self)
 
1032
        # We need to manipulate ranges that correspond to real chunks in the
 
1033
        # response, so we build a content appropriately.
 
1034
        filler = ''.join(['abcdefghij' for x in range(102)])
 
1035
        content = ''.join(['%04d' % v + filler for v in range(16)])
 
1036
        self.build_tree_contents([('a', content)],)
 
1037
 
 
1038
    def test_few_ranges(self):
 
1039
        t = self.get_transport()
 
1040
        l = list(t.readv('a', ((0, 4), (1024, 4), )))
 
1041
        self.assertEqual(l[0], (0, '0000'))
 
1042
        self.assertEqual(l[1], (1024, '0001'))
 
1043
        self.assertEqual(1, self.get_readonly_server().GET_request_nb)
 
1044
 
 
1045
    def test_more_ranges(self):
 
1046
        t = self.get_transport()
 
1047
        l = list(t.readv('a', ((0, 4), (1024, 4), (4096, 4), (8192, 4))))
 
1048
        self.assertEqual(l[0], (0, '0000'))
 
1049
        self.assertEqual(l[1], (1024, '0001'))
 
1050
        self.assertEqual(l[2], (4096, '0004'))
 
1051
        self.assertEqual(l[3], (8192, '0008'))
 
1052
        # The server will refuse to serve the first request (too much ranges),
 
1053
        # a second request will succeed.
 
1054
        self.assertEqual(2, self.get_readonly_server().GET_request_nb)
 
1055
 
 
1056
 
 
1057
class TestHttpProxyWhiteBox(tests.TestCase):
 
1058
    """Whitebox test proxy http authorization.
 
1059
 
 
1060
    Only the urllib implementation is tested here.
 
1061
    """
 
1062
 
 
1063
    def setUp(self):
 
1064
        tests.TestCase.setUp(self)
 
1065
        self._old_env = {}
 
1066
 
 
1067
    def tearDown(self):
 
1068
        self._restore_env()
 
1069
        tests.TestCase.tearDown(self)
 
1070
 
 
1071
    def _install_env(self, env):
 
1072
        for name, value in env.iteritems():
 
1073
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1074
 
 
1075
    def _restore_env(self):
 
1076
        for name, value in self._old_env.iteritems():
 
1077
            osutils.set_or_unset_env(name, value)
 
1078
 
 
1079
    def _proxied_request(self):
 
1080
        handler = _urllib2_wrappers.ProxyHandler()
 
1081
        request = _urllib2_wrappers.Request('GET','http://baz/buzzle')
 
1082
        handler.set_proxy(request, 'http')
 
1083
        return request
 
1084
 
 
1085
    def test_empty_user(self):
 
1086
        self._install_env({'http_proxy': 'http://bar.com'})
 
1087
        request = self._proxied_request()
 
1088
        self.assertFalse(request.headers.has_key('Proxy-authorization'))
 
1089
 
 
1090
    def test_invalid_proxy(self):
 
1091
        """A proxy env variable without scheme"""
 
1092
        self._install_env({'http_proxy': 'host:1234'})
 
1093
        self.assertRaises(errors.InvalidURL, self._proxied_request)
 
1094
 
 
1095
 
 
1096
class TestProxyHttpServer(http_utils.TestCaseWithTwoWebservers):
 
1097
    """Tests proxy server.
 
1098
 
 
1099
    Be aware that we do not setup a real proxy here. Instead, we
 
1100
    check that the *connection* goes through the proxy by serving
 
1101
    different content (the faked proxy server append '-proxied'
 
1102
    to the file names).
 
1103
    """
 
1104
 
 
1105
    # FIXME: We don't have an https server available, so we don't
 
1106
    # test https connections.
 
1107
 
 
1108
    def setUp(self):
 
1109
        super(TestProxyHttpServer, self).setUp()
 
1110
        self.build_tree_contents([('foo', 'contents of foo\n'),
 
1111
                                  ('foo-proxied', 'proxied contents of foo\n')])
 
1112
        # Let's setup some attributes for tests
 
1113
        self.server = self.get_readonly_server()
 
1114
        self.proxy_address = '%s:%d' % (self.server.host, self.server.port)
 
1115
        if self._testing_pycurl():
 
1116
            # Oh my ! pycurl does not check for the port as part of
 
1117
            # no_proxy :-( So we just test the host part
 
1118
            self.no_proxy_host = 'localhost'
 
1119
        else:
 
1120
            self.no_proxy_host = self.proxy_address
 
1121
        # The secondary server is the proxy
 
1122
        self.proxy = self.get_secondary_server()
 
1123
        self.proxy_url = self.proxy.get_url()
 
1124
        self._old_env = {}
 
1125
 
 
1126
    def _testing_pycurl(self):
 
1127
        return pycurl_present and self._transport == PyCurlTransport
 
1128
 
 
1129
    def create_transport_secondary_server(self):
 
1130
        """Creates an http server that will serve files with
 
1131
        '-proxied' appended to their names.
 
1132
        """
 
1133
        return http_utils.ProxyServer(protocol_version=self._protocol_version)
 
1134
 
 
1135
    def _install_env(self, env):
 
1136
        for name, value in env.iteritems():
 
1137
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1138
 
 
1139
    def _restore_env(self):
 
1140
        for name, value in self._old_env.iteritems():
 
1141
            osutils.set_or_unset_env(name, value)
 
1142
 
 
1143
    def proxied_in_env(self, env):
 
1144
        self._install_env(env)
 
1145
        url = self.server.get_url()
 
1146
        t = self._transport(url)
 
1147
        try:
 
1148
            self.assertEqual(t.get('foo').read(), 'proxied contents of foo\n')
 
1149
        finally:
 
1150
            self._restore_env()
 
1151
 
 
1152
    def not_proxied_in_env(self, env):
 
1153
        self._install_env(env)
 
1154
        url = self.server.get_url()
 
1155
        t = self._transport(url)
 
1156
        try:
 
1157
            self.assertEqual(t.get('foo').read(), 'contents of foo\n')
 
1158
        finally:
 
1159
            self._restore_env()
 
1160
 
 
1161
    def test_http_proxy(self):
 
1162
        self.proxied_in_env({'http_proxy': self.proxy_url})
 
1163
 
 
1164
    def test_HTTP_PROXY(self):
 
1165
        if self._testing_pycurl():
 
1166
            # pycurl does not check HTTP_PROXY for security reasons
 
1167
            # (for use in a CGI context that we do not care
 
1168
            # about. Should we ?)
 
1169
            raise tests.TestNotApplicable(
 
1170
                'pycurl does not check HTTP_PROXY for security reasons')
 
1171
        self.proxied_in_env({'HTTP_PROXY': self.proxy_url})
 
1172
 
 
1173
    def test_all_proxy(self):
 
1174
        self.proxied_in_env({'all_proxy': self.proxy_url})
 
1175
 
 
1176
    def test_ALL_PROXY(self):
 
1177
        self.proxied_in_env({'ALL_PROXY': self.proxy_url})
 
1178
 
 
1179
    def test_http_proxy_with_no_proxy(self):
 
1180
        self.not_proxied_in_env({'http_proxy': self.proxy_url,
 
1181
                                 'no_proxy': self.no_proxy_host})
 
1182
 
 
1183
    def test_HTTP_PROXY_with_NO_PROXY(self):
 
1184
        if self._testing_pycurl():
 
1185
            raise tests.TestNotApplicable(
 
1186
                'pycurl does not check HTTP_PROXY for security reasons')
 
1187
        self.not_proxied_in_env({'HTTP_PROXY': self.proxy_url,
 
1188
                                 'NO_PROXY': self.no_proxy_host})
 
1189
 
 
1190
    def test_all_proxy_with_no_proxy(self):
 
1191
        self.not_proxied_in_env({'all_proxy': self.proxy_url,
 
1192
                                 'no_proxy': self.no_proxy_host})
 
1193
 
 
1194
    def test_ALL_PROXY_with_NO_PROXY(self):
 
1195
        self.not_proxied_in_env({'ALL_PROXY': self.proxy_url,
 
1196
                                 'NO_PROXY': self.no_proxy_host})
 
1197
 
 
1198
    def test_http_proxy_without_scheme(self):
 
1199
        if self._testing_pycurl():
 
1200
            # pycurl *ignores* invalid proxy env variables. If that ever change
 
1201
            # in the future, this test will fail indicating that pycurl do not
 
1202
            # ignore anymore such variables.
 
1203
            self.not_proxied_in_env({'http_proxy': self.proxy_address})
 
1204
        else:
 
1205
            self.assertRaises(errors.InvalidURL,
 
1206
                              self.proxied_in_env,
 
1207
                              {'http_proxy': self.proxy_address})
 
1208
 
 
1209
 
 
1210
class TestRanges(http_utils.TestCaseWithWebserver):
 
1211
    """Test the Range header in GET methods."""
 
1212
 
 
1213
    def setUp(self):
 
1214
        http_utils.TestCaseWithWebserver.setUp(self)
 
1215
        self.build_tree_contents([('a', '0123456789')],)
 
1216
        server = self.get_readonly_server()
 
1217
        self.transport = self._transport(server.get_url())
 
1218
 
 
1219
    def create_transport_readonly_server(self):
 
1220
        return http_server.HttpServer(protocol_version=self._protocol_version)
 
1221
 
 
1222
    def _file_contents(self, relpath, ranges):
 
1223
        offsets = [ (start, end - start + 1) for start, end in ranges]
 
1224
        coalesce = self.transport._coalesce_offsets
 
1225
        coalesced = list(coalesce(offsets, limit=0, fudge_factor=0))
 
1226
        code, data = self.transport._get(relpath, coalesced)
 
1227
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1228
        for start, end in ranges:
 
1229
            data.seek(start)
 
1230
            yield data.read(end - start + 1)
 
1231
 
 
1232
    def _file_tail(self, relpath, tail_amount):
 
1233
        code, data = self.transport._get(relpath, [], tail_amount)
 
1234
        self.assertTrue(code in (200, 206),'_get returns: %d' % code)
 
1235
        data.seek(-tail_amount, 2)
 
1236
        return data.read(tail_amount)
 
1237
 
 
1238
    def test_range_header(self):
 
1239
        # Valid ranges
 
1240
        map(self.assertEqual,['0', '234'],
 
1241
            list(self._file_contents('a', [(0,0), (2,4)])),)
 
1242
 
 
1243
    def test_range_header_tail(self):
 
1244
        self.assertEqual('789', self._file_tail('a', 3))
 
1245
 
 
1246
    def test_syntactically_invalid_range_header(self):
 
1247
        self.assertListRaises(errors.InvalidHttpRange,
 
1248
                          self._file_contents, 'a', [(4, 3)])
 
1249
 
 
1250
    def test_semantically_invalid_range_header(self):
 
1251
        self.assertListRaises(errors.InvalidHttpRange,
 
1252
                          self._file_contents, 'a', [(42, 128)])
 
1253
 
 
1254
 
 
1255
class TestHTTPRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1256
    """Test redirection between http servers."""
 
1257
 
 
1258
    def create_transport_secondary_server(self):
 
1259
        """Create the secondary server redirecting to the primary server"""
 
1260
        new = self.get_readonly_server()
 
1261
 
 
1262
        redirecting = http_utils.HTTPServerRedirecting(
 
1263
            protocol_version=self._protocol_version)
 
1264
        redirecting.redirect_to(new.host, new.port)
 
1265
        return redirecting
 
1266
 
 
1267
    def setUp(self):
 
1268
        super(TestHTTPRedirections, self).setUp()
 
1269
        self.build_tree_contents([('a', '0123456789'),
 
1270
                                  ('bundle',
 
1271
                                  '# Bazaar revision bundle v0.9\n#\n')
 
1272
                                  ],)
 
1273
 
 
1274
        self.old_transport = self._transport(self.old_server.get_url())
 
1275
 
 
1276
    def test_redirected(self):
 
1277
        self.assertRaises(errors.RedirectRequested, self.old_transport.get, 'a')
 
1278
        t = self._transport(self.new_server.get_url())
 
1279
        self.assertEqual('0123456789', t.get('a').read())
 
1280
 
 
1281
    def test_read_redirected_bundle_from_url(self):
 
1282
        from bzrlib.bundle import read_bundle_from_url
 
1283
        url = self.old_transport.abspath('bundle')
 
1284
        bundle = read_bundle_from_url(url)
 
1285
        # If read_bundle_from_url was successful we get an empty bundle
 
1286
        self.assertEqual([], bundle.revisions)
 
1287
 
 
1288
 
 
1289
class RedirectedRequest(_urllib2_wrappers.Request):
 
1290
    """Request following redirections. """
 
1291
 
 
1292
    init_orig = _urllib2_wrappers.Request.__init__
 
1293
 
 
1294
    def __init__(self, method, url, *args, **kwargs):
 
1295
        """Constructor.
 
1296
 
 
1297
        """
 
1298
        # Since the tests using this class will replace
 
1299
        # _urllib2_wrappers.Request, we can't just call the base class __init__
 
1300
        # or we'll loop.
 
1301
        RedirectedRequest.init_orig(self, method, url, args, kwargs)
 
1302
        self.follow_redirections = True
 
1303
 
 
1304
 
 
1305
class TestHTTPSilentRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1306
    """Test redirections.
 
1307
 
 
1308
    http implementations do not redirect silently anymore (they
 
1309
    do not redirect at all in fact). The mechanism is still in
 
1310
    place at the _urllib2_wrappers.Request level and these tests
 
1311
    exercise it.
 
1312
 
 
1313
    For the pycurl implementation
 
1314
    the redirection have been deleted as we may deprecate pycurl
 
1315
    and I have no place to keep a working implementation.
 
1316
    -- vila 20070212
 
1317
    """
 
1318
 
 
1319
    def setUp(self):
 
1320
        if pycurl_present and self._transport == PyCurlTransport:
 
1321
            raise tests.TestNotApplicable(
 
1322
                "pycurl doesn't redirect silently annymore")
 
1323
        super(TestHTTPSilentRedirections, self).setUp()
 
1324
        self.setup_redirected_request()
 
1325
        self.addCleanup(self.cleanup_redirected_request)
 
1326
        self.build_tree_contents([('a','a'),
 
1327
                                  ('1/',),
 
1328
                                  ('1/a', 'redirected once'),
 
1329
                                  ('2/',),
 
1330
                                  ('2/a', 'redirected twice'),
 
1331
                                  ('3/',),
 
1332
                                  ('3/a', 'redirected thrice'),
 
1333
                                  ('4/',),
 
1334
                                  ('4/a', 'redirected 4 times'),
 
1335
                                  ('5/',),
 
1336
                                  ('5/a', 'redirected 5 times'),
 
1337
                                  ],)
 
1338
 
 
1339
        self.old_transport = self._transport(self.old_server.get_url())
 
1340
 
 
1341
    def setup_redirected_request(self):
 
1342
        self.original_class = _urllib2_wrappers.Request
 
1343
        _urllib2_wrappers.Request = RedirectedRequest
 
1344
 
 
1345
    def cleanup_redirected_request(self):
 
1346
        _urllib2_wrappers.Request = self.original_class
 
1347
 
 
1348
    def create_transport_secondary_server(self):
 
1349
        """Create the secondary server, redirections are defined in the tests"""
 
1350
        return http_utils.HTTPServerRedirecting(
 
1351
            protocol_version=self._protocol_version)
 
1352
 
 
1353
    def test_one_redirection(self):
 
1354
        t = self.old_transport
 
1355
 
 
1356
        req = RedirectedRequest('GET', t.abspath('a'))
 
1357
        req.follow_redirections = True
 
1358
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1359
                                       self.new_server.port)
 
1360
        self.old_server.redirections = \
 
1361
            [('(.*)', r'%s/1\1' % (new_prefix), 301),]
 
1362
        self.assertEquals('redirected once',t._perform(req).read())
 
1363
 
 
1364
    def test_five_redirections(self):
 
1365
        t = self.old_transport
 
1366
 
 
1367
        req = RedirectedRequest('GET', t.abspath('a'))
 
1368
        req.follow_redirections = True
 
1369
        old_prefix = 'http://%s:%s' % (self.old_server.host,
 
1370
                                       self.old_server.port)
 
1371
        new_prefix = 'http://%s:%s' % (self.new_server.host,
 
1372
                                       self.new_server.port)
 
1373
        self.old_server.redirections = [
 
1374
            ('/1(.*)', r'%s/2\1' % (old_prefix), 302),
 
1375
            ('/2(.*)', r'%s/3\1' % (old_prefix), 303),
 
1376
            ('/3(.*)', r'%s/4\1' % (old_prefix), 307),
 
1377
            ('/4(.*)', r'%s/5\1' % (new_prefix), 301),
 
1378
            ('(/[^/]+)', r'%s/1\1' % (old_prefix), 301),
 
1379
            ]
 
1380
        self.assertEquals('redirected 5 times',t._perform(req).read())
 
1381
 
 
1382
 
 
1383
class TestDoCatchRedirections(http_utils.TestCaseWithRedirectedWebserver):
 
1384
    """Test transport.do_catching_redirections."""
 
1385
 
 
1386
    def setUp(self):
 
1387
        super(TestDoCatchRedirections, self).setUp()
 
1388
        self.build_tree_contents([('a', '0123456789'),],)
 
1389
 
 
1390
        self.old_transport = self._transport(self.old_server.get_url())
 
1391
 
 
1392
    def get_a(self, transport):
 
1393
        return transport.get('a')
 
1394
 
 
1395
    def test_no_redirection(self):
 
1396
        t = self._transport(self.new_server.get_url())
 
1397
 
 
1398
        # We use None for redirected so that we fail if redirected
 
1399
        self.assertEquals('0123456789',
 
1400
                          transport.do_catching_redirections(
 
1401
                self.get_a, t, None).read())
 
1402
 
 
1403
    def test_one_redirection(self):
 
1404
        self.redirections = 0
 
1405
 
 
1406
        def redirected(transport, exception, redirection_notice):
 
1407
            self.redirections += 1
 
1408
            dir, file = urlutils.split(exception.target)
 
1409
            return self._transport(dir)
 
1410
 
 
1411
        self.assertEquals('0123456789',
 
1412
                          transport.do_catching_redirections(
 
1413
                self.get_a, self.old_transport, redirected).read())
 
1414
        self.assertEquals(1, self.redirections)
 
1415
 
 
1416
    def test_redirection_loop(self):
 
1417
 
 
1418
        def redirected(transport, exception, redirection_notice):
 
1419
            # By using the redirected url as a base dir for the
 
1420
            # *old* transport, we create a loop: a => a/a =>
 
1421
            # a/a/a
 
1422
            return self.old_transport.clone(exception.target)
 
1423
 
 
1424
        self.assertRaises(errors.TooManyRedirections,
 
1425
                          transport.do_catching_redirections,
 
1426
                          self.get_a, self.old_transport, redirected)
 
1427
 
 
1428
 
 
1429
class TestAuth(http_utils.TestCaseWithWebserver):
 
1430
    """Test authentication scheme"""
 
1431
 
 
1432
    _auth_header = 'Authorization'
 
1433
    _password_prompt_prefix = ''
 
1434
 
 
1435
    def setUp(self):
 
1436
        super(TestAuth, self).setUp()
 
1437
        self.server = self.get_readonly_server()
 
1438
        self.build_tree_contents([('a', 'contents of a\n'),
 
1439
                                  ('b', 'contents of b\n'),])
 
1440
 
 
1441
    def create_transport_readonly_server(self):
 
1442
        if self._auth_scheme == 'basic':
 
1443
            server = http_utils.HTTPBasicAuthServer(
 
1444
                protocol_version=self._protocol_version)
 
1445
        else:
 
1446
            if self._auth_scheme != 'digest':
 
1447
                raise AssertionError('Unknown auth scheme: %r'
 
1448
                                     % self._auth_scheme)
 
1449
            server = http_utils.HTTPDigestAuthServer(
 
1450
                protocol_version=self._protocol_version)
 
1451
        return server
 
1452
 
 
1453
    def _testing_pycurl(self):
 
1454
        return pycurl_present and self._transport == PyCurlTransport
 
1455
 
 
1456
    def get_user_url(self, user=None, password=None):
 
1457
        """Build an url embedding user and password"""
 
1458
        url = '%s://' % self.server._url_protocol
 
1459
        if user is not None:
 
1460
            url += user
 
1461
            if password is not None:
 
1462
                url += ':' + password
 
1463
            url += '@'
 
1464
        url += '%s:%s/' % (self.server.host, self.server.port)
 
1465
        return url
 
1466
 
 
1467
    def get_user_transport(self, user=None, password=None):
 
1468
        return self._transport(self.get_user_url(user, password))
 
1469
 
 
1470
    def test_no_user(self):
 
1471
        self.server.add_user('joe', 'foo')
 
1472
        t = self.get_user_transport()
 
1473
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1474
        # Only one 'Authentication Required' error should occur
 
1475
        self.assertEqual(1, self.server.auth_required_errors)
 
1476
 
 
1477
    def test_empty_pass(self):
 
1478
        self.server.add_user('joe', '')
 
1479
        t = self.get_user_transport('joe', '')
 
1480
        self.assertEqual('contents of a\n', t.get('a').read())
 
1481
        # Only one 'Authentication Required' error should occur
 
1482
        self.assertEqual(1, self.server.auth_required_errors)
 
1483
 
 
1484
    def test_user_pass(self):
 
1485
        self.server.add_user('joe', 'foo')
 
1486
        t = self.get_user_transport('joe', 'foo')
 
1487
        self.assertEqual('contents of a\n', t.get('a').read())
 
1488
        # Only one 'Authentication Required' error should occur
 
1489
        self.assertEqual(1, self.server.auth_required_errors)
 
1490
 
 
1491
    def test_unknown_user(self):
 
1492
        self.server.add_user('joe', 'foo')
 
1493
        t = self.get_user_transport('bill', 'foo')
 
1494
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1495
        # Two 'Authentication Required' errors should occur (the
 
1496
        # initial 'who are you' and 'I don't know you, who are
 
1497
        # you').
 
1498
        self.assertEqual(2, self.server.auth_required_errors)
 
1499
 
 
1500
    def test_wrong_pass(self):
 
1501
        self.server.add_user('joe', 'foo')
 
1502
        t = self.get_user_transport('joe', 'bar')
 
1503
        self.assertRaises(errors.InvalidHttpResponse, t.get, 'a')
 
1504
        # Two 'Authentication Required' errors should occur (the
 
1505
        # initial 'who are you' and 'this is not you, who are you')
 
1506
        self.assertEqual(2, self.server.auth_required_errors)
 
1507
 
 
1508
    def test_prompt_for_password(self):
 
1509
        if self._testing_pycurl():
 
1510
            raise tests.TestNotApplicable(
 
1511
                'pycurl cannot prompt, it handles auth by embedding'
 
1512
                ' user:pass in urls only')
 
1513
 
 
1514
        self.server.add_user('joe', 'foo')
 
1515
        t = self.get_user_transport('joe', None)
 
1516
        stdout = tests.StringIOWrapper()
 
1517
        ui.ui_factory = tests.TestUIFactory(stdin='foo\n', stdout=stdout)
 
1518
        self.assertEqual('contents of a\n',t.get('a').read())
 
1519
        # stdin should be empty
 
1520
        self.assertEqual('', ui.ui_factory.stdin.readline())
 
1521
        self._check_password_prompt(t._unqualified_scheme, 'joe',
 
1522
                                    stdout.getvalue())
 
1523
        # And we shouldn't prompt again for a different request
 
1524
        # against the same transport.
 
1525
        self.assertEqual('contents of b\n',t.get('b').read())
 
1526
        t2 = t.clone()
 
1527
        # And neither against a clone
 
1528
        self.assertEqual('contents of b\n',t2.get('b').read())
 
1529
        # Only one 'Authentication Required' error should occur
 
1530
        self.assertEqual(1, self.server.auth_required_errors)
 
1531
 
 
1532
    def _check_password_prompt(self, scheme, user, actual_prompt):
 
1533
        expected_prompt = (self._password_prompt_prefix
 
1534
                           + ("%s %s@%s:%d, Realm: '%s' password: "
 
1535
                              % (scheme.upper(),
 
1536
                                 user, self.server.host, self.server.port,
 
1537
                                 self.server.auth_realm)))
 
1538
        self.assertEquals(expected_prompt, actual_prompt)
 
1539
 
 
1540
    def test_no_prompt_for_password_when_using_auth_config(self):
 
1541
        if self._testing_pycurl():
 
1542
            raise tests.TestNotApplicable(
 
1543
                'pycurl does not support authentication.conf'
 
1544
                ' since it cannot prompt')
 
1545
 
 
1546
        user =' joe'
 
1547
        password = 'foo'
 
1548
        stdin_content = 'bar\n'  # Not the right password
 
1549
        self.server.add_user(user, password)
 
1550
        t = self.get_user_transport(user, None)
 
1551
        ui.ui_factory = tests.TestUIFactory(stdin=stdin_content,
 
1552
                                            stdout=tests.StringIOWrapper())
 
1553
        # Create a minimal config file with the right password
 
1554
        conf = config.AuthenticationConfig()
 
1555
        conf._get_config().update(
 
1556
            {'httptest': {'scheme': 'http', 'port': self.server.port,
 
1557
                          'user': user, 'password': password}})
 
1558
        conf._save()
 
1559
        # Issue a request to the server to connect
 
1560
        self.assertEqual('contents of a\n',t.get('a').read())
 
1561
        # stdin should have  been left untouched
 
1562
        self.assertEqual(stdin_content, ui.ui_factory.stdin.readline())
 
1563
        # Only one 'Authentication Required' error should occur
 
1564
        self.assertEqual(1, self.server.auth_required_errors)
 
1565
 
 
1566
    def test_changing_nonce(self):
 
1567
        if self._auth_scheme != 'digest':
 
1568
            raise tests.TestNotApplicable('HTTP auth digest only test')
 
1569
        if self._testing_pycurl():
 
1570
            raise tests.KnownFailure(
 
1571
                'pycurl does not handle a nonce change')
 
1572
        self.server.add_user('joe', 'foo')
 
1573
        t = self.get_user_transport('joe', 'foo')
 
1574
        self.assertEqual('contents of a\n', t.get('a').read())
 
1575
        self.assertEqual('contents of b\n', t.get('b').read())
 
1576
        # Only one 'Authentication Required' error should have
 
1577
        # occured so far
 
1578
        self.assertEqual(1, self.server.auth_required_errors)
 
1579
        # The server invalidates the current nonce
 
1580
        self.server.auth_nonce = self.server.auth_nonce + '. No, now!'
 
1581
        self.assertEqual('contents of a\n', t.get('a').read())
 
1582
        # Two 'Authentication Required' errors should occur (the
 
1583
        # initial 'who are you' and a second 'who are you' with the new nonce)
 
1584
        self.assertEqual(2, self.server.auth_required_errors)
 
1585
 
 
1586
 
 
1587
 
 
1588
class TestProxyAuth(TestAuth):
 
1589
    """Test proxy authentication schemes."""
 
1590
 
 
1591
    _auth_header = 'Proxy-authorization'
 
1592
    _password_prompt_prefix='Proxy '
 
1593
 
 
1594
    def setUp(self):
 
1595
        super(TestProxyAuth, self).setUp()
 
1596
        self._old_env = {}
 
1597
        self.addCleanup(self._restore_env)
 
1598
        # Override the contents to avoid false positives
 
1599
        self.build_tree_contents([('a', 'not proxied contents of a\n'),
 
1600
                                  ('b', 'not proxied contents of b\n'),
 
1601
                                  ('a-proxied', 'contents of a\n'),
 
1602
                                  ('b-proxied', 'contents of b\n'),
 
1603
                                  ])
 
1604
 
 
1605
    def create_transport_readonly_server(self):
 
1606
        if self._auth_scheme == 'basic':
 
1607
            server = http_utils.ProxyBasicAuthServer(
 
1608
                protocol_version=self._protocol_version)
 
1609
        else:
 
1610
            if self._auth_scheme != 'digest':
 
1611
                raise AssertionError('Unknown auth scheme: %r'
 
1612
                                     % self._auth_scheme)
 
1613
            server = http_utils.ProxyDigestAuthServer(
 
1614
                protocol_version=self._protocol_version)
 
1615
        return server
 
1616
 
 
1617
    def get_user_transport(self, user=None, password=None):
 
1618
        self._install_env({'all_proxy': self.get_user_url(user, password)})
 
1619
        return self._transport(self.server.get_url())
 
1620
 
 
1621
    def _install_env(self, env):
 
1622
        for name, value in env.iteritems():
 
1623
            self._old_env[name] = osutils.set_or_unset_env(name, value)
 
1624
 
 
1625
    def _restore_env(self):
 
1626
        for name, value in self._old_env.iteritems():
 
1627
            osutils.set_or_unset_env(name, value)
 
1628
 
 
1629
    def test_empty_pass(self):
 
1630
        if self._testing_pycurl():
 
1631
            import pycurl
 
1632
            if pycurl.version_info()[1] < '7.16.0':
 
1633
                raise tests.KnownFailure(
 
1634
                    'pycurl < 7.16.0 does not handle empty proxy passwords')
 
1635
        super(TestProxyAuth, self).test_empty_pass()
 
1636
 
 
1637
 
 
1638
class SampleSocket(object):
 
1639
    """A socket-like object for use in testing the HTTP request handler."""
 
1640
 
 
1641
    def __init__(self, socket_read_content):
 
1642
        """Constructs a sample socket.
 
1643
 
 
1644
        :param socket_read_content: a byte sequence
 
1645
        """
 
1646
        # Use plain python StringIO so we can monkey-patch the close method to
 
1647
        # not discard the contents.
 
1648
        from StringIO import StringIO
 
1649
        self.readfile = StringIO(socket_read_content)
 
1650
        self.writefile = StringIO()
 
1651
        self.writefile.close = lambda: None
 
1652
 
 
1653
    def makefile(self, mode='r', bufsize=None):
 
1654
        if 'r' in mode:
 
1655
            return self.readfile
 
1656
        else:
 
1657
            return self.writefile
 
1658
 
 
1659
 
 
1660
class SmartHTTPTunnellingTest(tests.TestCaseWithTransport):
 
1661
 
 
1662
    def setUp(self):
 
1663
        super(SmartHTTPTunnellingTest, self).setUp()
 
1664
        # We use the VFS layer as part of HTTP tunnelling tests.
 
1665
        self._captureVar('BZR_NO_SMART_VFS', None)
 
1666
        self.transport_readonly_server = http_utils.HTTPServerWithSmarts
 
1667
 
 
1668
    def create_transport_readonly_server(self):
 
1669
        return http_utils.HTTPServerWithSmarts(
 
1670
            protocol_version=self._protocol_version)
 
1671
 
 
1672
    def test_open_bzrdir(self):
 
1673
        branch = self.make_branch('relpath')
 
1674
        http_server = self.get_readonly_server()
 
1675
        url = http_server.get_url() + 'relpath'
 
1676
        bd = bzrdir.BzrDir.open(url)
 
1677
        self.assertIsInstance(bd, _mod_remote.RemoteBzrDir)
 
1678
 
 
1679
    def test_bulk_data(self):
 
1680
        # We should be able to send and receive bulk data in a single message.
 
1681
        # The 'readv' command in the smart protocol both sends and receives
 
1682
        # bulk data, so we use that.
 
1683
        self.build_tree(['data-file'])
 
1684
        http_server = self.get_readonly_server()
 
1685
        http_transport = self._transport(http_server.get_url())
 
1686
        medium = http_transport.get_smart_medium()
 
1687
        # Since we provide the medium, the url below will be mostly ignored
 
1688
        # during the test, as long as the path is '/'.
 
1689
        remote_transport = remote.RemoteTransport('bzr://fake_host/',
 
1690
                                                  medium=medium)
 
1691
        self.assertEqual(
 
1692
            [(0, "c")], list(remote_transport.readv("data-file", [(0,1)])))
 
1693
 
 
1694
    def test_http_send_smart_request(self):
 
1695
 
 
1696
        post_body = 'hello\n'
 
1697
        expected_reply_body = 'ok\x012\n'
 
1698
 
 
1699
        http_server = self.get_readonly_server()
 
1700
        http_transport = self._transport(http_server.get_url())
 
1701
        medium = http_transport.get_smart_medium()
 
1702
        response = medium.send_http_smart_request(post_body)
 
1703
        reply_body = response.read()
 
1704
        self.assertEqual(expected_reply_body, reply_body)
 
1705
 
 
1706
    def test_smart_http_server_post_request_handler(self):
 
1707
        httpd = self.get_readonly_server()._get_httpd()
 
1708
 
 
1709
        socket = SampleSocket(
 
1710
            'POST /.bzr/smart %s \r\n' % self._protocol_version
 
1711
            # HTTP/1.1 posts must have a Content-Length (but it doesn't hurt
 
1712
            # for 1.0)
 
1713
            + 'Content-Length: 6\r\n'
 
1714
            '\r\n'
 
1715
            'hello\n')
 
1716
        # Beware: the ('localhost', 80) below is the
 
1717
        # client_address parameter, but we don't have one because
 
1718
        # we have defined a socket which is not bound to an
 
1719
        # address. The test framework never uses this client
 
1720
        # address, so far...
 
1721
        request_handler = http_utils.SmartRequestHandler(socket,
 
1722
                                                         ('localhost', 80),
 
1723
                                                         httpd)
 
1724
        response = socket.writefile.getvalue()
 
1725
        self.assertStartsWith(response, '%s 200 ' % self._protocol_version)
 
1726
        # This includes the end of the HTTP headers, and all the body.
 
1727
        expected_end_of_response = '\r\n\r\nok\x012\n'
 
1728
        self.assertEndsWith(response, expected_end_of_response)
 
1729
 
 
1730
 
 
1731
class ForbiddenRequestHandler(http_server.TestingHTTPRequestHandler):
 
1732
    """No smart server here request handler."""
 
1733
 
 
1734
    def do_POST(self):
 
1735
        self.send_error(403, "Forbidden")
 
1736
 
 
1737
 
 
1738
class SmartClientAgainstNotSmartServer(TestSpecificRequestHandler):
 
1739
    """Test smart client behaviour against an http server without smarts."""
 
1740
 
 
1741
    _req_handler_class = ForbiddenRequestHandler
 
1742
 
 
1743
    def test_probe_smart_server(self):
 
1744
        """Test error handling against server refusing smart requests."""
 
1745
        server = self.get_readonly_server()
 
1746
        t = self._transport(server.get_url())
 
1747
        # No need to build a valid smart request here, the server will not even
 
1748
        # try to interpret it.
 
1749
        self.assertRaises(errors.SmartProtocolError,
 
1750
                          t.send_http_smart_request, 'whatever')
 
1751