~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Martin Pool
  • Date: 2010-04-01 04:41:18 UTC
  • mto: This revision was merged to the branch mainline in revision 5128.
  • Revision ID: mbp@sourcefrog.net-20100401044118-shyctqc02ob08ngz
ignore .testrepository

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010 Canonical Ltd
 
1
# Copyright (C) 2005, 2006, 2007, 2008, 2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
 
import errno
18
 
import socket
19
 
import SocketServer
20
 
import select
21
 
import sys
22
 
import threading
23
 
 
24
 
 
25
17
from bzrlib import (
26
 
    osutils,
27
18
    transport,
28
19
    urlutils,
29
20
    )
31
22
    chroot,
32
23
    pathfilter,
33
24
    )
34
 
from bzrlib.smart import (
35
 
    medium,
36
 
    server,
37
 
    )
38
 
 
39
 
 
40
 
def debug_threads():
41
 
    # FIXME: There is a dependency loop between bzrlib.tests and
42
 
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
43
 
    # defining this function is enough for our needs. -- vila 20100611
44
 
    from bzrlib import tests
45
 
    return 'threads' in tests.selftest_debug_flags
 
25
from bzrlib.smart import server
46
26
 
47
27
 
48
28
class TestServer(transport.Server):
243
223
        raise NotImplementedError
244
224
 
245
225
 
246
 
class ThreadWithException(threading.Thread):
247
 
    """A catching exception thread.
248
 
 
249
 
    If an exception occurs during the thread execution, it's caught and
250
 
    re-raised when the thread is joined().
251
 
    """
252
 
 
253
 
    def __init__(self, *args, **kwargs):
254
 
        # There are cases where the calling thread must wait, yet, if an
255
 
        # exception occurs, the event should be set so the caller is not
256
 
        # blocked. The main example is a calling thread that want to wait for
257
 
        # the called thread to be in a given state before continuing.
258
 
        try:
259
 
            event = kwargs.pop('event')
260
 
        except KeyError:
261
 
            # If the caller didn't pass a specific event, create our own
262
 
            event = threading.Event()
263
 
        super(ThreadWithException, self).__init__(*args, **kwargs)
264
 
        self.set_ready_event(event)
265
 
        self.exception = None
266
 
        self.ignored_exceptions = None # see set_ignored_exceptions
267
 
 
268
 
    # compatibility thunk for python-2.4 and python-2.5...
269
 
    if sys.version_info < (2, 6):
270
 
        name = property(threading.Thread.getName, threading.Thread.setName)
271
 
 
272
 
    def set_ready_event(self, event):
273
 
        """Set the ``ready`` event used to synchronize exception catching.
274
 
 
275
 
        When the thread uses an event to synchronize itself with another thread
276
 
        (setting it when the other thread can wake up from a ``wait`` call),
277
 
        the event must be set after catching an exception or the other thread
278
 
        will hang.
279
 
 
280
 
        Some threads require multiple events and should set the relevant one
281
 
        when appropriate.
282
 
        """
283
 
        self.ready = event
284
 
 
285
 
    def set_ignored_exceptions(self, ignored):
286
 
        """Declare which exceptions will be ignored.
287
 
 
288
 
        :param ignored: Can be either:
289
 
           - None: all exceptions will be raised,
290
 
           - an exception class: the instances of this class will be ignored,
291
 
           - a tuple of exception classes: the instances of any class of the
292
 
             list will be ignored,
293
 
           - a callable: that will be passed the exception object
294
 
             and should return True if the exception should be ignored
295
 
        """
296
 
        if ignored is None:
297
 
            self.ignored_exceptions = None
298
 
        elif isinstance(ignored, (Exception, tuple)):
299
 
            self.ignored_exceptions = lambda e: isinstance(e, ignored)
300
 
        else:
301
 
            self.ignored_exceptions = ignored
302
 
 
303
 
    def run(self):
304
 
        """Overrides Thread.run to capture any exception."""
305
 
        self.ready.clear()
306
 
        try:
307
 
            try:
308
 
                super(ThreadWithException, self).run()
309
 
            except:
310
 
                self.exception = sys.exc_info()
311
 
        finally:
312
 
            # Make sure the calling thread is released
313
 
            self.ready.set()
314
 
 
315
 
 
316
 
    def join(self, timeout=5):
317
 
        """Overrides Thread.join to raise any exception caught.
318
 
 
319
 
 
320
 
        Calling join(timeout=0) will raise the caught exception or return None
321
 
        if the thread is still alive.
322
 
 
323
 
        The default timeout is set to 5 and should expire only when a thread
324
 
        serving a client connection is hung.
325
 
        """
