~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Robert Collins
  • Date: 2005-11-04 23:27:47 UTC
  • Revision ID: robertc@robertcollins.net-20051104232747-5872c68d759bc7be
Bugfix the config test suite to not create .bazaar in the dir where it is run.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010, 2011 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
import errno
18
 
import socket
19
 
import SocketServer
20
 
import sys
21
 
import threading
22
 
 
23
 
 
24
 
from bzrlib import (
25
 
    cethread,
26
 
    osutils,
27
 
    transport,
28
 
    urlutils,
29
 
    )
30
 
from bzrlib.transport import (
31
 
    chroot,
32
 
    pathfilter,
33
 
    )
34
 
from bzrlib.smart import (
35
 
    medium,
36
 
    server,
37
 
    )
38
 
 
39
 
 
40
 
def debug_threads():
41
 
    # FIXME: There is a dependency loop between bzrlib.tests and
42
 
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
43
 
    # defining this function is enough for our needs. -- vila 20100611
44
 
    from bzrlib import tests
45
 
    return 'threads' in tests.selftest_debug_flags
46
 
 
47
 
 
48
 
class TestServer(transport.Server):
49
 
    """A Transport Server dedicated to tests.
50
 
 
51
 
    The TestServer interface provides a server for a given transport. We use
52
 
    these servers as loopback testing tools. For any given transport the
53
 
    Servers it provides must either allow writing, or serve the contents
54
 
    of os.getcwdu() at the time start_server is called.
55
 
 
56
 
    Note that these are real servers - they must implement all the things
57
 
    that we want bzr transports to take advantage of.
58
 
    """
59
 
 
60
 
    def get_url(self):
61
 
        """Return a url for this server.
62
 
 
63
 
        If the transport does not represent a disk directory (i.e. it is
64
 
        a database like svn, or a memory only transport, it should return
65
 
        a connection to a newly established resource for this Server.
66
 
        Otherwise it should return a url that will provide access to the path
67
 
        that was os.getcwdu() when start_server() was called.
68
 
 
69
 
        Subsequent calls will return the same resource.
70
 
        """
71
 
        raise NotImplementedError
72
 
 
73
 
    def get_bogus_url(self):
74
 
        """Return a url for this protocol, that will fail to connect.
75
 
 
76
 
        This may raise NotImplementedError to indicate that this server cannot
77
 
        provide bogus urls.
78
 
        """
79
 
        raise NotImplementedError
80
 
 
81
 
 
82
 
class LocalURLServer(TestServer):
83
 
    """A pretend server for local transports, using file:// urls.
84
 
 
85
 
    Of course no actual server is required to access the local filesystem, so
86
 
    this just exists to tell the test code how to get to it.
87
 
    """
88
 
 
89
 
    def start_server(self):
90
 
        pass
91
 
 
92
 
    def get_url(self):
93
 
        """See Transport.Server.get_url."""
94
 
        return urlutils.local_path_to_url('')
95
 
 
96
 
 
97
 
class DecoratorServer(TestServer):
98
 
    """Server for the TransportDecorator for testing with.
99
 
 
100
 
    To use this when subclassing TransportDecorator, override override the
101
 
    get_decorator_class method.
102
 
    """
103
 
 
104
 
    def start_server(self, server=None):
105
 
        """See bzrlib.transport.Server.start_server.
106
 
 
107
 
        :server: decorate the urls given by server. If not provided a
108
 
        LocalServer is created.
109
 
        """
110
 
        if server is not None:
111
 
            self._made_server = False
112
 
            self._server = server
113
 
        else:
114
 
            self._made_server = True
115
 
            self._server = LocalURLServer()
116
 
            self._server.start_server()
117
 
 
118
 
    def stop_server(self):
119
 
        if self._made_server:
120
 
            self._server.stop_server()
121
 
 
122
 
    def get_decorator_class(self):
123
 
        """Return the class of the decorators we should be constructing."""
124
 
        raise NotImplementedError(self.get_decorator_class)
