~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Aaron Bentley
  • Date: 2009-08-14 18:17:15 UTC
  • mto: (4603.1.22 shelve-editor)
  • mto: This revision was merged to the branch mainline in revision 4795.
  • Revision ID: aaron@aaronbentley.com-20090814181715-59qnhbov2stgzqt2
Allow configuring change editor.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010, 2011 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
import errno
18
 
import socket
19
 
import SocketServer
20
 
import sys
21
 
import threading
22
 
import traceback
23
 
 
24
 
 
25
 
from bzrlib import (
26
 
    cethread,
27
 
    errors,
28
 
    osutils,
29
 
    transport,
30
 
    urlutils,
31
 
    )
32
 
from bzrlib.transport import (
33
 
    chroot,
34
 
    pathfilter,
35
 
    )
36
 
from bzrlib.smart import (
37
 
    medium,
38
 
    server,
39
 
    )
40
 
 
41
 
 
42
 
def debug_threads():
43
 
    # FIXME: There is a dependency loop between bzrlib.tests and
44
 
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
45
 
    # defining this function is enough for our needs. -- vila 20100611
46
 
    from bzrlib import tests
47
 
    return 'threads' in tests.selftest_debug_flags
48
 
 
49
 
 
50
 
class TestServer(transport.Server):
51
 
    """A Transport Server dedicated to tests.
52
 
 
53
 
    The TestServer interface provides a server for a given transport. We use
54
 
    these servers as loopback testing tools. For any given transport the
55
 
    Servers it provides must either allow writing, or serve the contents
56
 
    of os.getcwdu() at the time start_server is called.
57
 
 
58
 
    Note that these are real servers - they must implement all the things
59
 
    that we want bzr transports to take advantage of.
60
 
    """
61
 
 
62
 
    def get_url(self):
63
 
        """Return a url for this server.
64
 
 
65
 
        If the transport does not represent a disk directory (i.e. it is
66
 
        a database like svn, or a memory only transport, it should return
67
 
        a connection to a newly established resource for this Server.
68
 
        Otherwise it should return a url that will provide access to the path
69
 
        that was os.getcwdu() when start_server() was called.
70
 
 
71
 
        Subsequent calls will return the same resource.
72
 
        """
73
 
        raise NotImplementedError
74
 
 
75
 
    def get_bogus_url(self):
76
 
        """Return a url for this protocol, that will fail to connect.
77
 
 
78
 
        This may raise NotImplementedError to indicate that this server cannot
79
 
        provide bogus urls.
80
 
        """
81
 
        raise NotImplementedError
82
 
 
83
 
 
84
 
class LocalURLServer(TestServer):
85
 
    """A pretend server for local transports, using file:// urls.
86
 
 
87
 
    Of course no actual server is required to access the local filesystem, so
88
 
    this just exists to tell the test code how to get to it.
89
 
    """
90
 
 
91
 
    def start_server(self):
92
 
        pass
93
 
 
94
 
    def get_url(self):
95
 
        """See Transport.Server.get_url."""
96
 
        return urlutils.local_path_to_url('')
97
 
 
98
 
 
99
 
class DecoratorServer(TestServer):
100
 
    """Server for the TransportDecorator for testing with.
101
 
 
102
 
    To use this when subclassing TransportDecorator, override override the
103
 
    get_decorator_class method.
104
 
    """
105
 
 
106
 
    def start_server(self, server=None):
107
 
        """See bzrlib.transport.Server.start_server.
108
 
 
109
 
        :server: decorate the urls given by server. If not provided a
110
 
        LocalServer is created.
111
 
        """
112
 
        if server is not None:
113
 
            self._made_server = False
114
 
            self._server = server
115
 
        else:
116
 
            self._made_server = True
117
 
            self._server = LocalURLServer()
118
 
            self._server.start_server()
119
 
 
120
 
    def stop_server(self):
121
 
        if self._made_server:
122
 
            self._server.stop_server()
123
 
 
124
 
    def get_decorator_class(self):
125
 
        """Return the class of the decorators we should be constructing."""