326
 
        super(ThreadWithException, self).join(timeout)
327
 
        if self.exception is not None:
328
 
            exc_class, exc_value, exc_tb = self.exception
329
 
            self.exception = None # The exception should be raised only once
330
 
            if (self.ignored_exceptions is None
331
 
                or not self.ignored_exceptions(exc_value)):
332
 
                # Raise non ignored exceptions
333
 
                raise exc_class, exc_value, exc_tb
334
 
        if timeout and self.isAlive():
335
 
            # The timeout expired without joining the thread, the thread is
336
 
            # therefore stucked and that's a failure as far as the test is
337
 
            # concerned. We used to hang here.
338
 
 
339
 
            # FIXME: we need to kill the thread, but as far as the test is
340
 
            # concerned, raising an assertion is too strong. On most of the
341
 
            # platforms, this doesn't occur, so just mentioning the problem is
342
 
            # enough for now -- vila 2010824
343
 
            sys.stderr.write('thread %s hung\n' % (self.name,))
344
 
            #raise AssertionError('thread %s hung' % (self.name,))
345
 
 
346
 
    def pending_exception(self):
347
 
        """Raise the caught exception.
348
 
 
349
 
        This does nothing if no exception occurred.
350
 
        """
351
 
        self.join(timeout=0)
352
 
 
353
 
 
354
 
class TestingTCPServerMixin:
355
 
    """Mixin to support running SocketServer.TCPServer in a thread.
356
 
 
357
 
    Tests are connecting from the main thread, the server has to be run in a
358
 
    separate thread.
359
 
    """
360
 
 
361
 
    def __init__(self):
362
 
        self.started = threading.Event()
363
 
        self.serving = None
364
 
        self.stopped = threading.Event()
365
 
        # We collect the resources used by the clients so we can release them
366
 
        # when shutting down
367
 
        self.clients = []
368
 
        self.ignored_exceptions = None
369
 
 
370
 
    def server_bind(self):
371
 
        self.socket.bind(self.server_address)
372
 
        self.server_address = self.socket.getsockname()
373
 
 
374
 
    def serve(self):
375
 
        self.serving = True
376
 
        self.stopped.clear()
377
 
        # We are listening and ready to accept connections
378
 
        self.started.set()
379
 
        try:
380
 
            while self.serving:
381
 
                # Really a connection but the python framework is generic and
382
 
                # call them requests
383
 
                self.handle_request()
384
 
            # Let's close the listening socket
385
 
            self.server_close()
386
 
        finally:
387
 
            self.stopped.set()
388
 
 
389
 
    def handle_request(self):
390
 
        """Handle one request.
391
 
 
392
 
        The python version swallows some socket exceptions and we don't use
393
 
        timeout, so we override it to better control the server behavior.
394
 
        """
395
 
        request, client_address = self.get_request()
396
 
        if self.verify_request(request, client_address):
397
 
            try:
398
 
                self.process_request(request, client_address)
399
 
            except:
400
 
                self.handle_error(request, client_address)
401
 
                self.close_request(request)
402
 
 
403
 
    def get_request(self):
404
 
        return self.socket.accept()
405
 
 
406
 
    def verify_request(self, request, client_address):
407
 
        """Verify the request.
408
 
 
409
 
        Return True if we should proceed with this request, False if we should
410
 
        not even touch a single byte in the socket ! This is useful when we
411
 
        stop the server with a dummy last connection.
412
 
        """
413
 
        return self.serving
414
 
 
415
 
    def handle_error(self, request, client_address):
416
 
        # Stop serving and re-raise the last exception seen
417
 
        self.serving = False
418
 
        # The following can be used for debugging purposes, it will display the
419
 
        # exception and the traceback just when it occurs instead of waiting
420
 
        # for the thread to be joined.
421
 
 
422
 
        # SocketServer.BaseServer.handle_error(self, request, client_address)
423
 
        raise
424
 
 
425
 
    def ignored_exceptions_during_shutdown(self, e):
426
 
        if sys.platform == 'win32':
427
 
            accepted_errnos = [errno.EBADF,
428
 
                               errno.EPIPE,
429
 
                               errno.WSAEBADF,
430
 
                               errno.WSAECONNRESET,
431
 
                               errno.WSAENOTCONN,
432
 
                               errno.WSAESHUTDOWN,
433
 
                               ]
434
 
        else:
435
 
            accepted_errnos = [errno.EBADF,
436
 
                               errno.ECONNRESET,
437
 
                               errno.ENOTCONN,
438
 
                               errno.EPIPE,
439
 
                               ]
