~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: John Arbash Meinel
  • Author(s): Mark Hammond
  • Date: 2008-09-09 17:02:21 UTC
  • mto: This revision was merged to the branch mainline in revision 3697.
  • Revision ID: john@arbash-meinel.com-20080909170221-svim3jw2mrz0amp3
An updated transparent icon for bzr.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010, 2011 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
import errno
18
 
import socket
19
 
import SocketServer
20
 
import sys
21
 
import threading
22
 
import traceback
23
 
 
24
 
 
25
 
from bzrlib import (
26
 
    cethread,
27
 
    errors,
28
 
    osutils,
29
 
    transport,
30
 
    urlutils,
31
 
    )
32
 
from bzrlib.transport import (
33
 
    chroot,
34
 
    pathfilter,
35
 
    )
36
 
from bzrlib.smart import (
37
 
    medium,
38
 
    server,
39
 
    )
40
 
 
41
 
 
42
 
def debug_threads():
43
 
    # FIXME: There is a dependency loop between bzrlib.tests and
44
 
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
45
 
    # defining this function is enough for our needs. -- vila 20100611
46
 
    from bzrlib import tests
47
 
    return 'threads' in tests.selftest_debug_flags
48
 
 
49
 
 
50
 
class TestServer(transport.Server):
51
 
    """A Transport Server dedicated to tests.
52
 
 
53
 
    The TestServer interface provides a server for a given transport. We use
54
 
    these servers as loopback testing tools. For any given transport the
55
 
    Servers it provides must either allow writing, or serve the contents
56
 
    of os.getcwdu() at the time start_server is called.
57
 
 
58
 
    Note that these are real servers - they must implement all the things
59
 
    that we want bzr transports to take advantage of.
60
 
    """
61
 
 
62
 
    def get_url(self):
63
 
        """Return a url for this server.
64
 
 
65
 
        If the transport does not represent a disk directory (i.e. it is
66
 
        a database like svn, or a memory only transport, it should return
67
 
        a connection to a newly established resource for this Server.
68
 
        Otherwise it should return a url that will provide access to the path
69
 
        that was os.getcwdu() when start_server() was called.
70
 
 
71
 
        Subsequent calls will return the same resource.
72
 
        """
73
 
        raise NotImplementedError
74
 
 
75
 
    def get_bogus_url(self):
76
 
        """Return a url for this protocol, that will fail to connect.
77
 
 
78
 
        This may raise NotImplementedError to indicate that this server cannot
79
 
        provide bogus urls.
80
 
        """
81
 
        raise NotImplementedError
82
 
 
83
 
 
84
 
class LocalURLServer(TestServer):
85
 
    """A pretend server for local transports, using file:// urls.
86
 
 
87
 
    Of course no actual server is required to access the local filesystem, so
88
 
    this just exists to tell the test code how to get to it.
89
 
    """
90
 
 
91
 
    def start_server(self):
92
 
        pass
93
 
 
94
 
    def get_url(self):
95
 
        """See Transport.Server.get_url."""
96
 
        return urlutils.local_path_to_url('')
97
 
 
98
 
 
99
 
class DecoratorServer(TestServer):
100
 
    """Server for the TransportDecorator for testing with.
101
 
 
102
 
    To use this when subclassing TransportDecorator, override override the
103
 
    get_decorator_class method.
104
 
    """
105
 
 
106
 
    def start_server(self, server=None):
107
 
        """See bzrlib.transport.Server.start_server.
108
 
 
109
 
        :server: decorate the urls given by server. If not provided a
110
 
        LocalServer is created.
111
 
        """
112
 
        if server is not None:
113
 
            self._made_server = False
114
 
            self._server = server
115
 
        else:
116
 
            self._made_server = True
117
 
            self._server = LocalURLServer()
118
 
            self._server.start_server()
119
 
 
120
 
    def stop_server(self):
121
 
        if self._made_server:
122
 
            self._server.stop_server()
123
 
 
124
 
    def get_decorator_class(self):
125
 
        """Return the class of the decorators we should be constructing."""
126
 
        raise NotImplementedError(self.get_decorator_class)
