~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Vincent Ladeuil
  • Date: 2010-09-24 09:56:50 UTC
  • mto: This revision was merged to the branch mainline in revision 5446.
  • Revision ID: v.ladeuil+lp@free.fr-20100924095650-okd49n2o18q9zkmb
Clarify SRU bug nomination.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import select
 
21
import sys
 
22
import threading
 
23
 
 
24
 
 
25
from bzrlib import (
 
26
    osutils,
 
27
    transport,
 
28
    urlutils,
 
29
    )
 
30
from bzrlib.transport import (
 
31
    chroot,
 
32
    pathfilter,
 
33
    )
 
34
from bzrlib.smart import (
 
35
    medium,
 
36
    server,
 
37
    )
 
38
 
 
39
 
 
40
def debug_threads():
 
41
    # FIXME: There is a dependency loop between bzrlib.tests and
 
42
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
43
    # defining this function is enough for our needs. -- vila 20100611
 
44
    from bzrlib import tests
 
45
    return 'threads' in tests.selftest_debug_flags
 
46
 
 
47
 
 
48
class TestServer(transport.Server):
 
49
    """A Transport Server dedicated to tests.
 
50
 
 
51
    The TestServer interface provides a server for a given transport. We use
 
52
    these servers as loopback testing tools. For any given transport the
 
53
    Servers it provides must either allow writing, or serve the contents
 
54
    of os.getcwdu() at the time start_server is called.
 
55
 
 
56
    Note that these are real servers - they must implement all the things
 
57
    that we want bzr transports to take advantage of.
 
58
    """
 
59
 
 
60
    def get_url(self):
 
61
        """Return a url for this server.
 
62
 
 
63
        If the transport does not represent a disk directory (i.e. it is
 
64
        a database like svn, or a memory only transport, it should return
 
65
        a connection to a newly established resource for this Server.
 
66
        Otherwise it should return a url that will provide access to the path
 
67
        that was os.getcwdu() when start_server() was called.
 
68
 
 
69
        Subsequent calls will return the same resource.
 
70
        """
 
71
        raise NotImplementedError
 
72
 
 
73
    def get_bogus_url(self):
 
74
        """Return a url for this protocol, that will fail to connect.
 
75
 
 
76
        This may raise NotImplementedError to indicate that this server cannot
 
77
        provide bogus urls.
 
78
        """
 
79
        raise NotImplementedError
 
80
 
 
81
 
 
82
class LocalURLServer(TestServer):
 
83
    """A pretend server for local transports, using file:// urls.
 
84
 
 
85
    Of course no actual server is required to access the local filesystem, so
 
86
    this just exists to tell the test code how to get to it.
 
87
    """
 
88
 
 
89
    def start_server(self):
 
90
        pass
 
91
 
 
92
    def get_url(self):
 
93
        """See Transport.Server.get_url."""
 
94
        return urlutils.local_path_to_url('')
 
95
 
 
96
 
 
97
class DecoratorServer(TestServer):
 
98
    """Server for the TransportDecorator for testing with.
 
99
 
 
100
    To use this when subclassing TransportDecorator, override override the
 
101
    get_decorator_class method.
 
102
    """
 
103
 
 
104
    def start_server(self, server=None):
 
105
        """See bzrlib.transport.Server.start_server.
 
106
 
 
107
        :server: decorate the urls given by server. If not provided a
 
108
        LocalServer is created.
 
109
        """
 
110
        if server is not None:
 
111
            self._made_server = False
 
112
            self._server = server
 
113
        else:
 
114
            self._made_server = True
 
115
            self._server = LocalURLServer()
 
116
            self._server.start_server()
 
117
 
 
118
    def stop_server(self):
 
119
        if self._made_server:
 
120
            self._server.stop_server()
 
121
 
 
122
    def get_decorator_class(self):
 
123
        """Return the class of the decorators we should be constructing."""
 