125
 
 
126
 
    def get_url_prefix(self):
127
 
        """What URL prefix does this decorator produce?"""
128
 
        return self.get_decorator_class()._get_url_prefix()
129
 
 
130
 
    def get_bogus_url(self):
131
 
        """See bzrlib.transport.Server.get_bogus_url."""
132
 
        return self.get_url_prefix() + self._server.get_bogus_url()
133
 
 
134
 
    def get_url(self):
135
 
        """See bzrlib.transport.Server.get_url."""
136
 
        return self.get_url_prefix() + self._server.get_url()
137
 
 
138
 
 
139
 
class BrokenRenameServer(DecoratorServer):
140
 
    """Server for the BrokenRenameTransportDecorator for testing with."""
141
 
 
142
 
    def get_decorator_class(self):
143
 
        from bzrlib.transport import brokenrename
144
 
        return brokenrename.BrokenRenameTransportDecorator
145
 
 
146
 
 
147
 
class FakeNFSServer(DecoratorServer):
148
 
    """Server for the FakeNFSTransportDecorator for testing with."""
149
 
 
150
 
    def get_decorator_class(self):
151
 
        from bzrlib.transport import fakenfs
152
 
        return fakenfs.FakeNFSTransportDecorator
153
 
 
154
 
 
155
 
class FakeVFATServer(DecoratorServer):
156
 
    """A server that suggests connections through FakeVFATTransportDecorator
157
 
 
158
 
    For use in testing.
159
 
    """
160
 
 
161
 
    def get_decorator_class(self):
162
 
        from bzrlib.transport import fakevfat
163
 
        return fakevfat.FakeVFATTransportDecorator
164
 
 
165
 
 
166
 
class LogDecoratorServer(DecoratorServer):
167
 
    """Server for testing."""
168
 
 
169
 
    def get_decorator_class(self):
170
 
        from bzrlib.transport import log
171
 
        return log.TransportLogDecorator
172
 
 
173
 
 
174
 
class NoSmartTransportServer(DecoratorServer):
175
 
    """Server for the NoSmartTransportDecorator for testing with."""
176
 
 
177
 
    def get_decorator_class(self):
178
 
        from bzrlib.transport import nosmart
179
 
        return nosmart.NoSmartTransportDecorator
180
 
 
181
 
 
182
 
class ReadonlyServer(DecoratorServer):
183
 
    """Server for the ReadonlyTransportDecorator for testing with."""
184
 
 
185
 
    def get_decorator_class(self):
186
 
        from bzrlib.transport import readonly
187
 
        return readonly.ReadonlyTransportDecorator
188
 
 
189
 
 
190
 
class TraceServer(DecoratorServer):
191
 
    """Server for the TransportTraceDecorator for testing with."""
192
 
 
193
 
    def get_decorator_class(self):
194
 
        from bzrlib.transport import trace
195
 
        return trace.TransportTraceDecorator
196
 
 
197
 
 
198
 
class UnlistableServer(DecoratorServer):
199
 
    """Server for the UnlistableTransportDecorator for testing with."""
200
 
 
201
 
    def get_decorator_class(self):
202
 
        from bzrlib.transport import unlistable
203
 
        return unlistable.UnlistableTransportDecorator
204
 
 
205
 
 
206
 
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
207
 
 
208
 
    def __init__(self):
209
 
        """TestingPathFilteringServer is not usable until start_server
210
 
        is called."""
211
 
 
212
 
    def start_server(self, backing_server=None):
213
 
        """Setup the Chroot on backing_server."""
214
 
        if backing_server is not None:
215
 
            self.backing_transport = transport.get_transport(
216
 
                backing_server.get_url())
217
 
        else:
218
 
            self.backing_transport = transport.get_transport('.')
219
 
        self.backing_transport.clone('added-by-filter').ensure_base()
220
 
        self.filter_func = lambda x: 'added-by-filter/' + x