127
 
 
128
 
    def get_url_prefix(self):
129
 
        """What URL prefix does this decorator produce?"""
130
 
        return self.get_decorator_class()._get_url_prefix()
131
 
 
132
 
    def get_bogus_url(self):
133
 
        """See bzrlib.transport.Server.get_bogus_url."""
134
 
        return self.get_url_prefix() + self._server.get_bogus_url()
135
 
 
136
 
    def get_url(self):
137
 
        """See bzrlib.transport.Server.get_url."""
138
 
        return self.get_url_prefix() + self._server.get_url()
139
 
 
140
 
 
141
 
class BrokenRenameServer(DecoratorServer):
142
 
    """Server for the BrokenRenameTransportDecorator for testing with."""
143
 
 
144
 
    def get_decorator_class(self):
145
 
        from bzrlib.transport import brokenrename
146
 
        return brokenrename.BrokenRenameTransportDecorator
147
 
 
148
 
 
149
 
class FakeNFSServer(DecoratorServer):
150
 
    """Server for the FakeNFSTransportDecorator for testing with."""
151
 
 
152
 
    def get_decorator_class(self):
153
 
        from bzrlib.transport import fakenfs
154
 
        return fakenfs.FakeNFSTransportDecorator
155
 
 
156
 
 
157
 
class FakeVFATServer(DecoratorServer):
158
 
    """A server that suggests connections through FakeVFATTransportDecorator
159
 
 
160
 
    For use in testing.
161
 
    """
162
 
 
163
 
    def get_decorator_class(self):
164
 
        from bzrlib.transport import fakevfat
165
 
        return fakevfat.FakeVFATTransportDecorator
166
 
 
167
 
 
168
 
class LogDecoratorServer(DecoratorServer):
169
 
    """Server for testing."""
170
 
 
171
 
    def get_decorator_class(self):
172
 
        from bzrlib.transport import log
173
 
        return log.TransportLogDecorator
174
 
 
175
 
 
176
 
class NoSmartTransportServer(DecoratorServer):
177
 
    """Server for the NoSmartTransportDecorator for testing with."""
178
 
 
179
 
    def get_decorator_class(self):
180
 
        from bzrlib.transport import nosmart
181
 
        return nosmart.NoSmartTransportDecorator
182
 
 
183
 
 
184
 
class ReadonlyServer(DecoratorServer):
185
 
    """Server for the ReadonlyTransportDecorator for testing with."""
186
 
 
187
 
    def get_decorator_class(self):
188
 
        from bzrlib.transport import readonly
189
 
        return readonly.ReadonlyTransportDecorator
190
 
 
191
 
 
192
 
class TraceServer(DecoratorServer):
193
 
    """Server for the TransportTraceDecorator for testing with."""
194
 
 
195
 
    def get_decorator_class(self):
196
 
        from bzrlib.transport import trace
197
 
        return trace.TransportTraceDecorator
198
 
 
199
 
 
200
 
class UnlistableServer(DecoratorServer):
201
 
    """Server for the UnlistableTransportDecorator for testing with."""
202
 
 
203
 
    def get_decorator_class(self):
204
 
        from bzrlib.transport import unlistable
205
 
        return unlistable.UnlistableTransportDecorator
206
 
 
207
 
 
208
 
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
209
 
 
210
 
    def __init__(self):
211
 
        """TestingPathFilteringServer is not usable until start_server
212
 
        is called."""
213
 
 
214
 
    def start_server(self, backing_server=None):
215
 
        """Setup the Chroot on backing_server."""
216
 
        if backing_server is not None:
217
 
            self.backing_transport = transport.get_transport_from_url(
218
 
                backing_server.get_url())
219
 
        else:
220
 
            self.backing_transport = transport.get_transport_from_path('.')
221
 
        self.backing_transport.clone('added-by-filter').ensure_base()
222
 
        self.filter_func = lambda x: 'added-by-filter/' + x
223
 
        super(TestingPathFilteringServer, self).start_server()
224
 
 
225
 
    def get_bogus_url(self):
226
 
        raise NotImplementedError