124
        raise NotImplementedError(self.get_decorator_class)
 
125
 
 
126
    def get_url_prefix(self):
 
127
        """What URL prefix does this decorator produce?"""
 
128
        return self.get_decorator_class()._get_url_prefix()
 
129
 
 
130
    def get_bogus_url(self):
 
131
        """See bzrlib.transport.Server.get_bogus_url."""
 
132
        return self.get_url_prefix() + self._server.get_bogus_url()
 
133
 
 
134
    def get_url(self):
 
135
        """See bzrlib.transport.Server.get_url."""
 
136
        return self.get_url_prefix() + self._server.get_url()
 
137
 
 
138
 
 
139
class BrokenRenameServer(DecoratorServer):
 
140
    """Server for the BrokenRenameTransportDecorator for testing with."""
 
141
 
 
142
    def get_decorator_class(self):
 
143
        from bzrlib.transport import brokenrename
 
144
        return brokenrename.BrokenRenameTransportDecorator
 
145
 
 
146
 
 
147
class FakeNFSServer(DecoratorServer):
 
148
    """Server for the FakeNFSTransportDecorator for testing with."""
 
149
 
 
150
    def get_decorator_class(self):
 
151
        from bzrlib.transport import fakenfs
 
152
        return fakenfs.FakeNFSTransportDecorator
 
153
 
 
154
 
 
155
class FakeVFATServer(DecoratorServer):
 
156
    """A server that suggests connections through FakeVFATTransportDecorator
 
157
 
 
158
    For use in testing.
 
159
    """
 
160
 
 
161
    def get_decorator_class(self):
 
162
        from bzrlib.transport import fakevfat
 
163
        return fakevfat.FakeVFATTransportDecorator
 
164
 
 
165
 
 
166
class LogDecoratorServer(DecoratorServer):
 
167
    """Server for testing."""
 
168
 
 
169
    def get_decorator_class(self):
 
170
        from bzrlib.transport import log
 
171
        return log.TransportLogDecorator
 
172
 
 
173
 
 
174
class NoSmartTransportServer(DecoratorServer):
 
175
    """Server for the NoSmartTransportDecorator for testing with."""
 
176
 
 
177
    def get_decorator_class(self):
 
178
        from bzrlib.transport import nosmart
 
179
        return nosmart.NoSmartTransportDecorator
 
180
 
 
181
 
 
182
class ReadonlyServer(DecoratorServer):
 
183
    """Server for the ReadonlyTransportDecorator for testing with."""
 
184
 
 
185
    def get_decorator_class(self):
 
186
        from bzrlib.transport import readonly
 
187
        return readonly.ReadonlyTransportDecorator
 
188
 
 
189
 
 
190
class TraceServer(DecoratorServer):
 
191
    """Server for the TransportTraceDecorator for testing with."""
 
192
 
 
193
    def get_decorator_class(self):
 
194
        from bzrlib.transport import trace
 
195
        return trace.TransportTraceDecorator
 
196
 
 
197
 
 
198
class UnlistableServer(DecoratorServer):
 
199
    """Server for the UnlistableTransportDecorator for testing with."""
 
200
 
 
201
    def get_decorator_class(self):
 
202
        from bzrlib.transport import unlistable
 
203
        return unlistable.UnlistableTransportDecorator
 
204
 
 
205
 
 
206
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
 
207
 
 
208
    def __init__(self):
 
209
        """TestingPathFilteringServer is not usable until start_server
 
210
        is called."""
 
211
 
 
212
    def start_server(self, backing_server=None):
 
213
        """Setup the Chroot on backing_server."""
 
214
        if backing_server is not None:
 
215
            self.backing_transport = transport.get_transport(
 
216
                backing_server.get_url())
 
217
        else:
 
218
            self.backing_transport = transport.get_transport('.')
 
219
        self.backing_transport.clone('added-by-filter').ensure_base()
 
220
        self.filter_func = lambda x: 'added-by-filter/' + x
 
