~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

(jameinel) Allow 'bzr serve' to interpret SIGHUP as a graceful shutdown.
 (bug #795025) (John A Meinel)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import sys
 
21
import threading
 
22
 
 
23
 
 
24
from bzrlib import (
 
25
    cethread,
 
26
    osutils,
 
27
    transport,
 
28
    urlutils,
 
29
    )
 
30
from bzrlib.transport import (
 
31
    chroot,
 
32
    pathfilter,
 
33
    )
 
34
from bzrlib.smart import (
 
35
    medium,
 
36
    server,
 
37
    )
 
38
 
 
39
 
 
40
def debug_threads():
 
41
    # FIXME: There is a dependency loop between bzrlib.tests and
 
42
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
43
    # defining this function is enough for our needs. -- vila 20100611
 
44
    from bzrlib import tests
 
45
    return 'threads' in tests.selftest_debug_flags
 
46
 
 
47
 
 
48
class TestServer(transport.Server):
 
49
    """A Transport Server dedicated to tests.
 
50
 
 
51
    The TestServer interface provides a server for a given transport. We use
 
52
    these servers as loopback testing tools. For any given transport the
 
53
    Servers it provides must either allow writing, or serve the contents
 
54
    of os.getcwdu() at the time start_server is called.
 
55
 
 
56
    Note that these are real servers - they must implement all the things
 
57
    that we want bzr transports to take advantage of.
 
58
    """
 
59
 
 
60
    def get_url(self):
 
61
        """Return a url for this server.
 
62
 
 
63
        If the transport does not represent a disk directory (i.e. it is
 
64
        a database like svn, or a memory only transport, it should return
 
65
        a connection to a newly established resource for this Server.
 
66
        Otherwise it should return a url that will provide access to the path
 
67
        that was os.getcwdu() when start_server() was called.
 
68
 
 
69
        Subsequent calls will return the same resource.
 
70
        """
 
71
        raise NotImplementedError
 
72
 
 
73
    def get_bogus_url(self):
 
74
        """Return a url for this protocol, that will fail to connect.
 
75
 
 
76
        This may raise NotImplementedError to indicate that this server cannot
 
77
        provide bogus urls.
 
78
        """
 
79
        raise NotImplementedError
 
80
 
 
81
 
 
82
class LocalURLServer(TestServer):
 
83
    """A pretend server for local transports, using file:// urls.
 
84
 
 
85
    Of course no actual server is required to access the local filesystem, so
 
86
    this just exists to tell the test code how to get to it.
 
87
    """
 
88
 
 
89
    def start_server(self):
 
90
        pass
 
91
 
 
92
    def get_url(self):
 
93
        """See Transport.Server.get_url."""
 
94
        return urlutils.local_path_to_url('')
 
95
 
 
96
 
 
97
class DecoratorServer(TestServer):
 
98
    """Server for the TransportDecorator for testing with.
 
99
 
 
100
    To use this when subclassing TransportDecorator, override override the
 
101
    get_decorator_class method.
 
102
    """
 
103
 
 
104
    def start_server(self, server=None):
 
105
        """See bzrlib.transport.Server.start_server.
 
106
 
 
107
        :server: decorate the urls given by server. If not provided a
 
108
        LocalServer is created.
 
109
        """
 
110
        if server is not None:
 
111
            self._made_server = False
 
112
            self._server = server
 
113
        else:
 
114
            self._made_server = True
 
115
            self._server = LocalURLServer()
 
116
            self._server.start_server()
 
117
 
 
118
    def stop_server(self):
 
119
        if self._made_server:
 
120
            self._server.stop_server()
 
121
 
 
122
    def get_decorator_class(self):
 
123
        """Return the class of the decorators we should be constructing."""
 
124
        raise NotImplementedError(self.get_decorator_class)
 
125
 
 
126
    def get_url_prefix(self):
 
127
        """What URL prefix does this decorator produce?"""
 
128
        return self.get_decorator_class()._get_url_prefix()
 
129
 
 
130
    def get_bogus_url(self):
 
131
        """See bzrlib.transport.Server.get_bogus_url."""
 
132
        return self.get_url_prefix() + self._server.get_bogus_url()
 
133
 
 
134
    def get_url(self):
 
135
        """See bzrlib.transport.Server.get_url."""
 
136
        return self.get_url_prefix() + self._server.get_url()
 
137
 
 
138
 
 
139
class BrokenRenameServer(DecoratorServer):
 
140
    """Server for the BrokenRenameTransportDecorator for testing with."""
 
141
 
 
142
    def get_decorator_class(self):
 
143
        from bzrlib.transport import brokenrename
 
144
        return brokenrename.BrokenRenameTransportDecorator
 
145
 
 
146
 
 
147
class FakeNFSServer(DecoratorServer):
 
148
    """Server for the FakeNFSTransportDecorator for testing with."""
 
149
 
 
150
    def get_decorator_class(self):
 
151
        from bzrlib.transport import fakenfs
 
152
        return fakenfs.FakeNFSTransportDecorator
 
153
 
 
154
 
 
155
class FakeVFATServer(DecoratorServer):
 
156
    """A server that suggests connections through FakeVFATTransportDecorator
 
157
 
 
158
    For use in testing.
 
159
    """
 
160
 
 
161
    def get_decorator_class(self):
 
162
        from bzrlib.transport import fakevfat
 
163
        return fakevfat.FakeVFATTransportDecorator
 
164
 
 
165
 
 
166
class LogDecoratorServer(DecoratorServer):
 
167
    """Server for testing."""
 
168
 
 
169
    def get_decorator_class(self):
 
170
        from bzrlib.transport import log
 
171
        return log.TransportLogDecorator
 
172
 
 
173
 
 
174
class NoSmartTransportServer(DecoratorServer):
 
175
    """Server for the NoSmartTransportDecorator for testing with."""
 
176
 
 
177
    def get_decorator_class(self):
 
178
        from bzrlib.transport import nosmart
 
179
        return nosmart.NoSmartTransportDecorator
 
180
 
 
181
 
 
182
class ReadonlyServer(DecoratorServer):
 
183
    """Server for the ReadonlyTransportDecorator for testing with."""
 
184
 
 
185
    def get_decorator_class(self):
 
186
        from bzrlib.transport import readonly
 
187
        return readonly.ReadonlyTransportDecorator
 
188
 
 
189
 
 
190
class TraceServer(DecoratorServer):
 
191
    """Server for the TransportTraceDecorator for testing with."""
 
192
 
 
193
    def get_decorator_class(self):
 
194
        from bzrlib.transport import trace
 
195
        return trace.TransportTraceDecorator
 
196
 
 
197
 
 
198
class UnlistableServer(DecoratorServer):
 
199
    """Server for the UnlistableTransportDecorator for testing with."""
 
200
 
 
201
    def get_decorator_class(self):
 
202
        from bzrlib.transport import unlistable
 
203
        return unlistable.UnlistableTransportDecorator
 
204
 
 
205
 
 
206
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
 
207
 
 
208
    def __init__(self):
 
209
        """TestingPathFilteringServer is not usable until start_server
 
210
        is called."""
 
211
 
 
212
    def start_server(self, backing_server=None):
 
213
        """Setup the Chroot on backing_server."""
 
214
        if backing_server is not None:
 
215
            self.backing_transport = transport.get_transport_from_url(
 
216
                backing_server.get_url())
 
217
        else:
 
218
            self.backing_transport = transport.get_transport_from_path('.')
 
219
        self.backing_transport.clone('added-by-filter').ensure_base()
 
220
        self.filter_func = lambda x: 'added-by-filter/' + x
 
221
        super(TestingPathFilteringServer, self).start_server()
 
222
 
 
223
    def get_bogus_url(self):
 
224
        raise NotImplementedError
 
225
 
 
226
 
 
227
class TestingChrootServer(chroot.ChrootServer):
 
228
 
 
229
    def __init__(self):
 
230
        """TestingChrootServer is not usable until start_server is called."""
 
231
        super(TestingChrootServer, self).__init__(None)
 
232
 
 
233
    def start_server(self, backing_server=None):
 
234
        """Setup the Chroot on backing_server."""
 
235
        if backing_server is not None:
 
236
            self.backing_transport = transport.get_transport_from_url(
 
237
                backing_server.get_url())
 
238
        else:
 
239
            self.backing_transport = transport.get_transport_from_path('.')
 
240
        super(TestingChrootServer, self).start_server()
 
241
 
 
242
    def get_bogus_url(self):
 
243
        raise NotImplementedError
 
244
 
 
245
 
 
246
class TestThread(cethread.CatchingExceptionThread):
 
247
 
 
248
    def join(self, timeout=5):
 
249
        """Overrides to use a default timeout.
 
250
 
 
251
        The default timeout is set to 5 and should expire only when a thread
 
252
        serving a client connection is hung.
 
253
        """
 
254
        super(TestThread, self).join(timeout)
 
255
        if timeout and self.isAlive():
 
256
            # The timeout expired without joining the thread, the thread is
 
257
            # therefore stucked and that's a failure as far as the test is
 
258
            # concerned. We used to hang here.
 
259
 
 
260
            # FIXME: we need to kill the thread, but as far as the test is
 
261
            # concerned, raising an assertion is too strong. On most of the
 
262
            # platforms, this doesn't occur, so just mentioning the problem is
 
263
            # enough for now -- vila 2010824
 
264
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
265
            #raise AssertionError('thread %s hung' % (self.name,))
 
266
 
 
267
 
 
268
class TestingTCPServerMixin(object):
 
269
    """Mixin to support running SocketServer.TCPServer in a thread.
 
270
 
 
271
    Tests are connecting from the main thread, the server has to be run in a
 
272
    separate thread.
 
273
    """
 
274
 
 
275
    def __init__(self):
 
276
        self.started = threading.Event()
 
277
        self.serving = None
 
278
        self.stopped = threading.Event()
 
279
        # We collect the resources used by the clients so we can release them
 
280
        # when shutting down
 
281
        self.clients = []
 
282
        self.ignored_exceptions = None
 
283
 
 
284
    def server_bind(self):
 
285
        self.socket.bind(self.server_address)
 
286
        self.server_address = self.socket.getsockname()
 
287
 
 
288
    def serve(self):
 
289
        self.serving = True
 
290
        self.stopped.clear()
 
291
        # We are listening and ready to accept connections
 
292
        self.started.set()
 
293
        try:
 
294
            while self.serving:
 
295
                # Really a connection but the python framework is generic and
 
296
                # call them requests
 
297
                self.handle_request()
 
298
            # Let's close the listening socket
 
299
            self.server_close()
 
300
        finally:
 
301
            self.stopped.set()
 
302
 
 
303
    def handle_request(self):
 
304
        """Handle one request.
 
305
 
 
306
        The python version swallows some socket exceptions and we don't use
 
307
        timeout, so we override it to better control the server behavior.
 
308
        """
 
309
        request, client_address = self.get_request()
 
310
        if self.verify_request(request, client_address):
 
311
            try:
 
312
                self.process_request(request, client_address)
 
313
            except:
 
314
                self.handle_error(request, client_address)
 
315
                self.close_request(request)
 
316
 
 
317
    def get_request(self):
 
318
        return self.socket.accept()
 
319
 
 
320
    def verify_request(self, request, client_address):
 
321
        """Verify the request.
 
322
 
 
323
        Return True if we should proceed with this request, False if we should
 
324
        not even touch a single byte in the socket ! This is useful when we
 
325
        stop the server with a dummy last connection.
 
326
        """
 
327
        return self.serving
 
328
 
 
329
    def handle_error(self, request, client_address):
 
330
        # Stop serving and re-raise the last exception seen
 
331
        self.serving = False
 
332
        # The following can be used for debugging purposes, it will display the
 
333
        # exception and the traceback just when it occurs instead of waiting
 
334
        # for the thread to be joined.
 
335
 
 
336
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
337
        raise
 
338
 
 
339
    def ignored_exceptions_during_shutdown(self, e):
 
340
        if sys.platform == 'win32':
 
341
            accepted_errnos = [errno.EBADF,
 
342
                               errno.EPIPE,
 
343
                               errno.WSAEBADF,
 
344
                               errno.WSAECONNRESET,
 
345
                               errno.WSAENOTCONN,
 
346
                               errno.WSAESHUTDOWN,
 
347
                               ]
 
348
        else:
 
349
            accepted_errnos = [errno.EBADF,
 
350
                               errno.ECONNRESET,
 
351
                               errno.ENOTCONN,
 
352
                               errno.EPIPE,
 
353
                               ]
 
354
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
355
            return True
 
356
        return False
 
357
 
 
358
    # The following methods are called by the main thread
 
359
 
 
360
    def stop_client_connections(self):
 
361
        while self.clients:
 
362
            c = self.clients.pop()
 
363
            self.shutdown_client(c)
 
364
 
 
365
    def shutdown_socket(self, sock):
 
366
        """Properly shutdown a socket.
 
367
 
 
368
        This should be called only when no other thread is trying to use the
 
369
        socket.
 
370
        """
 
371
        try:
 
372
            sock.shutdown(socket.SHUT_RDWR)
 
373
            sock.close()
 
374
        except Exception, e:
 
375
            if self.ignored_exceptions(e):
 
376
                pass
 
377
            else:
 
378
                raise
 
379
 
 
380
    # The following methods are called by the main thread
 
381
 
 
382
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
383
        self.ignored_exceptions = ignored_exceptions
 
384
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
385
 
 
386
    def _pending_exception(self, thread):
 
387
        """Raise server uncaught exception.
 
388
 
 
389
        Daughter classes can override this if they use daughter threads.
 
390
        """
 
391
        thread.pending_exception()
 
392
 
 
393
 
 
394
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
395
 
 
396
    def __init__(self, server_address, request_handler_class):
 
397
        TestingTCPServerMixin.__init__(self)
 
398
        SocketServer.TCPServer.__init__(self, server_address,
 
399
                                        request_handler_class)
 
400
 
 
401
    def get_request(self):
 
402
        """Get the request and client address from the socket."""
 
403
        sock, addr = TestingTCPServerMixin.get_request(self)
 
404
        self.clients.append((sock, addr))
 
405
        return sock, addr
 
406
 
 
407
    # The following methods are called by the main thread
 
408
 
 
409
    def shutdown_client(self, client):
 
410
        sock, addr = client
 
411
        self.shutdown_socket(sock)
 
412
 
 
413
 
 
414
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
415
                                SocketServer.ThreadingTCPServer):
 