227
 
 
228
 
 
229
 
class TestingChrootServer(chroot.ChrootServer):
230
 
 
231
 
    def __init__(self):
232
 
        """TestingChrootServer is not usable until start_server is called."""
233
 
        super(TestingChrootServer, self).__init__(None)
234
 
 
235
 
    def start_server(self, backing_server=None):
236
 
        """Setup the Chroot on backing_server."""
237
 
        if backing_server is not None:
238
 
            self.backing_transport = transport.get_transport_from_url(
239
 
                backing_server.get_url())
240
 
        else:
241
 
            self.backing_transport = transport.get_transport_from_path('.')
242
 
        super(TestingChrootServer, self).start_server()
243
 
 
244
 
    def get_bogus_url(self):
245
 
        raise NotImplementedError
246
 
 
247
 
 
248
 
class TestThread(cethread.CatchingExceptionThread):
249
 
 
250
 
    def join(self, timeout=5):
251
 
        """Overrides to use a default timeout.
252
 
 
253
 
        The default timeout is set to 5 and should expire only when a thread
254
 
        serving a client connection is hung.
255
 
        """
256
 
        super(TestThread, self).join(timeout)
257
 
        if timeout and self.isAlive():
258
 
            # The timeout expired without joining the thread, the thread is
259
 
            # therefore stucked and that's a failure as far as the test is
260
 
            # concerned. We used to hang here.
261
 
 
262
 
            # FIXME: we need to kill the thread, but as far as the test is
263
 
            # concerned, raising an assertion is too strong. On most of the
264
 
            # platforms, this doesn't occur, so just mentioning the problem is
265
 
            # enough for now -- vila 2010824
266
 
            sys.stderr.write('thread %s hung\n' % (self.name,))
267
 
            #raise AssertionError('thread %s hung' % (self.name,))
268
 
 
269
 
 
270
 
class TestingTCPServerMixin(object):
271
 
    """Mixin to support running SocketServer.TCPServer in a thread.
272
 
 
273
 
    Tests are connecting from the main thread, the server has to be run in a
274
 
    separate thread.
275
 
    """
276
 
 
277
 
    def __init__(self):
278
 
        self.started = threading.Event()
279
 
        self.serving = None
280
 
        self.stopped = threading.Event()
281
 
        # We collect the resources used by the clients so we can release them
282
 
        # when shutting down
283
 
        self.clients = []
284
 
        self.ignored_exceptions = None
285
 
 
286
 
    def server_bind(self):
287
 
        self.socket.bind(self.server_address)
288
 
        self.server_address = self.socket.getsockname()
289
 
 
290
 
    def serve(self):
291
 
        self.serving = True
292
 
        # We are listening and ready to accept connections
293
 
        self.started.set()
294
 
        try:
295
 
            while self.serving:
296
 
                # Really a connection but the python framework is generic and
297
 
                # call them requests
298
 
                self.handle_request()
299
 
            # Let's close the listening socket
300
 
            self.server_close()
301
 
        finally:
302
 
            self.stopped.set()
303
 
 
304
 
    def handle_request(self):
305
 
        """Handle one request.
306
 
 
307
 
        The python version swallows some socket exceptions and we don't use
308
 
        timeout, so we override it to better control the server behavior.
309
 
        """
310
 
        request, client_address = self.get_request()
311
 
        if self.verify_request(request, client_address):
312
 
            try:
313
 
                self.process_request(request, client_address)
314
 
            except:
315
 
                self.handle_error(request, client_address)
316
 
        else:
317
 
            self.close_request(request)
318
 
 
319
 
    def get_request(self):
320
 
        return self.socket.accept()
321
 
 
322
 
    def verify_request(self, request, client_address):
323
 
        """Verify the request.
324
 
 
325
 
        Return True if we should proceed with this request, False if we should
326
 
        not even touch a single byte in the socket ! This is useful when we
327
 
        stop the server with a dummy last connection.
328
 
        """
329
 
        return self.serving
330
 
 
331
 
    def handle_error(self, request, client_address):
332
 
        # Stop serving and re-raise the last exception seen
333
 
        self.serving = False