221
        super(TestingPathFilteringServer, self).start_server()
 
222
 
 
223
    def get_bogus_url(self):
 
224
        raise NotImplementedError
 
225
 
 
226
 
 
227
class TestingChrootServer(chroot.ChrootServer):
 
228
 
 
229
    def __init__(self):
 
230
        """TestingChrootServer is not usable until start_server is called."""
 
231
        super(TestingChrootServer, self).__init__(None)
 
232
 
 
233
    def start_server(self, backing_server=None):
 
234
        """Setup the Chroot on backing_server."""
 
235
        if backing_server is not None:
 
236
            self.backing_transport = transport.get_transport(
 
237
                backing_server.get_url())
 
238
        else:
 
239
            self.backing_transport = transport.get_transport('.')
 
240
        super(TestingChrootServer, self).start_server()
 
241
 
 
242
    def get_bogus_url(self):
 
243
        raise NotImplementedError
 
244
 
 
245
 
 
246
class ThreadWithException(threading.Thread):
 
247
    """A catching exception thread.
 
248
 
 
249
    If an exception occurs during the thread execution, it's caught and
 
250
    re-raised when the thread is joined().
 
251
    """
 
252
 
 
253
    def __init__(self, *args, **kwargs):
 
254
        # There are cases where the calling thread must wait, yet, if an
 
255
        # exception occurs, the event should be set so the caller is not
 
256
        # blocked. The main example is a calling thread that want to wait for
 
257
        # the called thread to be in a given state before continuing.
 
258
        try:
 
259
            event = kwargs.pop('event')
 
260
        except KeyError:
 
261
            # If the caller didn't pass a specific event, create our own
 
262
            event = threading.Event()
 
263
        super(ThreadWithException, self).__init__(*args, **kwargs)
 
264
        self.set_ready_event(event)
 
265
        self.exception = None
 
266
        self.ignored_exceptions = None # see set_ignored_exceptions
 
267
 
 
268
    # compatibility thunk for python-2.4 and python-2.5...
 
269
    if sys.version_info < (2, 6):
 
270
        name = property(threading.Thread.getName, threading.Thread.setName)
 
271
 
 
272
    def set_ready_event(self, event):
 
273
        """Set the ``ready`` event used to synchronize exception catching.
 
274
 
 
275
        When the thread uses an event to synchronize itself with another thread
 
276
        (setting it when the other thread can wake up from a ``wait`` call),
 
277
        the event must be set after catching an exception or the other thread
 
278
        will hang.
 
279
 
 
280
        Some threads require multiple events and should set the relevant one
 
281
        when appropriate.
 
282
        """
 
283
        self.ready = event
 
284
 
 
285
    def set_ignored_exceptions(self, ignored):
 
286
        """Declare which exceptions will be ignored.
 
287
 
 
288
        :param ignored: Can be either:
 
289
           - None: all exceptions will be raised,
 
290
           - an exception class: the instances of this class will be ignored,
 
291
           - a tuple of exception classes: the instances of any class of the
 
292
             list will be ignored,
 
293
           - a callable: that will be passed the exception object
 
294
             and should return True if the exception should be ignored
 
295
        """
 
296
        if ignored is None:
 
297
            self.ignored_exceptions = None
 
298
        elif isinstance(ignored, (Exception, tuple)):
 
299
            self.ignored_exceptions = lambda e: isinstance(e, ignored)
 
300
        else:
 
301
            self.ignored_exceptions = ignored
 
302
 
 
303
    def run(self):
 
304
        """Overrides Thread.run to capture any exception."""
 
305
        self.ready.clear()
 
306
        try:
 
307
            try:
 
308
                super(ThreadWithException, self).run()
 
309
            except:
 
310
                self.exception = sys.exc_info()
 
311
        finally:
 
312
            # Make sure the calling thread is released
 
313
            self.ready.set()
 