126
 
        raise NotImplementedError(self.get_decorator_class)
127
 
 
128
 
    def get_url_prefix(self):
129
 
        """What URL prefix does this decorator produce?"""
130
 
        return self.get_decorator_class()._get_url_prefix()
131
 
 
132
 
    def get_bogus_url(self):
133
 
        """See bzrlib.transport.Server.get_bogus_url."""
134
 
        return self.get_url_prefix() + self._server.get_bogus_url()
135
 
 
136
 
    def get_url(self):
137
 
        """See bzrlib.transport.Server.get_url."""
138
 
        return self.get_url_prefix() + self._server.get_url()
139
 
 
140
 
 
141
 
class BrokenRenameServer(DecoratorServer):
142
 
    """Server for the BrokenRenameTransportDecorator for testing with."""
143
 
 
144
 
    def get_decorator_class(self):
145
 
        from bzrlib.transport import brokenrename
146
 
        return brokenrename.BrokenRenameTransportDecorator
147
 
 
148
 
 
149
 
class FakeNFSServer(DecoratorServer):
150
 
    """Server for the FakeNFSTransportDecorator for testing with."""
151
 
 
152
 
    def get_decorator_class(self):
153
 
        from bzrlib.transport import fakenfs
154
 
        return fakenfs.FakeNFSTransportDecorator
155
 
 
156
 
 
157
 
class FakeVFATServer(DecoratorServer):
158
 
    """A server that suggests connections through FakeVFATTransportDecorator
159
 
 
160
 
    For use in testing.
161
 
    """
162
 
 
163
 
    def get_decorator_class(self):
164
 
        from bzrlib.transport import fakevfat
165
 
        return fakevfat.FakeVFATTransportDecorator
166
 
 
167
 
 
168
 
class LogDecoratorServer(DecoratorServer):
169
 
    """Server for testing."""
170
 
 
171
 
    def get_decorator_class(self):
172
 
        from bzrlib.transport import log
173
 
        return log.TransportLogDecorator
174
 
 
175
 
 
176
 
class NoSmartTransportServer(DecoratorServer):
177
 
    """Server for the NoSmartTransportDecorator for testing with."""
178
 
 
179
 
    def get_decorator_class(self):
180
 
        from bzrlib.transport import nosmart
181
 
        return nosmart.NoSmartTransportDecorator
182
 
 
183
 
 
184
 
class ReadonlyServer(DecoratorServer):
185
 
    """Server for the ReadonlyTransportDecorator for testing with."""
186
 
 
187
 
    def get_decorator_class(self):
188
 
        from bzrlib.transport import readonly
189
 
        return readonly.ReadonlyTransportDecorator
190
 
 
191
 
 
192
 
class TraceServer(DecoratorServer):
193
 
    """Server for the TransportTraceDecorator for testing with."""
194
 
 
195
 
    def get_decorator_class(self):
196
 
        from bzrlib.transport import trace
197
 
        return trace.TransportTraceDecorator
198
 
 
199
 
 
200
 
class UnlistableServer(DecoratorServer):
201
 
    """Server for the UnlistableTransportDecorator for testing with."""
202
 
 
203
 
    def get_decorator_class(self):
204
 
        from bzrlib.transport import unlistable
205
 
        return unlistable.UnlistableTransportDecorator
206
 
 
207
 
 
208
 
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
209
 
 
210
 
    def __init__(self):
211
 
        """TestingPathFilteringServer is not usable until start_server
212
 
        is called."""
213
 
 
214
 
    def start_server(self, backing_server=None):
215
 
        """Setup the Chroot on backing_server."""
216
 
        if backing_server is not None:
217
 
            self.backing_transport = transport.get_transport_from_url(
218
 
                backing_server.get_url())
219
 
        else:
220
 
            self.backing_transport = transport.get_transport_from_path('.')
221
 
        self.backing_transport.clone('added-by-filter').ensure_base()
222
 
        self.filter_func = lambda x: 'added-by-filter/' + x
223
 
        super(TestingPathFilteringServer, self).start_server()
224
 
 
225
 
    def get_bogus_url(self):
