~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Patch Queue Manager
  • Date: 2012-01-18 16:23:31 UTC
  • mfrom: (6439.1.1 work)
  • Revision ID: pqm@pqm.ubuntu.com-20120118162331-md4sf1tw6hyuw344
(vila) Ensure people get an easy access to the release details from
 announcements. (Vincent Ladeuil)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import sys
 
21
import threading
 
22
import traceback
 
23
 
 
24
 
 
25
from bzrlib import (
 
26
    cethread,
 
27
    errors,
 
28
    osutils,
 
29
    transport,
 
30
    urlutils,
 
31
    )
 
32
from bzrlib.transport import (
 
33
    chroot,
 
34
    pathfilter,
 
35
    )
 
36
from bzrlib.smart import (
 
37
    medium,
 
38
    server,
 
39
    )
 
40
 
 
41
 
 
42
def debug_threads():
 
43
    # FIXME: There is a dependency loop between bzrlib.tests and
 
44
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
45
    # defining this function is enough for our needs. -- vila 20100611
 
46
    from bzrlib import tests
 
47
    return 'threads' in tests.selftest_debug_flags
 
48
 
 
49
 
 
50
class TestServer(transport.Server):
 
51
    """A Transport Server dedicated to tests.
 
52
 
 
53
    The TestServer interface provides a server for a given transport. We use
 
54
    these servers as loopback testing tools. For any given transport the
 
55
    Servers it provides must either allow writing, or serve the contents
 
56
    of os.getcwdu() at the time start_server is called.
 
57
 
 
58
    Note that these are real servers - they must implement all the things
 
59
    that we want bzr transports to take advantage of.
 
60
    """
 
61
 
 
62
    def get_url(self):
 
63
        """Return a url for this server.
 
64
 
 
65
        If the transport does not represent a disk directory (i.e. it is
 
66
        a database like svn, or a memory only transport, it should return
 
67
        a connection to a newly established resource for this Server.
 
68
        Otherwise it should return a url that will provide access to the path
 
69
        that was os.getcwdu() when start_server() was called.
 
70
 
 
71
        Subsequent calls will return the same resource.
 
72
        """
 
73
        raise NotImplementedError
 
74
 
 
75
    def get_bogus_url(self):
 
76
        """Return a url for this protocol, that will fail to connect.
 
77
 
 
78
        This may raise NotImplementedError to indicate that this server cannot
 
79
        provide bogus urls.
 
80
        """
 
81
        raise NotImplementedError
 
82
 
 
83
 
 
84
class LocalURLServer(TestServer):
 
85
    """A pretend server for local transports, using file:// urls.
 
86
 
 
87
    Of course no actual server is required to access the local filesystem, so
 
88
    this just exists to tell the test code how to get to it.
 
89
    """
 
90
 
 
91
    def start_server(self):
 
92
        pass
 
93
 
 
94
    def get_url(self):
 
95
        """See Transport.Server.get_url."""
 
96
        return urlutils.local_path_to_url('')
 
97
 
 
98
 
 
99
class DecoratorServer(TestServer):
 
100
    """Server for the TransportDecorator for testing with.
 
101
 
 
102
    To use this when subclassing TransportDecorator, override override the
 
103
    get_decorator_class method.
 
104
    """
 
105
 
 
106
    def start_server(self, server=None):
 
107
        """See bzrlib.transport.Server.start_server.
 
108
 
 
109
        :server: decorate the urls given by server. If not provided a
 
110
        LocalServer is created.
 
111
        """
 
112
        if server is not None:
 
113
            self._made_server = False
 
114
            self._server = server
 
115
        else:
 
116
            self._made_server = True
 
117
            self._server = LocalURLServer()
 
118
            self._server.start_server()
 
119
 
 
120
    def stop_server(self):
 
121
        if self._made_server:
 
122
            self._server.stop_server()
 
123
 
 
124
    def get_decorator_class(self):
 
125
        """Return the class of the decorators we should be constructing."""
 
126
        raise NotImplementedError(self.get_decorator_class)
 