416
 
 
417
    def __init__(self, server_address, request_handler_class):
 
418
        TestingTCPServerMixin.__init__(self)
 
419
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
420
                                                 request_handler_class)
 
421
 
 
422
    def get_request (self):
 
423
        """Get the request and client address from the socket."""
 
424
        sock, addr = TestingTCPServerMixin.get_request(self)
 
425
        # The thread is not create yet, it will be updated in process_request
 
426
        self.clients.append((sock, addr, None))
 
427
        return sock, addr
 
428
 
 
429
    def process_request_thread(self, started, stopped, request, client_address):
 
430
        started.set()
 
431
        SocketServer.ThreadingTCPServer.process_request_thread(
 
432
            self, request, client_address)
 
433
        self.close_request(request)
 
434
        stopped.set()
 
435
 
 
436
    def process_request(self, request, client_address):
 
437
        """Start a new thread to process the request."""
 
438
        started = threading.Event()
 
439
        stopped = threading.Event()
 
440
        t = TestThread(
 
441
            sync_event=stopped,
 
442
            name='%s -> %s' % (client_address, self.server_address),
 
443
            target = self.process_request_thread,
 
444
            args = (started, stopped, request, client_address))
 
445
        # Update the client description
 
446
        self.clients.pop()
 
447
        self.clients.append((request, client_address, t))
 