226
 
        raise NotImplementedError
227
 
 
228
 
 
229
 
class TestingChrootServer(chroot.ChrootServer):
230
 
 
231
 
    def __init__(self):
232
 
        """TestingChrootServer is not usable until start_server is called."""
233
 
        super(TestingChrootServer, self).__init__(None)
234
 
 
235
 
    def start_server(self, backing_server=None):
236
 
        """Setup the Chroot on backing_server."""
237
 
        if backing_server is not None:
238
 
            self.backing_transport = transport.get_transport_from_url(
239
 
                backing_server.get_url())
240
 
        else:
241
 
            self.backing_transport = transport.get_transport_from_path('.')
242
 
        super(TestingChrootServer, self).start_server()
243
 
 
244
 
    def get_bogus_url(self):
245
 
        raise NotImplementedError
246
 
 
247
 
 
248
 
class TestThread(cethread.CatchingExceptionThread):
249
 
 
250
 
    def join(self, timeout=5):
251
 
        """Overrides to use a default timeout.
252
 
 
253
 
        The default timeout is set to 5 and should expire only when a thread
254
 
        serving a client connection is hung.
255
 
        """
256
 
        super(TestThread, self).join(timeout)
257
 
        if timeout and self.isAlive():
258
 
            # The timeout expired without joining the thread, the thread is
259
 
            # therefore stucked and that's a failure as far as the test is
260
 
            # concerned. We used to hang here.
261
 
 
262
 
            # FIXME: we need to kill the thread, but as far as the test is
263
 
            # concerned, raising an assertion is too strong. On most of the
264
 
            # platforms, this doesn't occur, so just mentioning the problem is
265
 
            # enough for now -- vila 2010824
266
 
            sys.stderr.write('thread %s hung\n' % (self.name,))
267
 
            #raise AssertionError('thread %s hung' % (self.name,))
268
 
 
269
 
 
270
 
class TestingTCPServerMixin(object):
271
 
    """Mixin to support running SocketServer.TCPServer in a thread.
272
 
 
273
 
    Tests are connecting from the main thread, the server has to be run in a
274
 
    separate thread.
275
 
    """
276
 
 
277
 
    def __init__(self):
278
 
        self.started = threading.Event()
279
 
        self.serving = None
280
 
        self.stopped = threading.Event()
281
 
        # We collect the resources used by the clients so we can release them
282
 
        # when shutting down
283
 
        self.clients = []
284
 
        self.ignored_exceptions = None
285
 
 
286
 
    def server_bind(self):
287
 
        self.socket.bind(self.server_address)
288
 
        self.server_address = self.socket.getsockname()
289
 
 
290
 
    def serve(self):
291
 
        self.serving = True
292
 
        # We are listening and ready to accept connections
293
 
        self.started.set()
294
 
        try:
295
 
            while self.serving:
296
 
                # Really a connection but the python framework is generic and
297
 
                # call them requests
298
 
                self.handle_request()
299
 
            # Let's close the listening socket
300
 
            self.server_close()
301
 
        finally:
302
 
            self.stopped.set()
303
 
 
304
 
    def handle_request(self):
305
 
        """Handle one request.
306
 
 
307
 
        The python version swallows some socket exceptions and we don't use
308
 
        timeout, so we override it to better control the server behavior.
309
 
        """
310
 
        request, client_address = self.get_request()
311
 
        if self.verify_request(request, client_address):
312
 
            try:
313
 
                self.process_request(request, client_address)
314
 
            except:
315
 
                self.handle_error(request, client_address)
316
 
        else:
317
 
            self.close_request(request)
318
 
 
319
 
    def get_request(self):
320
 
        return self.socket.accept()
321
 
 
322
 
    def verify_request(self, request, client_address):
323
 
        """Verify the request.
324
 
 
325
 
        Return True if we should proceed with this request, False if we should
326
 
        not even touch a single byte in the socket ! This is useful when we
327
 
        stop the server with a dummy last connection.
328
 
        """
329
 
        return self.serving
330
 
 
331
 
    def handle_error(self, request, client_address):