440
 
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
441
 
            return True
442
 
        return False
443
 
 
444
 
    # The following methods are called by the main thread
445
 
 
446
 
    def stop_client_connections(self):
447
 
        while self.clients:
448
 
            c = self.clients.pop()
449
 
            self.shutdown_client(c)
450
 
 
451
 
    def shutdown_socket(self, sock):
452
 
        """Properly shutdown a socket.
453
 
 
454
 
        This should be called only when no other thread is trying to use the
455
 
        socket.
456
 
        """
457
 
        try:
458
 
            sock.shutdown(socket.SHUT_RDWR)
459
 
            sock.close()
460
 
        except Exception, e:
461
 
            if self.ignored_exceptions(e):
462
 
                pass
463
 
            else:
464
 
                raise
465
 
 
466
 
    # The following methods are called by the main thread
467
 
 
468
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
469
 
        self.ignored_exceptions = ignored_exceptions
470
 
        thread.set_ignored_exceptions(self.ignored_exceptions)
471
 
 
472
 
    def _pending_exception(self, thread):
473
 
        """Raise server uncaught exception.
474
 
 
475
 
        Daughter classes can override this if they use daughter threads.
476
 
        """
477
 
        thread.pending_exception()
478
 
 
479
 
 
480
 
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
481
 
 
482
 
    def __init__(self, server_address, request_handler_class):
483
 
        TestingTCPServerMixin.__init__(self)
484
 
        SocketServer.TCPServer.__init__(self, server_address,
485
 
                                        request_handler_class)
486
 
 
487
 
    def get_request(self):
488
 
        """Get the request and client address from the socket."""
489
 
        sock, addr = TestingTCPServerMixin.get_request(self)
490
 
        self.clients.append((sock, addr))
491
 
        return sock, addr
492
 
 
493
 
    # The following methods are called by the main thread
494
 
 
495
 
    def shutdown_client(self, client):
496
 
        sock, addr = client
497
 
        self.shutdown_socket(sock)
498
 
 
499
 
 
500
 
class TestingThreadingTCPServer(TestingTCPServerMixin,
501
 
                                SocketServer.ThreadingTCPServer):
502
 
 
503
 
    def __init__(self, server_address, request_handler_class):
504
 
        TestingTCPServerMixin.__init__(self)
505
 
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
506
 
                                                 request_handler_class)
507
 
 
508
 
    def get_request (self):
509
 
        """Get the request and client address from the socket."""
510
 
        sock, addr = TestingTCPServerMixin.get_request(self)
511
 
        # The thread is not create yet, it will be updated in process_request
512
 
        self.clients.append((sock, addr, None))
513
 
        return sock, addr
514
 
 
515
 
    def process_request_thread(self, started, stopped, request, client_address):
516
 
        started.set()
517
 
        SocketServer.ThreadingTCPServer.process_request_thread(
518
 
            self, request, client_address)
519
 
        self.close_request(request)
520
 
        stopped.set()
521
 
 
522
 
    def process_request(self, request, client_address):
523
 
        """Start a new thread to process the request."""
524
 
        started = threading.Event()
525
 
        stopped = threading.Event()
526
 
        t = ThreadWithException(
527
 
            event=stopped,
528
 
            name='%s -> %s' % (client_address, self.server_address),
529
 
            target = self.process_request_thread,
530
 
            args = (started, stopped, request, client_address))
531
 
        # Update the client description
532
 
        self.clients.pop()
533
 
        self.clients.append((request, client_address, t))
534
 
        # Propagate the exception handler since we must use the same one for
535
 
        # connections running in their own threads than TestingTCPServer.
536
 
        t.set_ignored_exceptions(self.ignored_exceptions)
537
 
        t.start()
538
 
        started.wait()
539
 
        if debug_threads():
540
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
541
 
        # If an exception occured during the thread start, it will get raised.
542
 
        t.pending_exception()
543
 
 
544
 
    # The following methods are called by the main thread
545
 
 
546
 
    def shutdown_client(self, client):
547
 
        sock, addr, connection_thread = client
548
 
        self.shutdown_socket(sock)
549
 
        if connection_thread is not None:
550
 
            # The thread has been created only if the request is processed but
551
 
            # after the connection is inited. This could happen during server
552
 
            # shutdown. If an exception occurred in the thread it will be
553
 
            # re-raised
554
 
            if debug_threads():
555
 
                sys.stderr.write('Client thread %s will be joined\n'
556
 
                                 % (connection_thread.name,))
557
 
            connection_thread.join()