221
 
        super(TestingPathFilteringServer, self).start_server()
222
 
 
223
 
    def get_bogus_url(self):
224
 
        raise NotImplementedError
225
 
 
226
 
 
227
 
class TestingChrootServer(chroot.ChrootServer):
228
 
 
229
 
    def __init__(self):
230
 
        """TestingChrootServer is not usable until start_server is called."""
231
 
        super(TestingChrootServer, self).__init__(None)
232
 
 
233
 
    def start_server(self, backing_server=None):
234
 
        """Setup the Chroot on backing_server."""
235
 
        if backing_server is not None:
236
 
            self.backing_transport = transport.get_transport(
237
 
                backing_server.get_url())
238
 
        else:
239
 
            self.backing_transport = transport.get_transport('.')
240
 
        super(TestingChrootServer, self).start_server()
241
 
 
242
 
    def get_bogus_url(self):
243
 
        raise NotImplementedError
244
 
 
245
 
 
246
 
class TestThread(cethread.CatchingExceptionThread):
247
 
 
248
 
    def join(self, timeout=5):
249
 
        """Overrides to use a default timeout.
250
 
 
251
 
        The default timeout is set to 5 and should expire only when a thread
252
 
        serving a client connection is hung.
253
 
        """
254
 
        super(TestThread, self).join(timeout)
255
 
        if timeout and self.isAlive():
256
 
            # The timeout expired without joining the thread, the thread is
257
 
            # therefore stucked and that's a failure as far as the test is
258
 
            # concerned. We used to hang here.
259
 
 
260
 
            # FIXME: we need to kill the thread, but as far as the test is
261
 
            # concerned, raising an assertion is too strong. On most of the
262
 
            # platforms, this doesn't occur, so just mentioning the problem is
263
 
            # enough for now -- vila 2010824
264
 
            sys.stderr.write('thread %s hung\n' % (self.name,))
265
 
            #raise AssertionError('thread %s hung' % (self.name,))
266
 
 
267
 
 
268
 
class TestingTCPServerMixin:
269
 
    """Mixin to support running SocketServer.TCPServer in a thread.
270
 
 
271
 
    Tests are connecting from the main thread, the server has to be run in a
272
 
    separate thread.
273
 
    """
274
 
 
275
 
    def __init__(self):
276
 
        self.started = threading.Event()
277
 
        self.serving = None
278
 
        self.stopped = threading.Event()
279
 
        # We collect the resources used by the clients so we can release them
280
 
        # when shutting down
281
 
        self.clients = []
282
 
        self.ignored_exceptions = None
283
 
 
284
 
    def server_bind(self):
285
 
        self.socket.bind(self.server_address)
286
 
        self.server_address = self.socket.getsockname()
287
 
 
288
 
    def serve(self):
289
 
        self.serving = True
290
 
        self.stopped.clear()
291
 
        # We are listening and ready to accept connections
292
 
        self.started.set()
293
 
        try:
294
 
            while self.serving:
295
 
                # Really a connection but the python framework is generic and
296
 
                # call them requests
297
 
                self.handle_request()
298
 
            # Let's close the listening socket
299
 
            self.server_close()
300
 
        finally:
301
 
            self.stopped.set()
302
 
 
303
 
    def handle_request(self):
304
 
        """Handle one request.
305
 
 
306
 
        The python version swallows some socket exceptions and we don't use
307
 
        timeout, so we override it to better control the server behavior.
308
 
        """
309
 
        request, client_address = self.get_request()
310
 
        if self.verify_request(request, client_address):
311
 
            try:
312
 
                self.process_request(request, client_address)
313
 
            except:
314
 
                self.handle_error(request, client_address)
315
 
                self.close_request(request)
316
 
 
317
 
    def get_request(self):
318
 
        return self.socket.accept()
319
 
 
320
 
    def verify_request(self, request, client_address):
321
 
        """Verify the request.
322
 
 
323
 
        Return True if we should proceed with this request, False if we should
324
 
        not even touch a single byte in the socket ! This is useful when we
325
 
        stop the server with a dummy last connection.
326
 
        """