332
 
        # Stop serving and re-raise the last exception seen
333
 
        self.serving = False
334
 
        # The following can be used for debugging purposes, it will display the
335
 
        # exception and the traceback just when it occurs instead of waiting
336
 
        # for the thread to be joined.
337
 
        # SocketServer.BaseServer.handle_error(self, request, client_address)
338
 
 
339
 
        # We call close_request manually, because we are going to raise an
340
 
        # exception. The SocketServer implementation calls:
341
 
        #   handle_error(...)
342
 
        #   close_request(...)
343
 
        # But because we raise the exception, close_request will never be
344
 
        # triggered. This helps client not block waiting for a response when
345
 
        # the server gets an exception.
346
 
        self.close_request(request)
347
 
        raise
348
 
 
349
 
    def ignored_exceptions_during_shutdown(self, e):
350
 
        if sys.platform == 'win32':
351
 
            accepted_errnos = [errno.EBADF,
352
 
                               errno.EPIPE,
353
 
                               errno.WSAEBADF,
354
 
                               errno.WSAECONNRESET,
355
 
                               errno.WSAENOTCONN,
356
 
                               errno.WSAESHUTDOWN,
357
 
                               ]
358
 
        else:
359
 
            accepted_errnos = [errno.EBADF,
360
 
                               errno.ECONNRESET,
361
 
                               errno.ENOTCONN,
362
 
                               errno.EPIPE,
363
 
                               ]
364
 
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
365
 
            return True
366
 
        return False
367
 
 
368
 
    # The following methods are called by the main thread
369
 
 
370
 
    def stop_client_connections(self):
371
 
        while self.clients:
372
 
            c = self.clients.pop()
373
 
            self.shutdown_client(c)
374
 
 
375
 
    def shutdown_socket(self, sock):
376
 
        """Properly shutdown a socket.
377
 
 
378
 
        This should be called only when no other thread is trying to use the
379
 
        socket.
380
 
        """
381
 
        try:
382
 
            sock.shutdown(socket.SHUT_RDWR)
383
 
            sock.close()
384
 
        except Exception, e:
385
 
            if self.ignored_exceptions(e):
386
 
                pass
387
 
            else:
388
 
                raise
389
 
 
390
 
    # The following methods are called by the main thread
391
 
 
392
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
393
 
        self.ignored_exceptions = ignored_exceptions
394
 
        thread.set_ignored_exceptions(self.ignored_exceptions)
395
 
 
396
 
    def _pending_exception(self, thread):
397
 
        """Raise server uncaught exception.
398
 
 
399
 
        Daughter classes can override this if they use daughter threads.
400
 
        """
401
 
        thread.pending_exception()
402
 
 
403
 
 
404
 
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
405
 
 
406
 
    def __init__(self, server_address, request_handler_class):
407
 
        TestingTCPServerMixin.__init__(self)
408
 
        SocketServer.TCPServer.__init__(self, server_address,
409
 
                                        request_handler_class)
410
 
 
411
 
    def get_request(self):
412
 
        """Get the request and client address from the socket."""
413
 
        sock, addr = TestingTCPServerMixin.get_request(self)
414
 
        self.clients.append((sock, addr))
415
 
        return sock, addr
416
 
 
417
 
    # The following methods are called by the main thread
418
 
 
419
 
    def shutdown_client(self, client):
420
 
        sock, addr = client
421
 
        self.shutdown_socket(sock)
422
 
 
423
 
 
424
 
class TestingThreadingTCPServer(TestingTCPServerMixin,
425
 
                                SocketServer.ThreadingTCPServer):
426
 
 
427
 
    def __init__(self, server_address, request_handler_class):
428
 
        TestingTCPServerMixin.__init__(self)
429
 
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
430
 
                                                 request_handler_class)
431
 
 
432
 
    def get_request(self):
433
 
        """Get the request and client address from the socket."""
434
 
        sock, addr = TestingTCPServerMixin.get_request(self)
435
 
        # The thread is not create yet, it will be updated in process_request
436
 
        self.clients.append((sock, addr, None))
437
 
        return sock, addr