334
 
        # The following can be used for debugging purposes, it will display the
335
 
        # exception and the traceback just when it occurs instead of waiting
336
 
        # for the thread to be joined.
337
 
        # SocketServer.BaseServer.handle_error(self, request, client_address)
338
 
 
339
 
        # We call close_request manually, because we are going to raise an
340
 
        # exception. The SocketServer implementation calls:
341
 
        #   handle_error(...)
342
 
        #   close_request(...)
343
 
        # But because we raise the exception, close_request will never be
344
 
        # triggered. This helps client not block waiting for a response when
345
 
        # the server gets an exception.
346
 
        self.close_request(request)
347
 
        raise
348
 
 
349
 
    def ignored_exceptions_during_shutdown(self, e):
350
 
        if sys.platform == 'win32':
351
 
            accepted_errnos = [errno.EBADF,
352
 
                               errno.EPIPE,
353
 
                               errno.WSAEBADF,
354
 
                               errno.WSAECONNRESET,
355
 
                               errno.WSAENOTCONN,
356
 
                               errno.WSAESHUTDOWN,
357
 
                               ]
358
 
        else:
359
 
            accepted_errnos = [errno.EBADF,
360
 
                               errno.ECONNRESET,
361
 
                               errno.ENOTCONN,
362
 
                               errno.EPIPE,
363
 
                               ]
364
 
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
365
 
            return True
366
 
        return False
367
 
 
368
 
    # The following methods are called by the main thread
369
 
 
370
 
    def stop_client_connections(self):
371
 
        while self.clients:
372
 
            c = self.clients.pop()
373
 
            self.shutdown_client(c)
374
 
 
375
 
    def shutdown_socket(self, sock):
376
 
        """Properly shutdown a socket.
377
 
 
378
 
        This should be called only when no other thread is trying to use the
379
 
        socket.
380
 
        """
381
 
        try:
382
 
            sock.shutdown(socket.SHUT_RDWR)
383
 
            sock.close()
384
 
        except Exception, e:
385
 
            if self.ignored_exceptions(e):
386
 
                pass
387
 
            else:
388
 
                raise
389
 
 
390
 
    # The following methods are called by the main thread
391
 
 
392
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
393
 
        self.ignored_exceptions = ignored_exceptions
394
 
        thread.set_ignored_exceptions(self.ignored_exceptions)
395
 
 
396
 
    def _pending_exception(self, thread):
397
 
        """Raise server uncaught exception.
398
 
 
399
 
        Daughter classes can override this if they use daughter threads.
400
 
        """
401
 
        thread.pending_exception()
402
 
 
403
 
 
404
 
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
405
 
 
406
 
    def __init__(self, server_address, request_handler_class):
407
 
        TestingTCPServerMixin.__init__(self)
408
 
        SocketServer.TCPServer.__init__(self, server_address,
409
 
                                        request_handler_class)
410
 
 
411
 
    def get_request(self):
412
 
        """Get the request and client address from the socket."""
413
 
        sock, addr = TestingTCPServerMixin.get_request(self)
414
 
        self.clients.append((sock, addr))
415
 
        return sock, addr
416
 
 
417
 
    # The following methods are called by the main thread
418
 
 
419
 
    def shutdown_client(self, client):
420
 
        sock, addr = client
421
 
        self.shutdown_socket(sock)
422
 
 
423
 
 
424
 
class TestingThreadingTCPServer(TestingTCPServerMixin,
425
 
                                SocketServer.ThreadingTCPServer):
426
 
 
427
 
    def __init__(self, server_address, request_handler_class):
428
 
        TestingTCPServerMixin.__init__(self)
429
 
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
430
 
                                                 request_handler_class)
431
 
 
432
 
    def get_request(self):
433
 
        """Get the request and client address from the socket."""
434
 
        sock, addr = TestingTCPServerMixin.get_request(self)
435
 
        # The thread is not created yet, it will be updated in process_request
436
 
        self.clients.append((sock, addr, None))
437
 
        return sock, addr
438
 
 
439
 
    def process_request_thread(self, started, detached, stopped,
440
 
                               request, client_address):