314
 
 
315
 
 
316
    def join(self, timeout=5):
 
317
        """Overrides Thread.join to raise any exception caught.
 
318
 
 
319
 
 
320
        Calling join(timeout=0) will raise the caught exception or return None
 
321
        if the thread is still alive.
 
322
 
 
323
        The default timeout is set to 5 and should expire only when a thread
 
324
        serving a client connection is hung.
 
325
        """
 
326
        super(ThreadWithException, self).join(timeout)
 
327
        if self.exception is not None:
 
328
            exc_class, exc_value, exc_tb = self.exception
 
329
            self.exception = None # The exception should be raised only once
 
330
            if (self.ignored_exceptions is None
 
331
                or not self.ignored_exceptions(exc_value)):
 
332
                # Raise non ignored exceptions
 
333
                raise exc_class, exc_value, exc_tb
 
334
        if timeout and self.isAlive():
 
335
            # The timeout expired without joining the thread, the thread is
 
336
            # therefore stucked and that's a failure as far as the test is
 
337
            # concerned. We used to hang here.
 
338
 
 
339
            # FIXME: we need to kill the thread, but as far as the test is
 
340
            # concerned, raising an assertion is too strong. On most of the
 
341
            # platforms, this doesn't occur, so just mentioning the problem is
 
342
            # enough for now -- vila 2010824
 
343
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
344
            #raise AssertionError('thread %s hung' % (self.name,))
 
345
 
 
346
    def pending_exception(self):
 
347
        """Raise the caught exception.
 
348
 
 
349
        This does nothing if no exception occurred.
 
350
        """
 
351
        self.join(timeout=0)
 
352
 
 
353
 
 
354
class TestingTCPServerMixin:
 
355
    """Mixin to support running SocketServer.TCPServer in a thread.
 
356
 
 
357
    Tests are connecting from the main thread, the server has to be run in a
 
358
    separate thread.
 
359
    """
 
360
 
 
361
    def __init__(self):
 
362
        self.started = threading.Event()
 
363
        self.serving = None
 
364
        self.stopped = threading.Event()
 
365
        # We collect the resources used by the clients so we can release them
 
366
        # when shutting down
 
367
        self.clients = []
 
368
        self.ignored_exceptions = None
 
369
 
 
370
    def server_bind(self):
 
371
        self.socket.bind(self.server_address)
 
372
        self.server_address = self.socket.getsockname()
 
373
 
 
374
    def serve(self):
 
375
        self.serving = True
 
376
        self.stopped.clear()
 
377
        # We are listening and ready to accept connections
 
378
        self.started.set()
 
379
        try:
 
380
            while self.serving:
 
381
                # Really a connection but the python framework is generic and
 
382
                # call them requests
 
383
                self.handle_request()
 
384
            # Let's close the listening socket
 
385
            self.server_close()
 
386
        finally:
 
387
            self.stopped.set()
 
388
 
 
389
    def handle_request(self):
 
390
        """Handle one request.
 
391
 
 
392
        The python version swallows some socket exceptions and we don't use
 
393
        timeout, so we override it to better control the server behavior.
 
394
        """
 
395
        request, client_address = self.get_request()
 
396
        if self.verify_request(request, client_address):
 
397
            try:
 
398
                self.process_request(request, client_address)
 
399
            except:
 
400
                self.handle_error(request, client_address)
 
401
                self.close_request(request)
 
402
 
 
403
    def get_request(self):
 
404
        return self.socket.accept()
 
405
 
 
406
    def verify_request(self, request, client_address):
 
407
        """Verify the request.
 
408
 
 
409
        Return True if we should proceed with this request, False if we should
 
410
        not even touch a single byte in the socket ! This is useful when we
 
411
        stop the server with a dummy last connection.
 
412
        """
 
413
        return self.serving
 
414
 
 
415
    def handle_error(self, request, client_address):
 
416
        # Stop serving and re-raise the last exception seen
 