438
 
 
439
 
    def process_request_thread(self, started, stopped, request, client_address):
440
 
        started.set()
441
 
        SocketServer.ThreadingTCPServer.process_request_thread(
442
 
            self, request, client_address)
443
 
        self.close_request(request)
444
 
        stopped.set()
445
 
 
446
 
    def process_request(self, request, client_address):
447
 
        """Start a new thread to process the request."""
448
 
        started = threading.Event()
449
 
        stopped = threading.Event()
450
 
        t = TestThread(
451
 
            sync_event=stopped,
452
 
            name='%s -> %s' % (client_address, self.server_address),
453
 
            target=self.process_request_thread,
454
 
            args=(started, stopped, request, client_address))
455
 
        # Update the client description
456
 
        self.clients.pop()
457
 
        self.clients.append((request, client_address, t))
458
 
        # Propagate the exception handler since we must use the same one as
459
 
        # TestingTCPServer for connections running in their own threads.
460
 
        t.set_ignored_exceptions(self.ignored_exceptions)
461
 
        t.start()
462
 
        started.wait()
463
 
        if debug_threads():
464
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
465
 
        # If an exception occured during the thread start, it will get raised.
466
 
        # In rare cases, an exception raised during the request processing may
467
 
        # also get caught here (see http://pad.lv/869366)
468
 
        t.pending_exception()
469
 
 
470
 
    # The following methods are called by the main thread
471
 
 
472
 
    def shutdown_client(self, client):
473
 
        sock, addr, connection_thread = client
474
 
        self.shutdown_socket(sock)
475
 
        if connection_thread is not None:
476
 
            # The thread has been created only if the request is processed but
477
 
            # after the connection is inited. This could happen during server
478
 
            # shutdown. If an exception occurred in the thread it will be
479
 
            # re-raised
480
 
            if debug_threads():
481
 
                sys.stderr.write('Client thread %s will be joined\n'
482
 
                                 % (connection_thread.name,))
483
 
            connection_thread.join()
484
 
 
485
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
486
 
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
487
 
                                                     ignored_exceptions)
488
 
        for sock, addr, connection_thread in self.clients:
489
 
            if connection_thread is not None:
490
 
                connection_thread.set_ignored_exceptions(
491
 
                    self.ignored_exceptions)
492
 
 
493
 
    def _pending_exception(self, thread):
494
 
        for sock, addr, connection_thread in self.clients:
495
 
            if connection_thread is not None:
496
 
                connection_thread.pending_exception()
497
 
        TestingTCPServerMixin._pending_exception(self, thread)
498
 
 
499
 
 
500
 
class TestingTCPServerInAThread(transport.Server):
501
 
    """A server in a thread that re-raise thread exceptions."""
502
 
 
503
 
    def __init__(self, server_address, server_class, request_handler_class):
504
 
        self.server_class = server_class
505
 
        self.request_handler_class = request_handler_class
506
 
        self.host, self.port = server_address
507
 
        self.server = None
508
 
        self._server_thread = None
509
 
 
510
 
    def __repr__(self):
511
 
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
512
 
 
513
 
    def create_server(self):
514
 
        return self.server_class((self.host, self.port),
515
 
                                 self.request_handler_class)
516
 
 
517
 
    def start_server(self):
518
 
        self.server = self.create_server()
519
 
        self._server_thread = TestThread(
520
 
            sync_event=self.server.started,
521
 
            target=self.run_server)
522
 
        self._server_thread.start()
523
 
        # Wait for the server thread to start (i.e. release the lock)
524
 
        self.server.started.wait()
525
 
        # Get the real address, especially the port
526
 
        self.host, self.port = self.server.server_address
527
 
        self._server_thread.name = self.server.server_address
528
 
        if debug_threads():
529
 
            sys.stderr.write('Server thread %s started\n'
530
 
                             % (self._server_thread.name,))
531
 
        # If an exception occured during the server start, it will get raised,
532
 
        # otherwise, the server is blocked on its accept() call.
533
 
        self._server_thread.pending_exception()
534
 
        # From now on, we'll use a different event to ensure the server can set