127
 
 
128
    def get_url_prefix(self):
 
129
        """What URL prefix does this decorator produce?"""
 
130
        return self.get_decorator_class()._get_url_prefix()
 
131
 
 
132
    def get_bogus_url(self):
 
133
        """See bzrlib.transport.Server.get_bogus_url."""
 
134
        return self.get_url_prefix() + self._server.get_bogus_url()
 
135
 
 
136
    def get_url(self):
 
137
        """See bzrlib.transport.Server.get_url."""
 
138
        return self.get_url_prefix() + self._server.get_url()
 
139
 
 
140
 
 
141
class BrokenRenameServer(DecoratorServer):
 
142
    """Server for the BrokenRenameTransportDecorator for testing with."""
 
143
 
 
144
    def get_decorator_class(self):
 
145
        from bzrlib.transport import brokenrename
 
146
        return brokenrename.BrokenRenameTransportDecorator
 
147
 
 
148
 
 
149
class FakeNFSServer(DecoratorServer):
 
150
    """Server for the FakeNFSTransportDecorator for testing with."""
 
151
 
 
152
    def get_decorator_class(self):
 
153
        from bzrlib.transport import fakenfs
 
154
        return fakenfs.FakeNFSTransportDecorator
 
155
 
 
156
 
 
157
class FakeVFATServer(DecoratorServer):
 
158
    """A server that suggests connections through FakeVFATTransportDecorator
 
159
 
 
160
    For use in testing.
 
161
    """
 
162
 
 
163
    def get_decorator_class(self):
 
164
        from bzrlib.transport import fakevfat
 
165
        return fakevfat.FakeVFATTransportDecorator
 
166
 
 
167
 
 
168
class LogDecoratorServer(DecoratorServer):
 
169
    """Server for testing."""
 
170
 
 
171
    def get_decorator_class(self):
 
172
        from bzrlib.transport import log
 
173
        return log.TransportLogDecorator
 
174
 
 
175
 
 
176
class NoSmartTransportServer(DecoratorServer):
 
177
    """Server for the NoSmartTransportDecorator for testing with."""
 
178
 
 
179
    def get_decorator_class(self):
 
180
        from bzrlib.transport import nosmart
 
181
        return nosmart.NoSmartTransportDecorator
 
182
 
 
183
 
 
184
class ReadonlyServer(DecoratorServer):
 
185
    """Server for the ReadonlyTransportDecorator for testing with."""
 
186
 
 
187
    def get_decorator_class(self):
 
188
        from bzrlib.transport import readonly
 
189
        return readonly.ReadonlyTransportDecorator
 
190
 
 
191
 
 
192
class TraceServer(DecoratorServer):
 
193
    """Server for the TransportTraceDecorator for testing with."""
 
194
 
 
195
    def get_decorator_class(self):
 
196
        from bzrlib.transport import trace
 
197
        return trace.TransportTraceDecorator
 
198
 
 
199
 
 
200
class UnlistableServer(DecoratorServer):
 
201
    """Server for the UnlistableTransportDecorator for testing with."""
 
202
 
 
203
    def get_decorator_class(self):
 
204
        from bzrlib.transport import unlistable
 
205
        return unlistable.UnlistableTransportDecorator
 
206
 
 
207
 
 
208
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
 
209
 
 
210
    def __init__(self):
 
211
        """TestingPathFilteringServer is not usable until start_server
 
212
        is called."""
 
213
 
 
214
    def start_server(self, backing_server=None):
 
215
        """Setup the Chroot on backing_server."""
 
216
        if backing_server is not None:
 
217
            self.backing_transport = transport.get_transport_from_url(
 
218
                backing_server.get_url())
 
219
        else:
 
220
            self.backing_transport = transport.get_transport_from_path('.')
 
221
        self.backing_transport.clone('added-by-filter').ensure_base()
 
222
        self.filter_func = lambda x: 'added-by-filter/' + x
 
223
        super(TestingPathFilteringServer, self).start_server()
 
224
 
 
225
    def get_bogus_url(self):
 