441
 
        started.set()
442
 
        # We will be on our own once the server tells us we're detached
443
 
        detached.wait()
444
 
        SocketServer.ThreadingTCPServer.process_request_thread(
445
 
            self, request, client_address)
446
 
        self.close_request(request)
447
 
        stopped.set()
448
 
 
449
 
    def process_request(self, request, client_address):
450
 
        """Start a new thread to process the request."""
451
 
        started = threading.Event()
452
 
        detached = threading.Event()
453
 
        stopped = threading.Event()
454
 
        t = TestThread(
455
 
            sync_event=stopped,
456
 
            name='%s -> %s' % (client_address, self.server_address),
457
 
            target = self.process_request_thread,
458
 
            args = (started, detached, stopped, request, client_address))
459
 
        # Update the client description
460
 
        self.clients.pop()
461
 
        self.clients.append((request, client_address, t))
462
 
        # Propagate the exception handler since we must use the same one as
463
 
        # TestingTCPServer for connections running in their own threads.
464
 
        t.set_ignored_exceptions(self.ignored_exceptions)
465
 
        t.start()
466
 
        started.wait()
467
 
        # If an exception occured during the thread start, it will get raised.
468
 
        t.pending_exception()
469
 
        if debug_threads():
470
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
471
 
        # Tell the thread, it's now on its own for exception handling.
472
 
        detached.set()
473
 
 
474
 
    # The following methods are called by the main thread
475
 
 
476
 
    def shutdown_client(self, client):
477
 
        sock, addr, connection_thread = client
478
 
        self.shutdown_socket(sock)
479
 
        if connection_thread is not None:
480
 
            # The thread has been created only if the request is processed but
481
 
            # after the connection is inited. This could happen during server
482
 
            # shutdown. If an exception occurred in the thread it will be
483
 
            # re-raised
484
 
            if debug_threads():
485
 
                sys.stderr.write('Client thread %s will be joined\n'
486
 
                                 % (connection_thread.name,))
487
 
            connection_thread.join()
488
 
 
489
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
490
 
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
491
 
                                                     ignored_exceptions)
492
 
        for sock, addr, connection_thread in self.clients:
493
 
            if connection_thread is not None:
494
 
                connection_thread.set_ignored_exceptions(
495
 
                    self.ignored_exceptions)
496
 
 
497
 
    def _pending_exception(self, thread):
498
 
        for sock, addr, connection_thread in self.clients:
499
 
            if connection_thread is not None:
500
 
                connection_thread.pending_exception()
501
 
        TestingTCPServerMixin._pending_exception(self, thread)
502
 
 
503
 
 
504
 
class TestingTCPServerInAThread(transport.Server):
505
 
    """A server in a thread that re-raise thread exceptions."""
506
 
 
507
 
    def __init__(self, server_address, server_class, request_handler_class):
508
 
        self.server_class = server_class
509
 
        self.request_handler_class = request_handler_class
510
 
        self.host, self.port = server_address
511
 
        self.server = None
512
 
        self._server_thread = None
513
 
 
514
 
    def __repr__(self):
515
 
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
516
 
 
517
 
    def create_server(self):
518
 
        return self.server_class((self.host, self.port),
519
 
                                 self.request_handler_class)
520
 
 
521
 
    def start_server(self):
522
 
        self.server = self.create_server()
523
 
        self._server_thread = TestThread(
524
 
            sync_event=self.server.started,
525
 
            target=self.run_server)
526
 
        self._server_thread.start()
527
 
        # Wait for the server thread to start (i.e. release the lock)
528
 
        self.server.started.wait()
529
 
        # Get the real address, especially the port
530
 
        self.host, self.port = self.server.server_address
531
 
        self._server_thread.name = self.server.server_address
532
 
        if debug_threads():
533
 
            sys.stderr.write('Server thread %s started\n'
534
 
                             % (self._server_thread.name,))
535
 
        # If an exception occured during the server start, it will get raised,
536
 
        # otherwise, the server is blocked on its accept() call.
537
 
        self._server_thread.pending_exception()