448
        # Propagate the exception handler since we must use the same one as
 
449
        # TestingTCPServer for connections running in their own threads.
 
450
        t.set_ignored_exceptions(self.ignored_exceptions)
 
451
        t.start()
 
452
        started.wait()
 
453
        if debug_threads():
 
454
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
455
        # If an exception occured during the thread start, it will get raised.
 
456
        t.pending_exception()
 
457
 
 
458
    # The following methods are called by the main thread
 
459
 
 
460
    def shutdown_client(self, client):
 
461
        sock, addr, connection_thread = client
 
462
        self.shutdown_socket(sock)
 
463
        if connection_thread is not None:
 
464
            # The thread has been created only if the request is processed but
 
465
            # after the connection is inited. This could happen during server
 
466
            # shutdown. If an exception occurred in the thread it will be
 
467
            # re-raised
 
468
            if debug_threads():
 
469
                sys.stderr.write('Client thread %s will be joined\n'
 
470
                                 % (connection_thread.name,))
 
471
            connection_thread.join()
 
472
 
 
473
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
474
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
475
                                                     ignored_exceptions)
 
476
        for sock, addr, connection_thread in self.clients:
 
477
            if connection_thread is not None:
 
478
                connection_thread.set_ignored_exceptions(
 
479
                    self.ignored_exceptions)
 