226
        raise NotImplementedError
 
227
 
 
228
 
 
229
class TestingChrootServer(chroot.ChrootServer):
 
230
 
 
231
    def __init__(self):
 
232
        """TestingChrootServer is not usable until start_server is called."""
 
233
        super(TestingChrootServer, self).__init__(None)
 
234
 
 
235
    def start_server(self, backing_server=None):
 
236
        """Setup the Chroot on backing_server."""
 
237
        if backing_server is not None:
 
238
            self.backing_transport = transport.get_transport_from_url(
 
239
                backing_server.get_url())
 
240
        else:
 
241
            self.backing_transport = transport.get_transport_from_path('.')
 
242
        super(TestingChrootServer, self).start_server()
 
243
 
 
244
    def get_bogus_url(self):
 
245
        raise NotImplementedError
 
246
 
 
247
 
 
248
class TestThread(cethread.CatchingExceptionThread):
 
249
 
 
250
    def join(self, timeout=5):
 
251
        """Overrides to use a default timeout.
 
252
 
 
253
        The default timeout is set to 5 and should expire only when a thread
 
254
        serving a client connection is hung.
 
255
        """
 
256
        super(TestThread, self).join(timeout)
 
257
        if timeout and self.isAlive():
 
258
            # The timeout expired without joining the thread, the thread is
 
259
            # therefore stucked and that's a failure as far as the test is
 
260
            # concerned. We used to hang here.
 
261
 
 
262
            # FIXME: we need to kill the thread, but as far as the test is
 
263
            # concerned, raising an assertion is too strong. On most of the
 
264
            # platforms, this doesn't occur, so just mentioning the problem is
 
265
            # enough for now -- vila 2010824
 
266
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
267
            #raise AssertionError('thread %s hung' % (self.name,))
 
268
 
 
269
 
 
270
class TestingTCPServerMixin(object):
 
271
    """Mixin to support running SocketServer.TCPServer in a thread.
 
272
 
 
273
    Tests are connecting from the main thread, the server has to be run in a
 
274
    separate thread.
 
275
    """
 
276
 
 
277
    def __init__(self):
 
278
        self.started = threading.Event()
 
279
        self.serving = None
 
280
        self.stopped = threading.Event()
 
281
        # We collect the resources used by the clients so we can release them
 
282
        # when shutting down
 
283
        self.clients = []
 
284
        self.ignored_exceptions = None
 
285
 
 
286
    def server_bind(self):
 
287
        self.socket.bind(self.server_address)
 
288
        self.server_address = self.socket.getsockname()
 
289
 
 
290
    def serve(self):
 
291
        self.serving = True
 
292
        # We are listening and ready to accept connections
 
293
        self.started.set()
 
294
        try:
 
295
            while self.serving:
 
296
                # Really a connection but the python framework is generic and
 
297
                # call them requests
 
298
                self.handle_request()
 
299
            # Let's close the listening socket
 
300
            self.server_close()
 
301
        finally:
 
302
            self.stopped.set()
 
303
 
 
304
    def handle_request(self):
 
305
        """Handle one request.
 
306
 
 
307
        The python version swallows some socket exceptions and we don't use
 
308
        timeout, so we override it to better control the server behavior.
 
309
        """
 
310
        request, client_address = self.get_request()
 
311
        if self.verify_request(request, client_address):
 
312
            try:
 
313
                self.process_request(request, client_address)
 
314
            except:
 
315
                self.handle_error(request, client_address)
 
316
        else:
 
317
            self.close_request(request)
 
318
 
 
319
    def get_request(self):
 
320
        return self.socket.accept()
 
321
 
 
322
    def verify_request(self, request, client_address):
 
323
        """Verify the request.
 
324
 
 
325
        Return True if we should proceed with this request, False if we should
 
326
        not even touch a single byte in the socket ! This is useful when we
 
327
        stop the server with a dummy last connection.
 
328
        """
 
329
        return self.serving
 
330
 
 
331
    def handle_error(self, request, client_address):
 
332
        # Stop serving and re-raise the last exception seen
 