417
        self.serving = False
 
418
        # The following can be used for debugging purposes, it will display the
 
419
        # exception and the traceback just when it occurs instead of waiting
 
420
        # for the thread to be joined.
 
421
 
 
422
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
423
        raise
 
424
 
 
425
    def ignored_exceptions_during_shutdown(self, e):
 
426
        if sys.platform == 'win32':
 
427
            accepted_errnos = [errno.EBADF,
 
428
                               errno.EPIPE,
 
429
                               errno.WSAEBADF,
 
430
                               errno.WSAECONNRESET,
 
431
                               errno.WSAENOTCONN,
 
432
                               errno.WSAESHUTDOWN,
 
433
                               ]
 
434
        else:
 
435
            accepted_errnos = [errno.EBADF,
 
436
                               errno.ECONNRESET,
 
437
                               errno.ENOTCONN,
 
438
                               errno.EPIPE,
 
439
                               ]
 
440
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
441
            return True
 
442
        return False
 
443
 
 
444
    # The following methods are called by the main thread
 
445
 
 
446
    def stop_client_connections(self):
 
447
        while self.clients:
 
448
            c = self.clients.pop()
 
449
            self.shutdown_client(c)
 
450
 
 
451
    def shutdown_socket(self, sock):
 
452
        """Properly shutdown a socket.
 
453
 
 
454
        This should be called only when no other thread is trying to use the
 
455
        socket.
 
456
        """
 
457
        try:
 
458
            sock.shutdown(socket.SHUT_RDWR)
 
459
            sock.close()
 
460
        except Exception, e:
 
461
            if self.ignored_exceptions(e):
 
462
                pass
 
463
            else:
 
464
                raise
 
465
 
 
466
    # The following methods are called by the main thread
 
467
 
 
468
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
469
        self.ignored_exceptions = ignored_exceptions
 
470
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
471
 
 
472
    def _pending_exception(self, thread):
 
473
        """Raise server uncaught exception.
 
474
 
 
475
        Daughter classes can override this if they use daughter threads.
 
476
        """
 
477
        thread.pending_exception()
 
478
 
 
479
 
 
480
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
481
 
 
482
    def __init__(self, server_address, request_handler_class):
 
483
        TestingTCPServerMixin.__init__(self)
 
484
        SocketServer.TCPServer.__init__(self, server_address,
 
485
                                        request_handler_class)
 
486
 
 
487
    def get_request(self):
 
488
        """Get the request and client address from the socket."""
 
489
        sock, addr = TestingTCPServerMixin.get_request(self)
 
490
        self.clients.append((sock, addr))
 
491
        return sock, addr
 
492
 
 
493
    # The following methods are called by the main thread
 
494
 
 
495
    def shutdown_client(self, client):
 
496
        sock, addr = client
 
497
        self.shutdown_socket(sock)
 
498
 
 
499
 
 
500
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
501
                                SocketServer.ThreadingTCPServer):
 
502
 
 
503
    def __init__(self, server_address, request_handler_class):
 
504
        TestingTCPServerMixin.__init__(self)
 
505
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
506
                                                 request_handler_class)
 
507
 
 
508
    def get_request (self):
 
509
        """Get the request and client address from the socket."""
 
510
        sock, addr = TestingTCPServerMixin.get_request(self)
 
511
        # The thread is not create yet, it will be updated in process_request
 
512
        self.clients.append((sock, addr, None))
 
513
        return sock, addr
 
514
 
 
515
    def process_request_thread(self, started, stopped, request, client_address):
 
516
        started.set()
 
517
        SocketServer.ThreadingTCPServer.process_request_thread(
 
518
            self, request, client_address)
 
519
        self.close_request(request)
 
520
        stopped.set()
 
521
 
 
522
    def process_request(self, request, client_address):
 
523
        """Start a new thread to process the request."""
 
524
        started = threading.Event()
 
525
        stopped = threading.Event()
 