535
 
        # its exception
536
 
        self._server_thread.set_sync_event(self.server.stopped)
537
 
 
538
 
    def run_server(self):
539
 
        self.server.serve()
540
 
 
541
 
    def stop_server(self):
542
 
        if self.server is None:
543
 
            return
544
 
        try:
545
 
            # The server has been started successfully, shut it down now.  As
546
 
            # soon as we stop serving, no more connection are accepted except
547
 
            # one to get out of the blocking listen.
548
 
            self.set_ignored_exceptions(
549
 
                self.server.ignored_exceptions_during_shutdown)
550
 
            self.server.serving = False
551
 
            if debug_threads():
552
 
                sys.stderr.write('Server thread %s will be joined\n'
553
 
                                 % (self._server_thread.name,))
554
 
            # The server is listening for a last connection, let's give it:
555
 
            last_conn = None
556
 
            try:
557
 
                last_conn = osutils.connect_socket((self.host, self.port))
558
 
            except socket.error, e:
559
 
                # But ignore connection errors as the point is to unblock the
560
 
                # server thread, it may happen that it's not blocked or even
561
 
                # not started.
562
 
                pass
563
 
            # We start shutting down the clients while the server itself is
564
 
            # shutting down.
565
 
            self.server.stop_client_connections()
566
 
            # Now we wait for the thread running self.server.serve() to finish
567
 
            self.server.stopped.wait()
568
 
            if last_conn is not None:
569
 
                # Close the last connection without trying to use it. The
570
 
                # server will not process a single byte on that socket to avoid
571
 
                # complications (SSL starts with a handshake for example).
572
 
                last_conn.close()
573
 
            # Check for any exception that could have occurred in the server
574
 
            # thread
575
 
            try:
576
 
                self._server_thread.join()
577
 
            except Exception, e:
578
 
                if self.server.ignored_exceptions(e):
579
 
                    pass
580
 
                else:
581
 
                    raise
582
 
        finally:
583
 
            # Make sure we can be called twice safely, note that this means
584
 
            # that we will raise a single exception even if several occurred in
585
 
            # the various threads involved.
586
 
            self.server = None
587
 
 
588
 
    def set_ignored_exceptions(self, ignored_exceptions):
589
 
        """Install an exception handler for the server."""
590
 
        self.server.set_ignored_exceptions(self._server_thread,
591
 
                                           ignored_exceptions)
592
 
 
593
 
    def pending_exception(self):
594
 
        """Raise uncaught exception in the server."""
595
 
        self.server._pending_exception(self._server_thread)
596
 
 
597
 
 
598
 
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
599
 
                                    medium.SmartServerSocketStreamMedium):
600
 
 
601
 
    def __init__(self, request, client_address, server):
602
 
        medium.SmartServerSocketStreamMedium.__init__(
603
 
            self, request, server.backing_transport,
604
 
            server.root_client_path,
605
 
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
606
 
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
607
 
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
608
 
                                                 server)
609
 
 
610
 
    def handle(self):
611
 
        try:
612
 
            while not self.finished:
613
 
                server_protocol = self._build_protocol()
614
 
                self._serve_one_request(server_protocol)
615
 
        except errors.ConnectionTimeout:
616
 
            # idle connections aren't considered a failure of the server
617
 
            return
618
 
 
619
 
 
620
 
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
621
 
 
622
 
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
623
 
 
624
 
    def __init__(self, server_address, request_handler_class,
625
 
                 backing_transport, root_client_path):
626
 
        TestingThreadingTCPServer.__init__(self, server_address,
627
 
                                           request_handler_class)