538
 
        # From now on, we'll use a different event to ensure the server can set
539
 
        # its exception
540
 
        self._server_thread.set_sync_event(self.server.stopped)
541
 
 
542
 
    def run_server(self):
543
 
        self.server.serve()
544
 
 
545
 
    def stop_server(self):
546
 
        if self.server is None:
547
 
            return
548
 
        try:
549
 
            # The server has been started successfully, shut it down now.  As
550
 
            # soon as we stop serving, no more connection are accepted except
551
 
            # one to get out of the blocking listen.
552
 
            self.set_ignored_exceptions(
553
 
                self.server.ignored_exceptions_during_shutdown)
554
 
            self.server.serving = False
555
 
            if debug_threads():
556
 
                sys.stderr.write('Server thread %s will be joined\n'
557
 
                                 % (self._server_thread.name,))
558
 
            # The server is listening for a last connection, let's give it:
559
 
            last_conn = None
560
 
            try:
561
 
                last_conn = osutils.connect_socket((self.host, self.port))
562
 
            except socket.error, e:
563
 
                # But ignore connection errors as the point is to unblock the
564
 
                # server thread, it may happen that it's not blocked or even
565
 
                # not started.
566
 
                pass
567
 
            # We start shutting down the clients while the server itself is
568
 
            # shutting down.
569
 
            self.server.stop_client_connections()
570
 
            # Now we wait for the thread running self.server.serve() to finish
571
 
            self.server.stopped.wait()
572
 
            if last_conn is not None:
573
 
                # Close the last connection without trying to use it. The
574
 
                # server will not process a single byte on that socket to avoid
575
 
                # complications (SSL starts with a handshake for example).
576
 
                last_conn.close()
577
 
            # Check for any exception that could have occurred in the server
578
 
            # thread
579
 
            try:
580
 
                self._server_thread.join()
581
 
            except Exception, e:
582
 
                if self.server.ignored_exceptions(e):
583
 
                    pass
584
 
                else:
585
 
                    raise
586
 
        finally:
587
 
            # Make sure we can be called twice safely, note that this means
588
 
            # that we will raise a single exception even if several occurred in
589
 
            # the various threads involved.
590
 
            self.server = None
591
 
 
592
 
    def set_ignored_exceptions(self, ignored_exceptions):
593
 
        """Install an exception handler for the server."""
594
 
        self.server.set_ignored_exceptions(self._server_thread,
595
 
                                           ignored_exceptions)
596
 
 
597
 
    def pending_exception(self):
598
 
        """Raise uncaught exception in the server."""
599
 
        self.server._pending_exception(self._server_thread)
600
 
 
601
 
 
602
 
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
603
 
                                    medium.SmartServerSocketStreamMedium):
604
 
 
605
 
    def __init__(self, request, client_address, server):
606
 
        medium.SmartServerSocketStreamMedium.__init__(
607
 
            self, request, server.backing_transport,
608
 
            server.root_client_path,
609
 
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
610
 
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
611
 
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
612
 
                                                 server)
613
 
 
614
 
    def handle(self):
615
 
        try:
616
 
            while not self.finished:
617
 
                server_protocol = self._build_protocol()
618
 
                self._serve_one_request(server_protocol)
619
 
        except errors.ConnectionTimeout:
620
 
            # idle connections aren't considered a failure of the server
621
 
            return
622
 
 
623
 
 
624
 
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
625
 
 
626
 
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
627
 
 
628
 
    def __init__(self, server_address, request_handler_class,
629
 
                 backing_transport, root_client_path):
630
 
        TestingThreadingTCPServer.__init__(self, server_address,
631
 
                                           request_handler_class)