526
        t = ThreadWithException(
 
527
            event=stopped,
 
528
            name='%s -> %s' % (client_address, self.server_address),
 
529
            target = self.process_request_thread,
 
530
            args = (started, stopped, request, client_address))
 
531
        # Update the client description
 
532
        self.clients.pop()
 
533
        self.clients.append((request, client_address, t))
 
534
        # Propagate the exception handler since we must use the same one for
 
535
        # connections running in their own threads than TestingTCPServer.
 
536
        t.set_ignored_exceptions(self.ignored_exceptions)
 
537
        t.start()
 
538
        started.wait()
 
539
        if debug_threads():
 
540
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
541
        # If an exception occured during the thread start, it will get raised.
 
542
        t.pending_exception()
 
543
 
 
544
    # The following methods are called by the main thread
 
545
 
 
546
    def shutdown_client(self, client):
 
547
        sock, addr, connection_thread = client
 
548
        self.shutdown_socket(sock)
 
549
        if connection_thread is not None:
 
550
            # The thread has been created only if the request is processed but
 
551
            # after the connection is inited. This could happen during server
 
552
            # shutdown. If an exception occurred in the thread it will be
 
553
            # re-raised
 
554
            if debug_threads():
 
555
                sys.stderr.write('Client thread %s will be joined\n'
 
556
                                 % (connection_thread.name,))
 
557
            connection_thread.join()
 
558
 
 
559
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
560
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
561
                                                     ignored_exceptions)
 
562
        for sock, addr, connection_thread in self.clients:
 
563
            if connection_thread is not None:
 
564
                connection_thread.set_ignored_exceptions(
 
565
                    self.ignored_exceptions)
 
566
 
 
567
    def _pending_exception(self, thread):
 
568
        for sock, addr, connection_thread in self.clients:
 
569
            if connection_thread is not None:
 
570
                connection_thread.pending_exception()
 
571
        TestingTCPServerMixin._pending_exception(self, thread)
 
572
 
 
573
 
 
574
class TestingTCPServerInAThread(transport.Server):
 
575
    """A server in a thread that re-raise thread exceptions."""
 
576
 
 
577
    def __init__(self, server_address, server_class, request_handler_class):
 
578
        self.server_class = server_class
 
579
        self.request_handler_class = request_handler_class
 
580
        self.host, self.port = server_address
 
581
        self.server = None
 
582
        self._server_thread = None
 
583
 
 
584
    def __repr__(self):
 
585
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
586
 
 
587
    def create_server(self):
 
588
        return self.server_class((self.host, self.port),
 
589
                                 self.request_handler_class)
 
590
 
 
591
    def start_server(self):
 
592
        self.server = self.create_server()
 
593
        self._server_thread = ThreadWithException(
 
594
            event=self.server.started,
 
595
            target=self.run_server)
 
596
        self._server_thread.start()
 
597
        # Wait for the server thread to start (i.e release the lock)
 
598
        self.server.started.wait()
 
599
        # Get the real address, especially the port
 
600
        self.host, self.port = self.server.server_address
 
601
        self._server_thread.name = self.server.server_address
 
602
        if debug_threads():
 
603
            sys.stderr.write('Server thread %s started\n'
 
604
                             % (self._server_thread.name,))
 
605
        # If an exception occured during the server start, it will get raised,
 
606
        # otherwise, the server is blocked on its accept() call.
 
607
        self._server_thread.pending_exception()
 
608
        # From now on, we'll use a different event to ensure the server can set
 
609
        # its exception
 
610
        self._server_thread.set_ready_event(self.server.stopped)
 
611
 
 
612
    def run_server(self):
 
613
        self.server.serve()
 
614
 
 
615
    def stop_server(self):
 
616
        if self.server is None:
 
617
            return
 
618
        try:
 
619
            # The server has been started successfully, shut it down now.  As
 
620
            # soon as we stop serving, no more connection are accepted except
 