333
        self.serving = False
 
334
        # The following can be used for debugging purposes, it will display the
 
335
        # exception and the traceback just when it occurs instead of waiting
 
336
        # for the thread to be joined.
 
337
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
338
 
 
339
        # We call close_request manually, because we are going to raise an
 
340
        # exception. The SocketServer implementation calls:
 
341
        #   handle_error(...)
 
342
        #   close_request(...)
 
343
        # But because we raise the exception, close_request will never be
 
344
        # triggered. This helps client not block waiting for a response when
 
345
        # the server gets an exception.
 
346
        self.close_request(request)
 
347
        raise
 
348
 
 
349
    def ignored_exceptions_during_shutdown(self, e):
 
350
        if sys.platform == 'win32':
 
351
            accepted_errnos = [errno.EBADF,
 
352
                               errno.EPIPE,
 
353
                               errno.WSAEBADF,
 
354
                               errno.WSAECONNRESET,
 
355
                               errno.WSAENOTCONN,
 
356
                               errno.WSAESHUTDOWN,
 
357
                               ]
 
358
        else:
 
359
            accepted_errnos = [errno.EBADF,
 
360
                               errno.ECONNRESET,
 
361
                               errno.ENOTCONN,
 
362
                               errno.EPIPE,
 
363
                               ]
 
364
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
365
            return True
 
366
        return False
 
367
 
 
368
    # The following methods are called by the main thread
 
369
 
 
370
    def stop_client_connections(self):
 
371
        while self.clients:
 
372
            c = self.clients.pop()
 
373
            self.shutdown_client(c)
 
374
 
 
375
    def shutdown_socket(self, sock):
 
376
        """Properly shutdown a socket.
 
377
 
 
378
        This should be called only when no other thread is trying to use the
 
379
        socket.
 
380
        """
 
381
        try:
 
382
            sock.shutdown(socket.SHUT_RDWR)
 
383
            sock.close()
 
384
        except Exception, e:
 
385
            if self.ignored_exceptions(e):
 
386
                pass
 
387
            else:
 
388
                raise
 
389
 
 
390
    # The following methods are called by the main thread
 
391
 
 
392
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
393
        self.ignored_exceptions = ignored_exceptions
 
394
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
395
 
 
396
    def _pending_exception(self, thread):
 
397
        """Raise server uncaught exception.
 
398
 
 
399
        Daughter classes can override this if they use daughter threads.
 
400
        """
 
401
        thread.pending_exception()
 
402
 
 
403
 
 
404
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
405
 
 
406
    def __init__(self, server_address, request_handler_class):
 
407
        TestingTCPServerMixin.__init__(self)
 
408
        SocketServer.TCPServer.__init__(self, server_address,
 
409
                                        request_handler_class)
 
410
 
 
411
    def get_request(self):
 
412
        """Get the request and client address from the socket."""
 
413
        sock, addr = TestingTCPServerMixin.get_request(self)
 
414
        self.clients.append((sock, addr))
 
415
        return sock, addr
 
416
 
 
417
    # The following methods are called by the main thread
 
418
 
 
419
    def shutdown_client(self, client):
 
420
        sock, addr = client
 
421
        self.shutdown_socket(sock)
 
422
 
 
423
 
 
424
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
425
                                SocketServer.ThreadingTCPServer):
 
426
 
 
427
    def __init__(self, server_address, request_handler_class):
 
428
        TestingTCPServerMixin.__init__(self)
 
429
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
430
                                                 request_handler_class)
 
431
 
 
432
    def get_request(self):
 
433
        """Get the request and client address from the socket."""
 
434
        sock, addr = TestingTCPServerMixin.get_request(self)
 
435
        # The thread is not created yet, it will be updated in process_request
 
436
        self.clients.append((sock, addr, None))
 
437
        return sock, addr
 
438
 
 
439
    def process_request_thread(self, started, detached, stopped,
 
440
                               request, client_address):
 
441
        started.set()
 
442
        # We will be on our own once the server tells us we're detached
 
443
        detached.wait()
 