480
 
 
481
    def _pending_exception(self, thread):
 
482
        for sock, addr, connection_thread in self.clients:
 
483
            if connection_thread is not None:
 
484
                connection_thread.pending_exception()
 
485
        TestingTCPServerMixin._pending_exception(self, thread)
 
486
 
 
487
 
 
488
class TestingTCPServerInAThread(transport.Server):
 
489
    """A server in a thread that re-raise thread exceptions."""
 
490
 
 
491
    def __init__(self, server_address, server_class, request_handler_class):
 
492
        self.server_class = server_class
 
493
        self.request_handler_class = request_handler_class
 
494
        self.host, self.port = server_address
 
495
        self.server = None
 
496
        self._server_thread = None
 
497
 
 
498
    def __repr__(self):
 
499
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
500
 
 
501
    def create_server(self):
 
502
        return self.server_class((self.host, self.port),
 
503
                                 self.request_handler_class)
 
504
 
 
505
    def start_server(self):
 
506
        self.server = self.create_server()
 
507
        self._server_thread = TestThread(
 
508
            sync_event=self.server.started,
 
509
            target=self.run_server)
 
510
        self._server_thread.start()
 
511
        # Wait for the server thread to start (i.e release the lock)
 
512
        self.server.started.wait()
 
513
        # Get the real address, especially the port
 