621
            # one to get out of the blocking listen.
 
622
            self.set_ignored_exceptions(
 
623
                self.server.ignored_exceptions_during_shutdown)
 
624
            self.server.serving = False
 
625
            if debug_threads():
 
626
                sys.stderr.write('Server thread %s will be joined\n'
 
627
                                 % (self._server_thread.name,))
 
628
            # The server is listening for a last connection, let's give it:
 
629
            last_conn = None
 
630
            try:
 
631
                last_conn = osutils.connect_socket((self.host, self.port))
 
632
            except socket.error, e:
 
633
                # But ignore connection errors as the point is to unblock the
 
634
                # server thread, it may happen that it's not blocked or even
 
635
                # not started.
 
636
                pass
 
637
            # We start shutting down the client while the server itself is
 
638
            # shutting down.
 
639
            self.server.stop_client_connections()
 
640
            # Now we wait for the thread running self.server.serve() to finish
 
641
            self.server.stopped.wait()
 
642
            if last_conn is not None:
 
643
                # Close the last connection without trying to use it. The
 
644
                # server will not process a single byte on that socket to avoid
 
645
                # complications (SSL starts with a handshake for example).
 
646
                last_conn.close()
 
647
            # Check for any exception that could have occurred in the server
 
648
            # thread
 
649
            try:
 
650
                self._server_thread.join()
 
651
            except Exception, e:
 
652
                if self.server.ignored_exceptions(e):
 
653
                    pass
 
654
                else:
 
655
                    raise
 
656
        finally:
 
657
            # Make sure we can be called twice safely, note that this means
 
658
            # that we will raise a single exception even if several occurred in
 
659
            # the various threads involved.
 
660
            self.server = None
 
661
 
 
662
    def set_ignored_exceptions(self, ignored_exceptions):
 
663
        """Install an exception handler for the server."""
 
664
        self.server.set_ignored_exceptions(self._server_thread,
 
665
                                           ignored_exceptions)
 
666
 
 
667
    def pending_exception(self):
 
668
        """Raise uncaught exception in the server."""
 
669
        self.server._pending_exception(self._server_thread)
 
670
 
 
671
 
 
672
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
673
                                    medium.SmartServerSocketStreamMedium):
 
674
 
 
675
    def __init__(self, request, client_address, server):
 
676
        medium.SmartServerSocketStreamMedium.__init__(
 
677
            self, request, server.backing_transport,
 
678
            server.root_client_path)
 
679
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
680
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
681
                                                 server)
 
682
 
 
683
    def handle(self):
 
684
        while not self.finished:
 
685
            server_protocol = self._build_protocol()
 
686
            self._serve_one_request(server_protocol)
 
687
 
 
688
 
 
689
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
690
 
 
691
    def __init__(self, server_address, request_handler_class,
 
692
                 backing_transport, root_client_path):
 
693
        TestingThreadingTCPServer.__init__(self, server_address,
 
694
                                           request_handler_class)
 
695
        server.SmartTCPServer.__init__(self, backing_transport,
 
696
                                       root_client_path)
 
697
    def serve(self):
 
698
        # FIXME: No test are exercising the hooks for the test server
 
699
        # -- vila 20100618
 
700
        self.run_server_started_hooks()
 
701
        try:
 
702
            TestingThreadingTCPServer.serve(self)
 
703
        finally:
 
704
            self.run_server_stopped_hooks()
 
705
 
 
706
    def get_url(self):
 
707
        """Return the url of the server"""
 
708
        return "bzr://%s:%d/" % self.server_address
 
709
 
 
710
 
 
711
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
 
712
    """Server suitable for use by transport tests.
 
713
 
 
714
    This server is backed by the process's cwd.
 
715
    """
 
716
    def __init__(self, thread_name_suffix=''):
 
717
        self.client_path_extra = None
 
718
        self.thread_name_suffix = thread_name_suffix
 