444
        SocketServer.ThreadingTCPServer.process_request_thread(
 
445
            self, request, client_address)
 
446
        self.close_request(request)
 
447
        stopped.set()
 
448
 
 
449
    def process_request(self, request, client_address):
 
450
        """Start a new thread to process the request."""
 
451
        started = threading.Event()
 
452
        detached = threading.Event()
 
453
        stopped = threading.Event()
 
454
        t = TestThread(
 
455
            sync_event=stopped,
 
456
            name='%s -> %s' % (client_address, self.server_address),
 
457
            target = self.process_request_thread,
 
458
            args = (started, detached, stopped, request, client_address))
 
459
        # Update the client description
 
460
        self.clients.pop()
 
461
        self.clients.append((request, client_address, t))
 
462
        # Propagate the exception handler since we must use the same one as
 
463
        # TestingTCPServer for connections running in their own threads.
 
464
        t.set_ignored_exceptions(self.ignored_exceptions)
 
465
        t.start()
 
466
        started.wait()
 
467
        # If an exception occured during the thread start, it will get raised.
 
468
        t.pending_exception()
 
469
        if debug_threads():
 
470
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
471
        # Tell the thread, it's now on its own for exception handling.
 
472
        detached.set()
 
473
 
 
474
    # The following methods are called by the main thread
 
475
 
 
476
    def shutdown_client(self, client):
 
477
        sock, addr, connection_thread = client
 
478
        self.shutdown_socket(sock)
 
479
        if connection_thread is not None:
 
480
            # The thread has been created only if the request is processed but
 
481
            # after the connection is inited. This could happen during server
 
482
            # shutdown. If an exception occurred in the thread it will be
 
483
            # re-raised
 
484
            if debug_threads():
 
485
                sys.stderr.write('Client thread %s will be joined\n'
 
486
                                 % (connection_thread.name,))
 
487
            connection_thread.join()
 
488
 
 
489
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
490
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
491
                                                     ignored_exceptions)
 
492
        for sock, addr, connection_thread in self.clients:
 
493
            if connection_thread is not None:
 
494
                connection_thread.set_ignored_exceptions(
 
495
                    self.ignored_exceptions)
 
496
 
 
497
    def _pending_exception(self, thread):
 
498
        for sock, addr, connection_thread in self.clients:
 
499
            if connection_thread is not None:
 
500
                connection_thread.pending_exception()
 
501
        TestingTCPServerMixin._pending_exception(self, thread)
 
502
 
 
503
 
 
504
class TestingTCPServerInAThread(transport.Server):
 
505
    """A server in a thread that re-raise thread exceptions."""
 
506
 
 
507
    def __init__(self, server_address, server_class, request_handler_class):
 
508
        self.server_class = server_class
 
509
        self.request_handler_class = request_handler_class
 
510
        self.host, self.port = server_address
 
511
        self.server = None
 
512
        self._server_thread = None
 
513
 
 
514
    def __repr__(self):
 
515
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
516
 
 
517
    def create_server(self):
 
518
        return self.server_class((self.host, self.port),
 
519
                                 self.request_handler_class)
 
520
 
 
521
    def start_server(self):
 
522
        self.server = self.create_server()
 
523
        self._server_thread = TestThread(
 
524
            sync_event=self.server.started,
 
525
            target=self.run_server)
 
526
        self._server_thread.start()
 
527
        # Wait for the server thread to start (i.e. release the lock)
 
528
        self.server.started.wait()
 
529
        # Get the real address, especially the port
 
530
        self.host, self.port = self.server.server_address
 
531
        self._server_thread.name = self.server.server_address
 
532
        if debug_threads():
 
533
            sys.stderr.write('Server thread %s started\n'
 
534
                             % (self._server_thread.name,))
 
535
        # If an exception occured during the server start, it will get raised,
 
536
        # otherwise, the server is blocked on its accept() call.
 
537
        self._server_thread.pending_exception()
 
538
        # From now on, we'll use a different event to ensure the server can set
 
539
        # its exception
 