514
        self.host, self.port = self.server.server_address
 
515
        self._server_thread.name = self.server.server_address
 
516
        if debug_threads():
 
517
            sys.stderr.write('Server thread %s started\n'
 
518
                             % (self._server_thread.name,))
 
519
        # If an exception occured during the server start, it will get raised,
 
520
        # otherwise, the server is blocked on its accept() call.
 
521
        self._server_thread.pending_exception()
 
522
        # From now on, we'll use a different event to ensure the server can set
 
523
        # its exception
 
524
        self._server_thread.set_sync_event(self.server.stopped)
 
525
 
 
526
    def run_server(self):
 
527
        self.server.serve()
 
528
 
 
529
    def stop_server(self):
 
530
        if self.server is None:
 
531
            return
 
532
        try:
 
533
            # The server has been started successfully, shut it down now.  As
 
534
            # soon as we stop serving, no more connection are accepted except
 
535
            # one to get out of the blocking listen.
 
536
            self.set_ignored_exceptions(
 
537
                self.server.ignored_exceptions_during_shutdown)
 
538
            self.server.serving = False
 
539
            if debug_threads():
 
540
                sys.stderr.write('Server thread %s will be joined\n'
 
541
                                 % (self._server_thread.name,))
 
542
            # The server is listening for a last connection, let's give it:
 