327
 
        return self.serving
328
 
 
329
 
    def handle_error(self, request, client_address):
330
 
        # Stop serving and re-raise the last exception seen
331
 
        self.serving = False
332
 
        # The following can be used for debugging purposes, it will display the
333
 
        # exception and the traceback just when it occurs instead of waiting
334
 
        # for the thread to be joined.
335
 
 
336
 
        # SocketServer.BaseServer.handle_error(self, request, client_address)
337
 
        raise
338
 
 
339
 
    def ignored_exceptions_during_shutdown(self, e):
340
 
        if sys.platform == 'win32':
341
 
            accepted_errnos = [errno.EBADF,
342
 
                               errno.EPIPE,
343
 
                               errno.WSAEBADF,
344
 
                               errno.WSAECONNRESET,
345
 
                               errno.WSAENOTCONN,
346
 
                               errno.WSAESHUTDOWN,
347
 
                               ]
348
 
        else:
349
 
            accepted_errnos = [errno.EBADF,
350
 
                               errno.ECONNRESET,
351
 
                               errno.ENOTCONN,
352
 
                               errno.EPIPE,
353
 
                               ]
354
 
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
355
 
            return True
356
 
        return False
357
 
 
358
 
    # The following methods are called by the main thread
359
 
 
360
 
    def stop_client_connections(self):
361
 
        while self.clients:
362
 
            c = self.clients.pop()
363
 
            self.shutdown_client(c)
364
 
 
365
 
    def shutdown_socket(self, sock):
366
 
        """Properly shutdown a socket.
367
 
 
368
 
        This should be called only when no other thread is trying to use the
369
 
        socket.
370
 
        """
371
 
        try:
372
 
            sock.shutdown(socket.SHUT_RDWR)
373
 
            sock.close()
374
 
        except Exception, e:
375
 
            if self.ignored_exceptions(e):
376
 
                pass
377
 
            else:
378
 
                raise
379
 
 
380
 
    # The following methods are called by the main thread
381
 
 
382
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
383
 
        self.ignored_exceptions = ignored_exceptions
384
 
        thread.set_ignored_exceptions(self.ignored_exceptions)
385
 
 
386
 
    def _pending_exception(self, thread):
387
 
        """Raise server uncaught exception.
388
 
 
389
 
        Daughter classes can override this if they use daughter threads.
390
 
        """
391
 
        thread.pending_exception()
392
 
 
393
 
 
394
 
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
395
 
 
396
 
    def __init__(self, server_address, request_handler_class):
397
 
        TestingTCPServerMixin.__init__(self)
398
 
        SocketServer.TCPServer.__init__(self, server_address,
399
 
                                        request_handler_class)
400
 
 
401
 
    def get_request(self):
402
 
        """Get the request and client address from the socket."""
403
 
        sock, addr = TestingTCPServerMixin.get_request(self)
404
 
        self.clients.append((sock, addr))
405
 
        return sock, addr
406
 
 
407
 
    # The following methods are called by the main thread
408
 
 
409
 
    def shutdown_client(self, client):
410
 
        sock, addr = client
411
 
        self.shutdown_socket(sock)
412
 
 
413
 
 
414
 
class TestingThreadingTCPServer(TestingTCPServerMixin,
415
 
                                SocketServer.ThreadingTCPServer):
416
 
 
417
 
    def __init__(self, server_address, request_handler_class):
418
 
        TestingTCPServerMixin.__init__(self)
419
 
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
420
 
                                                 request_handler_class)
421
 
 
422
 
    def get_request (self):
423
 
        """Get the request and client address from the socket."""
424
 
        sock, addr = TestingTCPServerMixin.get_request(self)
425
 
        # The thread is not create yet, it will be updated in process_request
426
 
        self.clients.append((sock, addr, None))
427
 
        return sock, addr