719
        self.host = '127.0.0.1'
 
720
        self.port = 0
 
721
        super(SmartTCPServer_for_testing, self).__init__(
 
722
                (self.host, self.port),
 
723
                TestingSmartServer,
 
724
                TestingSmartConnectionHandler)
 
725
 
 
726
    def create_server(self):
 
727
        return self.server_class((self.host, self.port),
 
728
                                 self.request_handler_class,
 
729
                                 self.backing_transport,
 
730
                                 self.root_client_path)
 
731
 
 
732
 
 
733
    def start_server(self, backing_transport_server=None,
 
734
                     client_path_extra='/extra/'):
 
735
        """Set up server for testing.
 
736
 
 
737
        :param backing_transport_server: backing server to use.  If not
 
738
            specified, a LocalURLServer at the current working directory will
 
739
            be used.
 
740
        :param client_path_extra: a path segment starting with '/' to append to
 
741
            the root URL for this server.  For instance, a value of '/foo/bar/'
 
742
            will mean the root of the backing transport will be published at a
 
743
            URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
 
744
            `bzr://127.0.0.1:nnnn/`.  Default value is `extra`, so that tests
 
745
            by default will fail unless they do the necessary path translation.
 
746
        """
 
747
        if not client_path_extra.startswith('/'):
 
748
            raise ValueError(client_path_extra)
 
749
        self.root_client_path = self.client_path_extra = client_path_extra
 
750
        from bzrlib.transport.chroot import ChrootServer
 
751
        if backing_transport_server is None:
 
752
            backing_transport_server = LocalURLServer()
 
753
        self.chroot_server = ChrootServer(
 
754
            self.get_backing_transport(backing_transport_server))
 
755
        self.chroot_server.start_server()
 
756
        self.backing_transport = transport.get_transport(
 
757
            self.chroot_server.get_url())
 
758
        super(SmartTCPServer_for_testing, self).start_server()
 
759
 
 
760
    def stop_server(self):
 
761
        try:
 
762
            super(SmartTCPServer_for_testing, self).stop_server()
 
763
        finally:
 
764
            self.chroot_server.stop_server()
 
765
 
 
766
    def get_backing_transport(self, backing_transport_server):
 
767
        """Get a backing transport from a server we are decorating."""
 
768
        return transport.get_transport(backing_transport_server.get_url())
 
769
 
 
770
    def get_url(self):
 
771
        url = self.server.get_url()
 
772
        return url[:-1] + self.client_path_extra
 
773
 
 
774
    def get_bogus_url(self):
 
775
        """Return a URL which will fail to connect"""
 
776
        return 'bzr://127.0.0.1:1/'
 
777
 
 
778
 
 
779
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
 
780
    """Get a readonly server for testing."""
 
781
 
 
782
    def get_backing_transport(self, backing_transport_server):
 
783
        """Get a backing transport from a server we are decorating."""
 
784
        url = 'readonly+' + backing_transport_server.get_url()
 
785
        return transport.get_transport(url)
 
786
 
 
787
 
 
788
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
 
789
    """A variation of SmartTCPServer_for_testing that limits the client to
 
790
    using RPCs in protocol v2 (i.e. bzr <= 1.5).
 
791
    """
 
792
 
 
793
    def get_url(self):
 
794
        url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
 
795
        url = 'bzr-v2://' + url[len('bzr://'):]
 
796
        return url
 
797
 
 
798
 
 
799
class ReadonlySmartTCPServer_for_testing_v2_only(
 
800
    SmartTCPServer_for_testing_v2_only):
 
801
    """Get a readonly server for testing."""
 
802
 
 
803
    def get_backing_transport(self, backing_transport_server):
 
804
        """Get a backing transport from a server we are decorating."""
 
805
        url = 'readonly+' + backing_transport_server.get_url()
 
806
        return transport.get_transport(url)
 
807
 
 
808
 
 
809
 
 
810