558
 
 
559
 
    def set_ignored_exceptions(self, thread, ignored_exceptions):
560
 
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
561
 
                                                     ignored_exceptions)
562
 
        for sock, addr, connection_thread in self.clients:
563
 
            if connection_thread is not None:
564
 
                connection_thread.set_ignored_exceptions(
565
 
                    self.ignored_exceptions)
566
 
 
567
 
    def _pending_exception(self, thread):
568
 
        for sock, addr, connection_thread in self.clients:
569
 
            if connection_thread is not None:
570
 
                connection_thread.pending_exception()
571
 
        TestingTCPServerMixin._pending_exception(self, thread)
572
 
 
573
 
 
574
 
class TestingTCPServerInAThread(transport.Server):
575
 
    """A server in a thread that re-raise thread exceptions."""
576
 
 
577
 
    def __init__(self, server_address, server_class, request_handler_class):
578
 
        self.server_class = server_class
579
 
        self.request_handler_class = request_handler_class
580
 
        self.host, self.port = server_address
581
 
        self.server = None
582
 
        self._server_thread = None
583
 
 
584
 
    def __repr__(self):
585
 
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
586
 
 
587
 
    def create_server(self):
588
 
        return self.server_class((self.host, self.port),
589
 
                                 self.request_handler_class)
590
 
 
591
 
    def start_server(self):
592
 
        self.server = self.create_server()
593
 
        self._server_thread = ThreadWithException(
594
 
            event=self.server.started,
595
 
            target=self.run_server)
596
 
        self._server_thread.start()
597
 
        # Wait for the server thread to start (i.e release the lock)
598
 
        self.server.started.wait()
599
 
        # Get the real address, especially the port
600
 
        self.host, self.port = self.server.server_address
601
 
        self._server_thread.name = self.server.server_address
602
 
        if debug_threads():
603
 
            sys.stderr.write('Server thread %s started\n'
604
 
                             % (self._server_thread.name,))
605
 
        # If an exception occured during the server start, it will get raised,
606
 
        # otherwise, the server is blocked on its accept() call.
607
 
        self._server_thread.pending_exception()
608
 
        # From now on, we'll use a different event to ensure the server can set
609
 
        # its exception
610
 
        self._server_thread.set_ready_event(self.server.stopped)
611
 
 
612
 
    def run_server(self):
613
 
        self.server.serve()
614
 
 
615
 
    def stop_server(self):
616
 
        if self.server is None:
617
 
            return
618
 
        try:
619
 
            # The server has been started successfully, shut it down now.  As
620
 
            # soon as we stop serving, no more connection are accepted except
621
 
            # one to get out of the blocking listen.
622
 
            self.set_ignored_exceptions(
623
 
                self.server.ignored_exceptions_during_shutdown)
624
 
            self.server.serving = False
625
 
            if debug_threads():
626
 
                sys.stderr.write('Server thread %s will be joined\n'
627
 
                                 % (self._server_thread.name,))
628
 
            # The server is listening for a last connection, let's give it:
629
 
            last_conn = None
630
 
            try:
631
 
                last_conn = osutils.connect_socket((self.host, self.port))
632
 
            except socket.error, e:
633
 
                # But ignore connection errors as the point is to unblock the
634
 
                # server thread, it may happen that it's not blocked or even
635
 
                # not started.
636
 
                pass
637
 
            # We start shutting down the client while the server itself is
638
 
            # shutting down.
639
 
            self.server.stop_client_connections()
640
 
            # Now we wait for the thread running self.server.serve() to finish
641
 
            self.server.stopped.wait()
642
 
            if last_conn is not None:
643
 
                # Close the last connection without trying to use it. The
644
 
                # server will not process a single byte on that socket to avoid
645
 
                # complications (SSL starts with a handshake for example).
646
 
                last_conn.close()
647
 
            # Check for any exception that could have occurred in the server
648
 
            # thread
649
 
            try:
650
 
                self._server_thread.join()
651
 
            except Exception, e:
652
 
                if self.server.ignored_exceptions(e):
653
 
                    pass
654
 
                else:
655
 
                    raise
656
 
        finally:
657
 
            # Make sure we can be called twice safely, note that this means
658
 
            # that we will raise a single exception even if several occurred in
659
 
            # the various threads involved.
660
 
            self.server = None
661
 
 
662
 
    def set_ignored_exceptions(self, ignored_exceptions):
663
 
        """Install an exception handler for the server."""
664
 
        self.server.set_ignored_exceptions(self._server_thread,
665
 
                                           ignored_exceptions)
666
 
 
667
 
    def pending_exception(self):