543
            last_conn = None
 
544
            try:
 
545
                last_conn = osutils.connect_socket((self.host, self.port))
 
546
            except socket.error, e:
 
547
                # But ignore connection errors as the point is to unblock the
 
548
                # server thread, it may happen that it's not blocked or even
 
549
                # not started.
 
550
                pass
 
551
            # We start shutting down the clients while the server itself is
 
552
            # shutting down.
 
553
            self.server.stop_client_connections()
 
554
            # Now we wait for the thread running self.server.serve() to finish
 
555
            self.server.stopped.wait()
 
556
            if last_conn is not None:
 
557
                # Close the last connection without trying to use it. The
 
558
                # server will not process a single byte on that socket to avoid
 
559
                # complications (SSL starts with a handshake for example).
 
560
                last_conn.close()
 
561
            # Check for any exception that could have occurred in the server
 
562
            # thread
 
563
            try:
 
564
                self._server_thread.join()
 
565
            except Exception, e:
 
566
                if self.server.ignored_exceptions(e):
 
567
                    pass
 
568
                else:
 
569
                    raise
 
570
        finally:
 
571
            # Make sure we can be called twice safely, note that this means
 
572
            # that we will raise a single exception even if several occurred in
 
573
            # the various threads involved.
 
574
            self.server = None
 
575
 
 
576
    def set_ignored_exceptions(self, ignored_exceptions):
 
577
        """Install an exception handler for the server."""
 
578
        self.server.set_ignored_exceptions(self._server_thread,
 
579
                                           ignored_exceptions)
 
580
 
 
581
    def pending_exception(self):
 
582
        """Raise uncaught exception in the server."""
 
583
        self.server._pending_exception(self._server_thread)
 
584
 
 
585
 
 
586
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
587
                                    medium.SmartServerSocketStreamMedium):
 
588
 
 
589
    def __init__(self, request, client_address, server):
 
590
        medium.SmartServerSocketStreamMedium.__init__(
 
591
            self, request, server.backing_transport,
 
592
            server.root_client_path,
 
593
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
594
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
595
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
596
                                                 server)
 
597
 
 
598
    def handle(self):
 
599
        while not self.finished:
 
600
            server_protocol = self._build_protocol()
 
601
            self._serve_one_request(server_protocol)
 
602
 
 
603
 
 
604
_DEFAULT_TESTING_CLIENT_TIMEOUT = 4.0
 
605
 
 
606
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
607
 
 
608
    def __init__(self, server_address, request_handler_class,
 
609
                 backing_transport, root_client_path):
 
610
        TestingThreadingTCPServer.__init__(self, server_address,
 
611
                                           request_handler_class)
 