540
        self._server_thread.set_sync_event(self.server.stopped)
 
541
 
 
542
    def run_server(self):
 
543
        self.server.serve()
 
544
 
 
545
    def stop_server(self):
 
546
        if self.server is None:
 
547
            return
 
548
        try:
 
549
            # The server has been started successfully, shut it down now.  As
 
550
            # soon as we stop serving, no more connection are accepted except
 
551
            # one to get out of the blocking listen.
 
552
            self.set_ignored_exceptions(
 
553
                self.server.ignored_exceptions_during_shutdown)
 
554
            self.server.serving = False
 
555
            if debug_threads():
 
556
                sys.stderr.write('Server thread %s will be joined\n'
 
557
                                 % (self._server_thread.name,))
 
558
            # The server is listening for a last connection, let's give it:
 
559
            last_conn = None
 
560
            try:
 
561
                last_conn = osutils.connect_socket((self.host, self.port))
 
562
            except socket.error, e:
 
563
                # But ignore connection errors as the point is to unblock the
 
564
                # server thread, it may happen that it's not blocked or even
 
565
                # not started.
 
566
                pass
 
567
            # We start shutting down the clients while the server itself is
 
568
            # shutting down.
 
569
            self.server.stop_client_connections()
 
570
            # Now we wait for the thread running self.server.serve() to finish
 
571
            self.server.stopped.wait()
 
572
            if last_conn is not None:
 
573
                # Close the last connection without trying to use it. The
 
574
                # server will not process a single byte on that socket to avoid
 
575
                # complications (SSL starts with a handshake for example).
 
576
                last_conn.close()
 
577
            # Check for any exception that could have occurred in the server
 
578
            # thread
 
579
            try:
 
580
                self._server_thread.join()
 
581
            except Exception, e:
 
582
                if self.server.ignored_exceptions(e):
 
583
                    pass
 
584
                else:
 
585
                    raise
 
586
        finally:
 
587
            # Make sure we can be called twice safely, note that this means
 
588
            # that we will raise a single exception even if several occurred in
 
589
            # the various threads involved.
 
590
            self.server = None
 
591
 
 
592
    def set_ignored_exceptions(self, ignored_exceptions):
 
593
        """Install an exception handler for the server."""
 
594
        self.server.set_ignored_exceptions(self._server_thread,
 
595
                                           ignored_exceptions)
 
596
 
 
597
    def pending_exception(self):
 
598
        """Raise uncaught exception in the server."""
 
599
        self.server._pending_exception(self._server_thread)
 
600
 
 
601
 
 
602
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
603
                                    medium.SmartServerSocketStreamMedium):
 
604
 
 
605
    def __init__(self, request, client_address, server):
 
606
        medium.SmartServerSocketStreamMedium.__init__(
 
607
            self, request, server.backing_transport,
 
608
            server.root_client_path,
 
609
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
610
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
611
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
612
                                                 server)
 
613
 
 
614
    def handle(self):
 
615
        try:
 
616
            while not self.finished:
 
617
                server_protocol = self._build_protocol()
 
618
                self._serve_one_request(server_protocol)
 
619
        except errors.ConnectionTimeout:
 
620
            # idle connections aren't considered a failure of the server
 
621
            return
 
622
 
 
623
 
 
624
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
 
625
 
 
626
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
627
 
 
628
    def __init__(self, server_address, request_handler_class,
 
629
                 backing_transport, root_client_path):
 
630
        TestingThreadingTCPServer.__init__(self, server_address,
 
631
                                           request_handler_class)
 