668
 
        """Raise uncaught exception in the server."""
669
 
        self.server._pending_exception(self._server_thread)
670
 
 
671
 
 
672
 
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
673
 
                                    medium.SmartServerSocketStreamMedium):
674
 
 
675
 
    def __init__(self, request, client_address, server):
676
 
        medium.SmartServerSocketStreamMedium.__init__(
677
 
            self, request, server.backing_transport,
678
 
            server.root_client_path)
679
 
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
680
 
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
681
 
                                                 server)
682
 
 
683
 
    def handle(self):
684
 
        while not self.finished:
685
 
            server_protocol = self._build_protocol()
686
 
            self._serve_one_request(server_protocol)
687
 
 
688
 
 
689
 
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
690
 
 
691
 
    def __init__(self, server_address, request_handler_class,
692
 
                 backing_transport, root_client_path):
693
 
        TestingThreadingTCPServer.__init__(self, server_address,
694
 
                                           request_handler_class)
695
 
        server.SmartTCPServer.__init__(self, backing_transport,
696
 
                                       root_client_path)
697
 
    def serve(self):
698
 
        # FIXME: No test are exercising the hooks for the test server
699
 
        # -- vila 20100618
700
 
        self.run_server_started_hooks()
701
 
        try:
702
 
            TestingThreadingTCPServer.serve(self)
703
 
        finally:
704
 
            self.run_server_stopped_hooks()
705
 
 
706
 
    def get_url(self):
707
 
        """Return the url of the server"""
708
 
        return "bzr://%s:%d/" % self.server_address
709
 
 
710
 
 
711
 
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
 
226
class SmartTCPServer_for_testing(server.SmartTCPServer):
712
227
    """Server suitable for use by transport tests.
713
228
 
714
229
    This server is backed by the process's cwd.
715
230
    """
 
231
 
716
232
    def __init__(self, thread_name_suffix=''):
 
233
        super(SmartTCPServer_for_testing, self).__init__(None)
717
234
        self.client_path_extra = None
718
235
        self.thread_name_suffix = thread_name_suffix
719
 
        self.host = '127.0.0.1'
720
 
        self.port = 0
721
 
        super(SmartTCPServer_for_testing, self).__init__(
722
 
                (self.host, self.port),
723
 
                TestingSmartServer,
724
 
                TestingSmartConnectionHandler)
725
 
 
726
 
    def create_server(self):
727
 
        return self.server_class((self.host, self.port),
728
 
                                 self.request_handler_class,
729
 
                                 self.backing_transport,
730
 
                                 self.root_client_path)
731
 
 
 
236
 
 
237
    def get_backing_transport(self, backing_transport_server):
 
238
        """Get a backing transport from a server we are decorating."""
 
239
        return transport.get_transport(backing_transport_server.get_url())
732
240
 
733
241
    def start_server(self, backing_transport_server=None,
734
 
                     client_path_extra='/extra/'):
 
242
              client_path_extra='/extra/'):
735
243
        """Set up server for testing.
736
244
 
737
245
        :param backing_transport_server: backing server to use.  If not
746
254
        """
747
255
        if not client_path_extra.startswith('/'):
748
256
            raise ValueError(client_path_extra)
749
 
        self.root_client_path = self.client_path_extra = client_path_extra
750
257
        from bzrlib.transport.chroot import ChrootServer
751
258
        if backing_transport_server is None:
752
259
            backing_transport_server = LocalURLServer()
755
262
        self.chroot_server.start_server()
756
263
        self.backing_transport = transport.get_transport(
757
264
            self.chroot_server.get_url())
758
 
        super(SmartTCPServer_for_testing, self).start_server()
 
265
        self.root_client_path = self.client_path_extra = client_path_extra
 
266
        self.start_background_thread(self.thread_name_suffix)
759
267
 
760
268
    def stop_server(self):
761
 
        try:
762
 
            super(SmartTCPServer_for_testing, self).stop_server()
763
 
        finally:
764
 
            self.chroot_server.stop_server()
765
 
 
766
 
    def get_backing_transport(self, backing_transport_server):
767
 
        """Get a backing transport from a server we are decorating."""
768
 
        return transport.get_transport(backing_transport_server.get_url())
 
269
        self.stop_background_thread()
 
270
        self.chroot_server.stop_server()
769
271
 
770
272
    def get_url(self):
771
 
        url = self.server.get_url()
 
273
        url = super(SmartTCPServer_for_testing, self).get_url()
772
274
        return url[:-1] + self.client_path_extra
773
275
 
774
276
    def get_bogus_url(self):