612
        server.SmartTCPServer.__init__(self, backing_transport,
 
613
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
614
 
 
615
    def serve(self):
 
616
        self.run_server_started_hooks()
 
617
        try:
 
618
            TestingThreadingTCPServer.serve(self)
 
619
        finally:
 
620
            self.run_server_stopped_hooks()
 
621
 
 
622
    def get_url(self):
 
623
        """Return the url of the server"""
 
624
        return "bzr://%s:%d/" % self.server_address
 
625
 
 
626
 
 
627
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
 
628
    """Server suitable for use by transport tests.
 
629
 
 
630
    This server is backed by the process's cwd.
 
631
    """
 
632
    def __init__(self, thread_name_suffix=''):
 
633
        self.client_path_extra = None
 
634
        self.thread_name_suffix = thread_name_suffix
 
635
        self.host = '127.0.0.1'
 
636
        self.port = 0
 
637
        super(SmartTCPServer_for_testing, self).__init__(
 
638
                (self.host, self.port),
 
639
                TestingSmartServer,
 
640
                TestingSmartConnectionHandler)
 
641
 
 
642
    def create_server(self):
 
643
        return self.server_class((self.host, self.port),
 
644
                                 self.request_handler_class,
 
645
                                 self.backing_transport,
 
646
                                 self.root_client_path)
 
647
 
 
648
 
 
649
    def start_server(self, backing_transport_server=None,
 
650
                     client_path_extra='/extra/'):
 
651
        """Set up server for testing.
 
652
 
 
653
        :param backing_transport_server: backing server to use.  If not
 
654
            specified, a LocalURLServer at the current working directory will
 
655
            be used.
 
656
        :param client_path_extra: a path segment starting with '/' to append to
 
657
            the root URL for this server.  For instance, a value of '/foo/bar/'
 
658
            will mean the root of the backing transport will be published at a
 
659
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
 
660
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
 
661
            by default will fail unless they do the necessary path translation.
 
662
        """
 
663
        if not client_path_extra.startswith('/'):
 
664
            raise ValueError(client_path_extra)
 
665
        self.root_client_path = self.client_path_extra = client_path_extra
 
666
        from bzrlib.transport.chroot import ChrootServer
 
667
        if backing_transport_server is None:
 
668
            backing_transport_server = LocalURLServer()
 
669
        self.chroot_server = ChrootServer(
 
670
            self.get_backing_transport(backing_transport_server))
 
671
        self.chroot_server.start_server()
 
672
        self.backing_transport = transport.get_transport_from_url(
 
673
            self.chroot_server.get_url())
 
674
        super(SmartTCPServer_for_testing, self).start_server()
 
675
 
 
676
    def stop_server(self):
 
677
        try:
 
678
            super(SmartTCPServer_for_testing, self).stop_server()
 
679
        finally:
 
680
            self.chroot_server.stop_server()
 
681
 
 
682
    def get_backing_transport(self, backing_transport_server):
 
683
        """Get a backing transport from a server we are decorating."""
 
684
        return transport.get_transport_from_url(
 
685
            backing_transport_server.get_url())
 
686
 
 
687
    def get_url(self):
 
688
        url = self.server.get_url()
 
689
        return url[:-1] + self.client_path_extra
 
690
 
 
691
    def get_bogus_url(self):
 
692
        """Return a URL which will fail to connect"""
 
693
        return 'bzr://127.0.0.1:1/'
 
694
 
 
695
 
 
696
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
 
697
    """Get a readonly server for testing."""
 
698
 
 
699
    def get_backing_transport(self, backing_transport_server):
 
700
        """Get a backing transport from a server we are decorating."""
 
701
        url = 'readonly+' + backing_transport_server.get_url()
 
702
        return transport.get_transport_from_url(url)
 
703
 
 
704
 
 
705
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
 
706
    """A variation of SmartTCPServer_for_testing that limits the client to
 
707
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
 
708
    """
 
709
 
 
710
    def get_url(self):
 
711
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
 
712
        url = 'bzr-v2://' + url[len('bzr://'):]
 
713
        return url
 
714
 
 
715
 
 
716
class ReadonlySmartTCPServer_for_testing_v2_only(
 
717
    SmartTCPServer_for_testing_v2_only):
 
718
    """Get a readonly server for testing."""
 
719
 
 
720
    def get_backing_transport(self, backing_transport_server):
 
721
        """Get a backing transport from a server we are decorating."""
 
722
        url = 'readonly+' + backing_transport_server.get_url()
 
723
        return transport.get_transport_from_url(url)