428
 
 
429
 
    def process_request_thread(self, started, stopped, request, client_address):
430
 
        started.set()
431
 
        SocketServer.ThreadingTCPServer.process_request_thread(
432
 
            self, request, client_address)
433
 
        self.close_request(request)
434
 
        stopped.set()
435
 
 
436
 
    def process_request(self, request, client_address):
437
 
        """Start a new thread to process the request."""
438
 
        started = threading.Event()
439
 
        stopped = threading.Event()
440
 
        t = TestThread(
441
 
            sync_event=stopped,
442
 
            name='%s -> %s' % (client_address, self.server_address),
443
 
            target = self.process_request_thread,
444
 
            args = (started, stopped, request, client_address))
445
 
        # Update the client description
446
 
        self.clients.pop()
447
 
        self.clients.append((request, client_address, t))
448
 
        # Propagate the exception handler since we must use the same one as
449
 
        # TestingTCPServer for connections running in their own threads.
450
 
        t.set_ignored_exceptions(self.ignored_exceptions)
451
 
        t.start()
452
 
        started.wait()
453
 
        if debug_threads():
454
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
455
 
        # If an exception occured during the thread start, it will get raised.
456
 
        t.pending_exception()
457
 
 
458
 
    # The following methods are called by the main thread
459
 
 
460
 
    def shutdown_client(self, client):
461
 
        sock, addr, connection_thread = client
462
 
        self.shutdown_socket(sock)
463
 
        if connection_thread is not None:
464
 
            # The thread has been created only if the request is processed but
465
 
            # after the connection is inited. This could happen during server
466
 
            # shutdown. If an exception occurred in the thread it will be
467
 
            # re-raised
468
 
            if debug_threads():
469
 
                sys.stderr.write('Client thread %s will be joined\n'
470
 
                                 % (connection_thread.name,))
471
 
            connection_thread.join()
472
 
 
473
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
474
 
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
475
 
                                                     ignored_exceptions)
476
 
        for sock, addr, connection_thread in self.clients:
477
 
            if connection_thread is not None:
478
 
                connection_thread.set_ignored_exceptions(
479
 
                    self.ignored_exceptions)
480
 
 
481
 
    def _pending_exception(self, thread):
482
 
        for sock, addr, connection_thread in self.clients:
483
 
            if connection_thread is not None:
484
 
                connection_thread.pending_exception()
485
 
        TestingTCPServerMixin._pending_exception(self, thread)
486
 
 
487
 
 
488
 
class TestingTCPServerInAThread(transport.Server):
489
 
    """A server in a thread that re-raise thread exceptions."""
490
 
 
491
 
    def __init__(self, server_address, server_class, request_handler_class):
492
 
        self.server_class = server_class
493
 
        self.request_handler_class = request_handler_class
494
 
        self.host, self.port = server_address
495
 
        self.server = None
496
 
        self._server_thread = None
497
 
 
498
 
    def __repr__(self):
499
 
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
500
 
 
501
 
    def create_server(self):
502
 
        return self.server_class((self.host, self.port),
503
 
                                 self.request_handler_class)
504
 
 
505
 
    def start_server(self):
506
 
        self.server = self.create_server()
507
 
        self._server_thread = TestThread(
508
 
            sync_event=self.server.started,
509
 
            target=self.run_server)
510
 
        self._server_thread.start()
511
 
        # Wait for the server thread to start (i.e release the lock)
512
 
        self.server.started.wait()
513
 
        # Get the real address, especially the port
514
 
        self.host, self.port = self.server.server_address
515
 
        self._server_thread.name = self.server.server_address
516
 
        if debug_threads():
517
 
            sys.stderr.write('Server thread %s started\n'
518
 
                             % (self._server_thread.name,))
519
 
        # If an exception occured during the server start, it will get raised,
520
 
        # otherwise, the server is blocked on its accept() call.
521
 
        self._server_thread.pending_exception()
522
 
        # From now on, we'll use a different event to ensure the server can set