628
 
        server.SmartTCPServer.__init__(self, backing_transport,
629
 
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
630
 
 
631
 
    def serve(self):
632
 
        self.run_server_started_hooks()
633
 
        try:
634
 
            TestingThreadingTCPServer.serve(self)
635
 
        finally:
636
 
            self.run_server_stopped_hooks()
637
 
 
638
 
    def get_url(self):
639
 
        """Return the url of the server"""
640
 
        return "bzr://%s:%d/" % self.server_address
641
 
 
642
 
 
643
 
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
644
 
    """Server suitable for use by transport tests.
645
 
 
646
 
    This server is backed by the process's cwd.
647
 
    """
648
 
    def __init__(self, thread_name_suffix=''):
649
 
        self.client_path_extra = None
650
 
        self.thread_name_suffix = thread_name_suffix
651
 
        self.host = '127.0.0.1'
652
 
        self.port = 0
653
 
        super(SmartTCPServer_for_testing, self).__init__(
654
 
                (self.host, self.port),
655
 
                TestingSmartServer,
656
 
                TestingSmartConnectionHandler)
657
 
 
658
 
    def create_server(self):
659
 
        return self.server_class((self.host, self.port),
660
 
                                 self.request_handler_class,
661
 
                                 self.backing_transport,
662
 
                                 self.root_client_path)
663
 
 
664
 
 
665
 
    def start_server(self, backing_transport_server=None,
666
 
                     client_path_extra='/extra/'):
667
 
        """Set up server for testing.
668
 
 
669
 
        :param backing_transport_server: backing server to use.  If not
670
 
            specified, a LocalURLServer at the current working directory will
671
 
            be used.
672
 
        :param client_path_extra: a path segment starting with '/' to append to
673
 
            the root URL for this server.  For instance, a value of '/foo/bar/'
674
 
            will mean the root of the backing transport will be published at a
675
 
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
676
 
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
677
 
            by default will fail unless they do the necessary path translation.
678
 
        """
679
 
        if not client_path_extra.startswith('/'):
680
 
            raise ValueError(client_path_extra)
681
 
        self.root_client_path = self.client_path_extra = client_path_extra
682
 
        from bzrlib.transport.chroot import ChrootServer
683
 
        if backing_transport_server is None:
684
 
            backing_transport_server = LocalURLServer()
685
 
        self.chroot_server = ChrootServer(
686
 
            self.get_backing_transport(backing_transport_server))
687
 
        self.chroot_server.start_server()
688
 
        self.backing_transport = transport.get_transport_from_url(
689
 
            self.chroot_server.get_url())
690
 
        super(SmartTCPServer_for_testing, self).start_server()
691
 
 
692
 
    def stop_server(self):
693
 
        try:
694
 
            super(SmartTCPServer_for_testing, self).stop_server()
695
 
        finally:
696
 
            self.chroot_server.stop_server()
697
 
 
698
 
    def get_backing_transport(self, backing_transport_server):
699
 
        """Get a backing transport from a server we are decorating."""
700
 
        return transport.get_transport_from_url(
701
 
            backing_transport_server.get_url())
702
 
 
703
 
    def get_url(self):
704
 
        url = self.server.get_url()
705
 
        return url[:-1] + self.client_path_extra
706
 
 
707
 
    def get_bogus_url(self):
708
 
        """Return a URL which will fail to connect"""
709
 
        return 'bzr://127.0.0.1:1/'
710
 
 
711
 
 
712
 
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
713
 
    """Get a readonly server for testing."""
714
 
 
715
 
    def get_backing_transport(self, backing_transport_server):
716
 
        """Get a backing transport from a server we are decorating."""
717
 
        url = 'readonly+' + backing_transport_server.get_url()
718
 
        return transport.get_transport_from_url(url)
719
 
 
720
 
 
721
 
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
722
 
    """A variation of SmartTCPServer_for_testing that limits the client to
723
 
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
724
 
    """
725
 
 
726
 
    def get_url(self):
727
 
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
728
 
        url = 'bzr-v2://' + url[len('bzr://'):]
729
 
        return url
730
 
 
731
 
 
732
 
class ReadonlySmartTCPServer_for_testing_v2_only(
733
 
    SmartTCPServer_for_testing_v2_only):
734
 
    """Get a readonly server for testing."""
735
 
 
736
 
    def get_backing_transport(self, backing_transport_server):
737
 
        """Get a backing transport from a server we are decorating."""
738
 
        url = 'readonly+' + backing_transport_server.get_url()
739
 
        return transport.get_transport_from_url(url)