632
 
        server.SmartTCPServer.__init__(self, backing_transport,
633
 
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
634
 
 
635
 
    def serve(self):
636
 
        self.run_server_started_hooks()
637
 
        try:
638
 
            TestingThreadingTCPServer.serve(self)
639
 
        finally:
640
 
            self.run_server_stopped_hooks()
641
 
 
642
 
    def get_url(self):
643
 
        """Return the url of the server"""
644
 
        return "bzr://%s:%d/" % self.server_address
645
 
 
646
 
 
647
 
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
648
 
    """Server suitable for use by transport tests.
649
 
 
650
 
    This server is backed by the process's cwd.
651
 
    """
652
 
    def __init__(self, thread_name_suffix=''):
653
 
        self.client_path_extra = None
654
 
        self.thread_name_suffix = thread_name_suffix
655
 
        self.host = '127.0.0.1'
656
 
        self.port = 0
657
 
        super(SmartTCPServer_for_testing, self).__init__(
658
 
                (self.host, self.port),
659
 
                TestingSmartServer,
660
 
                TestingSmartConnectionHandler)
661
 
 
662
 
    def create_server(self):
663
 
        return self.server_class((self.host, self.port),
664
 
                                 self.request_handler_class,
665
 
                                 self.backing_transport,
666
 
                                 self.root_client_path)
667
 
 
668
 
 
669
 
    def start_server(self, backing_transport_server=None,
670
 
                     client_path_extra='/extra/'):
671
 
        """Set up server for testing.
672
 
 
673
 
        :param backing_transport_server: backing server to use.  If not
674
 
            specified, a LocalURLServer at the current working directory will
675
 
            be used.
676
 
        :param client_path_extra: a path segment starting with '/' to append to
677
 
            the root URL for this server.  For instance, a value of '/foo/bar/'
678
 
            will mean the root of the backing transport will be published at a
679
 
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
680
 
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
681
 
            by default will fail unless they do the necessary path translation.
682
 
        """
683
 
        if not client_path_extra.startswith('/'):
684
 
            raise ValueError(client_path_extra)
685
 
        self.root_client_path = self.client_path_extra = client_path_extra
686
 
        from bzrlib.transport.chroot import ChrootServer
687
 
        if backing_transport_server is None:
688
 
            backing_transport_server = LocalURLServer()
689
 
        self.chroot_server = ChrootServer(
690
 
            self.get_backing_transport(backing_transport_server))
691
 
        self.chroot_server.start_server()
692
 
        self.backing_transport = transport.get_transport_from_url(
693
 
            self.chroot_server.get_url())
694
 
        super(SmartTCPServer_for_testing, self).start_server()
695
 
 
696
 
    def stop_server(self):
697
 
        try:
698
 
            super(SmartTCPServer_for_testing, self).stop_server()
699
 
        finally:
700
 
            self.chroot_server.stop_server()
701
 
 
702
 
    def get_backing_transport(self, backing_transport_server):
703
 
        """Get a backing transport from a server we are decorating."""
704
 
        return transport.get_transport_from_url(
705
 
            backing_transport_server.get_url())
706
 
 
707
 
    def get_url(self):
708
 
        url = self.server.get_url()
709
 
        return url[:-1] + self.client_path_extra
710
 
 
711
 
    def get_bogus_url(self):
712
 
        """Return a URL which will fail to connect"""
713
 
        return 'bzr://127.0.0.1:1/'
714
 
 
715
 
 
716
 
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
717
 
    """Get a readonly server for testing."""
718
 
 
719
 
    def get_backing_transport(self, backing_transport_server):
720
 
        """Get a backing transport from a server we are decorating."""
721
 
        url = 'readonly+' + backing_transport_server.get_url()
722
 
        return transport.get_transport_from_url(url)
723
 
 
724
 
 
725
 
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
726
 
    """A variation of SmartTCPServer_for_testing that limits the client to
727
 
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
728
 
    """
729
 
 
730
 
    def get_url(self):
731
 
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
732
 
        url = 'bzr-v2://' + url[len('bzr://'):]
733
 
        return url
734
 
 
735
 
 
736
 
class ReadonlySmartTCPServer_for_testing_v2_only(
737
 
    SmartTCPServer_for_testing_v2_only):
738
 
    """Get a readonly server for testing."""
739
 
 
740
 
    def get_backing_transport(self, backing_transport_server):
741
 
        """Get a backing transport from a server we are decorating."""
742
 
        url = 'readonly+' + backing_transport_server.get_url()
743
 
        return transport.get_transport_from_url(url)