523
 
        # its exception
524
 
        self._server_thread.set_sync_event(self.server.stopped)
525
 
 
526
 
    def run_server(self):
527
 
        self.server.serve()
528
 
 
529
 
    def stop_server(self):
530
 
        if self.server is None:
531
 
            return
532
 
        try:
533
 
            # The server has been started successfully, shut it down now.  As
534
 
            # soon as we stop serving, no more connection are accepted except
535
 
            # one to get out of the blocking listen.
536
 
            self.set_ignored_exceptions(
537
 
                self.server.ignored_exceptions_during_shutdown)
538
 
            self.server.serving = False
539
 
            if debug_threads():
540
 
                sys.stderr.write('Server thread %s will be joined\n'
541
 
                                 % (self._server_thread.name,))
542
 
            # The server is listening for a last connection, let's give it:
543
 
            last_conn = None
544
 
            try:
545
 
                last_conn = osutils.connect_socket((self.host, self.port))
546
 
            except socket.error, e:
547
 
                # But ignore connection errors as the point is to unblock the
548
 
                # server thread, it may happen that it's not blocked or even
549
 
                # not started.
550
 
                pass
551
 
            # We start shutting down the clients while the server itself is
552
 
            # shutting down.
553
 
            self.server.stop_client_connections()
554
 
            # Now we wait for the thread running self.server.serve() to finish
555
 
            self.server.stopped.wait()
556
 
            if last_conn is not None:
557
 
                # Close the last connection without trying to use it. The
558
 
                # server will not process a single byte on that socket to avoid
559
 
                # complications (SSL starts with a handshake for example).
560
 
                last_conn.close()
561
 
            # Check for any exception that could have occurred in the server
562
 
            # thread
563
 
            try:
564
 
                self._server_thread.join()
565
 
            except Exception, e:
566
 
                if self.server.ignored_exceptions(e):
567
 
                    pass
568
 
                else:
569
 
                    raise
570
 
        finally:
571
 
            # Make sure we can be called twice safely, note that this means
572
 
            # that we will raise a single exception even if several occurred in
573
 
            # the various threads involved.
574
 
            self.server = None
575
 
 
576
 
    def set_ignored_exceptions(self, ignored_exceptions):
577
 
        """Install an exception handler for the server."""
578
 
        self.server.set_ignored_exceptions(self._server_thread,
579
 
                                           ignored_exceptions)
580
 
 
581
 
    def pending_exception(self):
582
 
        """Raise uncaught exception in the server."""
583
 
        self.server._pending_exception(self._server_thread)
584
 
 
585
 
 
586
 
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
587
 
                                    medium.SmartServerSocketStreamMedium):
588
 
 
589
 
    def __init__(self, request, client_address, server):
590
 
        medium.SmartServerSocketStreamMedium.__init__(
591
 
            self, request, server.backing_transport,
592
 
            server.root_client_path)
593
 
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
594
 
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
595
 
                                                 server)
596
 
 
597
 
    def handle(self):
598
 
        while not self.finished:
599
 
            server_protocol = self._build_protocol()
600
 
            self._serve_one_request(server_protocol)
601
 
 
602
 
 
603
 
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
604
 
 
605
 
    def __init__(self, server_address, request_handler_class,
606
 
                 backing_transport, root_client_path):
607
 
        TestingThreadingTCPServer.__init__(self, server_address,
608
 
                                           request_handler_class)
609
 
        server.SmartTCPServer.__init__(self, backing_transport,
610
 
                                       root_client_path)
611
 
    def serve(self):
612
 
        self.run_server_started_hooks()
613
 
        try:
614
 
            TestingThreadingTCPServer.serve(self)
615
 
        finally:
616
 
            self.run_server_stopped_hooks()
617
 
 
618
 
    def get_url(self):
619
 
        """Return the url of the server"""
620
 
        return "bzr://%s:%d/" % self.server_address