632
        server.SmartTCPServer.__init__(self, backing_transport,
 
633
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
634
 
 
635
    def serve(self):
 
636
        self.run_server_started_hooks()
 
637
        try:
 
638
            TestingThreadingTCPServer.serve(self)
 
639
        finally:
 
640
            self.run_server_stopped_hooks()
 
641
 
 
642
    def get_url(self):
 
643
        """Return the url of the server"""
 
644
        return "bzr://%s:%d/" % self.server_address
 
645
 
 
646
 
 
647
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
 
648
    """Server suitable for use by transport tests.
 
649
 
 
650
    This server is backed by the process's cwd.
 
651
    """
 
652
    def __init__(self, thread_name_suffix=''):
 
653
        self.client_path_extra = None
 
654
        self.thread_name_suffix = thread_name_suffix
 
655
        self.host = '127.0.0.1'
 
656
        self.port = 0
 
657
        super(SmartTCPServer_for_testing, self).__init__(
 
658
                (self.host, self.port),
 
659
                TestingSmartServer,
 
660
                TestingSmartConnectionHandler)
 
661
 
 
662
    def create_server(self):
 
663
        return self.server_class((self.host, self.port),
 
664
                                 self.request_handler_class,
 
665
                                 self.backing_transport,
 
666
                                 self.root_client_path)
 
667
 
 
668
 
 
669
    def start_server(self, backing_transport_server=None,
 
670
                     client_path_extra='/extra/'):
 
671
        """Set up server for testing.
 
672
 
 
673
        :param backing_transport_server: backing server to use.  If not
 
674
            specified, a LocalURLServer at the current working directory will
 
675
            be used.
 
676
        :param client_path_extra: a path segment starting with '/' to append to
 
677
            the root URL for this server.  For instance, a value of '/foo/bar/'
 
678
            will mean the root of the backing transport will be published at a
 
679
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
 
680
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
 
681
            by default will fail unless they do the necessary path translation.
 
682
        """
 
683
        if not client_path_extra.startswith('/'):
 
684
            raise ValueError(client_path_extra)
 
685
        self.root_client_path = self.client_path_extra = client_path_extra
 
686
        from bzrlib.transport.chroot import ChrootServer
 
687
        if backing_transport_server is None:
 
688
            backing_transport_server = LocalURLServer()
 
689
        self.chroot_server = ChrootServer(
 
690
            self.get_backing_transport(backing_transport_server))
 
691
        self.chroot_server.start_server()
 
692
        self.backing_transport = transport.get_transport_from_url(
 
693
            self.chroot_server.get_url())
 
694
        super(SmartTCPServer_for_testing, self).start_server()
 
695
 
 
696
    def stop_server(self):
 
697
        try:
 
698
            super(SmartTCPServer_for_testing, self).stop_server()
 
699
        finally:
 
700
            self.chroot_server.stop_server()
 
701
 
 
702
    def get_backing_transport(self, backing_transport_server):
 
703
        """Get a backing transport from a server we are decorating."""
 
704
        return transport.get_transport_from_url(
 
705
            backing_transport_server.get_url())
 
706
 
 
707
    def get_url(self):
 
708
        url = self.server.get_url()
 
709
        return url[:-1] + self.client_path_extra
 
710
 
 
711
    def get_bogus_url(self):
 
712
        """Return a URL which will fail to connect"""
 
713
        return 'bzr://127.0.0.1:1/'
 
714
 
 
715
 
 
716
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
 
717
    """Get a readonly server for testing."""
 
718
 
 
719
    def get_backing_transport(self, backing_transport_server):
 
720
        """Get a backing transport from a server we are decorating."""
 
721
        url = 'readonly+' + backing_transport_server.get_url()
 
722
        return transport.get_transport_from_url(url)
 
723
 
 
724
 
 
725
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
 
726
    """A variation of SmartTCPServer_for_testing that limits the client to
 
727
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
 
728
    """
 
729
 
 
730
    def get_url(self):
 
731
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
 
732
        url = 'bzr-v2://' + url[len('bzr://'):]
 
733
        return url
 
734
 
 
735
 
 
736
class ReadonlySmartTCPServer_for_testing_v2_only(
 
737
    SmartTCPServer_for_testing_v2_only):
 
738
    """Get a readonly server for testing."""
 
739
 
 
740
    def get_backing_transport(self, backing_transport_server):
 
741
        """Get a backing transport from a server we are decorating."""
 
742
        url = 'readonly+' + backing_transport_server.get_url()
 
743
        return transport.get_transport_from_url(url)