621
 
 
622
 
 
623
 
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
624
 
    """Server suitable for use by transport tests.
625
 
 
626
 
    This server is backed by the process's cwd.
627
 
    """
628
 
    def __init__(self, thread_name_suffix=''):
629
 
        self.client_path_extra = None
630
 
        self.thread_name_suffix = thread_name_suffix
631
 
        self.host = '127.0.0.1'
632
 
        self.port = 0
633
 
        super(SmartTCPServer_for_testing, self).__init__(
634
 
                (self.host, self.port),
635
 
                TestingSmartServer,
636
 
                TestingSmartConnectionHandler)
637
 
 
638
 
    def create_server(self):
639
 
        return self.server_class((self.host, self.port),
640
 
                                 self.request_handler_class,
641
 
                                 self.backing_transport,
642
 
                                 self.root_client_path)
643
 
 
644
 
 
645
 
    def start_server(self, backing_transport_server=None,
646
 
                     client_path_extra='/extra/'):
647
 
        """Set up server for testing.
648
 
 
649
 
        :param backing_transport_server: backing server to use.  If not
650
 
            specified, a LocalURLServer at the current working directory will
651
 
            be used.
652
 
        :param client_path_extra: a path segment starting with '/' to append to
653
 
            the root URL for this server.  For instance, a value of '/foo/bar/'
654
 
            will mean the root of the backing transport will be published at a
655
 
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
656
 
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
657
 
            by default will fail unless they do the necessary path translation.
658
 
        """
659
 
        if not client_path_extra.startswith('/'):
660
 
            raise ValueError(client_path_extra)
661
 
        self.root_client_path = self.client_path_extra = client_path_extra
662
 
        from bzrlib.transport.chroot import ChrootServer
663
 
        if backing_transport_server is None:
664
 
            backing_transport_server = LocalURLServer()
665
 
        self.chroot_server = ChrootServer(
666
 
            self.get_backing_transport(backing_transport_server))
667
 
        self.chroot_server.start_server()
668
 
        self.backing_transport = transport.get_transport(
669
 
            self.chroot_server.get_url())
670
 
        super(SmartTCPServer_for_testing, self).start_server()
671
 
 
672
 
    def stop_server(self):
673
 
        try:
674
 
            super(SmartTCPServer_for_testing, self).stop_server()
675
 
        finally:
676
 
            self.chroot_server.stop_server()
677
 
 
678
 
    def get_backing_transport(self, backing_transport_server):
679
 
        """Get a backing transport from a server we are decorating."""
680
 
        return transport.get_transport(backing_transport_server.get_url())
681
 
 
682
 
    def get_url(self):
683
 
        url = self.server.get_url()
684
 
        return url[:-1] + self.client_path_extra
685
 
 
686
 
    def get_bogus_url(self):
687
 
        """Return a URL which will fail to connect"""
688
 
        return 'bzr://127.0.0.1:1/'
689
 
 
690
 
 
691
 
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
692
 
    """Get a readonly server for testing."""
693
 
 
694
 
    def get_backing_transport(self, backing_transport_server):
695
 
        """Get a backing transport from a server we are decorating."""
696
 
        url = 'readonly+' + backing_transport_server.get_url()
697
 
        return transport.get_transport(url)
698
 
 
699
 
 
700
 
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
701
 
    """A variation of SmartTCPServer_for_testing that limits the client to
702
 
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
703
 
    """
704
 
 
705
 
    def get_url(self):
706
 
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
707
 
        url = 'bzr-v2://' + url[len('bzr://'):]
708
 
        return url
709
 
 
710
 
 
711
 
class ReadonlySmartTCPServer_for_testing_v2_only(
712
 
    SmartTCPServer_for_testing_v2_only):
713
 
    """Get a readonly server for testing."""
714
 
 
715
 
    def get_backing_transport(self, backing_transport_server):
716
 
        """Get a backing transport from a server we are decorating."""
717
 
        url = 'readonly+' + backing_transport_server.get_url()
718
 
        return transport.get_transport(url)