~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Martin Pool
  • Date: 2011-11-29 00:35:22 UTC
  • mto: This revision was merged to the branch mainline in revision 6320.
  • Revision ID: mbp@canonical.com-20111129003522-8ki2s26327416iie
Set a timeout of 120s per test during 'make check'

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2007, 2008, 2010 Canonical Ltd
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import sys
 
21
import threading
 
22
import traceback
 
23
 
 
24
 
17
25
from bzrlib import (
 
26
    cethread,
 
27
    errors,
 
28
    osutils,
18
29
    transport,
19
30
    urlutils,
20
31
    )
22
33
    chroot,
23
34
    pathfilter,
24
35
    )
25
 
from bzrlib.smart import server
 
36
from bzrlib.smart import (
 
37
    medium,
 
38
    server,
 
39
    )
 
40
 
 
41
 
 
42
def debug_threads():
 
43
    # FIXME: There is a dependency loop between bzrlib.tests and
 
44
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
45
    # defining this function is enough for our needs. -- vila 20100611
 
46
    from bzrlib import tests
 
47
    return 'threads' in tests.selftest_debug_flags
26
48
 
27
49
 
28
50
class TestServer(transport.Server):
192
214
    def start_server(self, backing_server=None):
193
215
        """Setup the Chroot on backing_server."""
194
216
        if backing_server is not None:
195
 
            self.backing_transport = transport.get_transport(
 
217
            self.backing_transport = transport.get_transport_from_url(
196
218
                backing_server.get_url())
197
219
        else:
198
 
            self.backing_transport = transport.get_transport('.')
 
220
            self.backing_transport = transport.get_transport_from_path('.')
199
221
        self.backing_transport.clone('added-by-filter').ensure_base()
200
222
        self.filter_func = lambda x: 'added-by-filter/' + x
201
223
        super(TestingPathFilteringServer, self).start_server()
213
235
    def start_server(self, backing_server=None):
214
236
        """Setup the Chroot on backing_server."""
215
237
        if backing_server is not None:
216
 
            self.backing_transport = transport.get_transport(
 
238
            self.backing_transport = transport.get_transport_from_url(
217
239
                backing_server.get_url())
218
240
        else:
219
 
            self.backing_transport = transport.get_transport('.')
 
241
            self.backing_transport = transport.get_transport_from_path('.')
220
242
        super(TestingChrootServer, self).start_server()
221
243
 
222
244
    def get_bogus_url(self):
223
245
        raise NotImplementedError
224
246
 
225
247
 
226
 
class SmartTCPServer_for_testing(server.SmartTCPServer):
 
248
class TestThread(cethread.CatchingExceptionThread):
 
249
 
 
250
    def join(self, timeout=5):
 
251
        """Overrides to use a default timeout.
 
252
 
 
253
        The default timeout is set to 5 and should expire only when a thread
 
254
        serving a client connection is hung.
 
255
        """
 
256
        super(TestThread, self).join(timeout)
 
257
        if timeout and self.isAlive():
 
258
            # The timeout expired without joining the thread, the thread is
 
259
            # therefore stucked and that's a failure as far as the test is
 
260
            # concerned. We used to hang here.
 
261
 
 
262
            # FIXME: we need to kill the thread, but as far as the test is
 
263
            # concerned, raising an assertion is too strong. On most of the
 
264
            # platforms, this doesn't occur, so just mentioning the problem is
 
265
            # enough for now -- vila 2010824
 
266
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
267
            #raise AssertionError('thread %s hung' % (self.name,))
 
268
 
 
269
 
 
270
class TestingTCPServerMixin(object):
 
271
    """Mixin to support running SocketServer.TCPServer in a thread.
 
272
 
 
273
    Tests are connecting from the main thread, the server has to be run in a
 
274
    separate thread.
 
275
    """
 
276
 
 
277
    def __init__(self):
 
278
        self.started = threading.Event()
 
279
        self.serving = None
 
280
        self.stopped = threading.Event()
 
281
        # We collect the resources used by the clients so we can release them
 
282
        # when shutting down
 
283
        self.clients = []
 
284
        self.ignored_exceptions = None
 
285
 
 
286
    def server_bind(self):
 
287
        self.socket.bind(self.server_address)
 
288
        self.server_address = self.socket.getsockname()
 
289
 
 
290
    def serve(self):
 
291
        self.serving = True
 
292
        # We are listening and ready to accept connections
 
293
        self.started.set()
 
294
        try:
 
295
            while self.serving:
 
296
                # Really a connection but the python framework is generic and
 
297
                # call them requests
 
298
                self.handle_request()
 
299
            # Let's close the listening socket
 
300
            self.server_close()
 
301
        finally:
 
302
            self.stopped.set()
 
303
 
 
304
    def handle_request(self):
 
305
        """Handle one request.
 
306
 
 
307
        The python version swallows some socket exceptions and we don't use
 
308
        timeout, so we override it to better control the server behavior.
 
309
        """
 
310
        request, client_address = self.get_request()
 
311
        if self.verify_request(request, client_address):
 
312
            try:
 
313
                self.process_request(request, client_address)
 
314
            except:
 
315
                self.handle_error(request, client_address)
 
316
        else:
 
317
            self.close_request(request)
 
318
 
 
319
    def get_request(self):
 
320
        return self.socket.accept()
 
321
 
 
322
    def verify_request(self, request, client_address):
 
323
        """Verify the request.
 
324
 
 
325
        Return True if we should proceed with this request, False if we should
 
326
        not even touch a single byte in the socket ! This is useful when we
 
327
        stop the server with a dummy last connection.
 
328
        """
 
329
        return self.serving
 
330
 
 
331
    def handle_error(self, request, client_address):
 
332
        # Stop serving and re-raise the last exception seen
 
333
        self.serving = False
 
334
        # The following can be used for debugging purposes, it will display the
 
335
        # exception and the traceback just when it occurs instead of waiting
 
336
        # for the thread to be joined.
 
337
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
338
 
 
339
        # We call close_request manually, because we are going to raise an
 
340
        # exception. The SocketServer implementation calls:
 
341
        #   handle_error(...)
 
342
        #   close_request(...)
 
343
        # But because we raise the exception, close_request will never be
 
344
        # triggered. This helps client not block waiting for a response when
 
345
        # the server gets an exception.
 
346
        self.close_request(request)
 
347
        raise
 
348
 
 
349
    def ignored_exceptions_during_shutdown(self, e):
 
350
        if sys.platform == 'win32':
 
351
            accepted_errnos = [errno.EBADF,
 
352
                               errno.EPIPE,
 
353
                               errno.WSAEBADF,
 
354
                               errno.WSAECONNRESET,
 
355
                               errno.WSAENOTCONN,
 
356
                               errno.WSAESHUTDOWN,
 
357
                               ]
 
358
        else:
 
359
            accepted_errnos = [errno.EBADF,
 
360
                               errno.ECONNRESET,
 
361
                               errno.ENOTCONN,
 
362
                               errno.EPIPE,
 
363
                               ]
 
364
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
365
            return True
 
366
        return False
 
367
 
 
368
    # The following methods are called by the main thread
 
369
 
 
370
    def stop_client_connections(self):
 
371
        while self.clients:
 
372
            c = self.clients.pop()
 
373
            self.shutdown_client(c)
 
374
 
 
375
    def shutdown_socket(self, sock):
 
376
        """Properly shutdown a socket.
 
377
 
 
378
        This should be called only when no other thread is trying to use the
 
379
        socket.
 
380
        """
 
381
        try:
 
382
            sock.shutdown(socket.SHUT_RDWR)
 
383
            sock.close()
 
384
        except Exception, e:
 
385
            if self.ignored_exceptions(e):
 
386
                pass
 
387
            else:
 
388
                raise
 
389
 
 
390
    # The following methods are called by the main thread
 
391
 
 
392
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
393
        self.ignored_exceptions = ignored_exceptions
 
394
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
395
 
 
396
    def _pending_exception(self, thread):
 
397
        """Raise server uncaught exception.
 
398
 
 
399
        Daughter classes can override this if they use daughter threads.
 
400
        """
 
401
        thread.pending_exception()
 
402
 
 
403
 
 
404
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
405
 
 
406
    def __init__(self, server_address, request_handler_class):
 
407
        TestingTCPServerMixin.__init__(self)
 
408
        SocketServer.TCPServer.__init__(self, server_address,
 
409
                                        request_handler_class)
 
410
 
 
411
    def get_request(self):
 
412
        """Get the request and client address from the socket."""
 
413
        sock, addr = TestingTCPServerMixin.get_request(self)
 
414
        self.clients.append((sock, addr))
 
415
        return sock, addr
 
416
 
 
417
    # The following methods are called by the main thread
 
418
 
 
419
    def shutdown_client(self, client):
 
420
        sock, addr = client
 
421
        self.shutdown_socket(sock)
 
422
 
 
423
 
 
424
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
425
                                SocketServer.ThreadingTCPServer):
 
426
 
 
427
    def __init__(self, server_address, request_handler_class):
 
428
        TestingTCPServerMixin.__init__(self)
 
429
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
430
                                                 request_handler_class)
 
431
 
 
432
    def get_request(self):
 
433
        """Get the request and client address from the socket."""
 
434
        sock, addr = TestingTCPServerMixin.get_request(self)
 
435
        # The thread is not create yet, it will be updated in process_request
 
436
        self.clients.append((sock, addr, None))
 
437
        return sock, addr
 
438
 
 
439
    def process_request_thread(self, started, stopped, request, client_address):
 
440
        started.set()
 
441
        SocketServer.ThreadingTCPServer.process_request_thread(
 
442
            self, request, client_address)
 
443
        self.close_request(request)
 
444
        stopped.set()
 
445
 
 
446
    def process_request(self, request, client_address):
 
447
        """Start a new thread to process the request."""
 
448
        started = threading.Event()
 
449
        stopped = threading.Event()
 
450
        t = TestThread(
 
451
            sync_event=stopped,
 
452
            name='%s -> %s' % (client_address, self.server_address),
 
453
            target=self.process_request_thread,
 
454
            args=(started, stopped, request, client_address))
 
455
        # Update the client description
 
456
        self.clients.pop()
 
457
        self.clients.append((request, client_address, t))
 
458
        # Propagate the exception handler since we must use the same one as
 
459
        # TestingTCPServer for connections running in their own threads.
 
460
        t.set_ignored_exceptions(self.ignored_exceptions)
 
461
        t.start()
 
462
        started.wait()
 
463
        if debug_threads():
 
464
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
465
        # If an exception occured during the thread start, it will get raised.
 
466
        # In rare cases, an exception raised during the request processing may
 
467
        # also get caught here (see http://pad.lv/869366)
 
468
        t.pending_exception()
 
469
 
 
470
    # The following methods are called by the main thread
 
471
 
 
472
    def shutdown_client(self, client):
 
473
        sock, addr, connection_thread = client
 
474
        self.shutdown_socket(sock)
 
475
        if connection_thread is not None:
 
476
            # The thread has been created only if the request is processed but
 
477
            # after the connection is inited. This could happen during server
 
478
            # shutdown. If an exception occurred in the thread it will be
 
479
            # re-raised
 
480
            if debug_threads():
 
481
                sys.stderr.write('Client thread %s will be joined\n'
 
482
                                 % (connection_thread.name,))
 
483
            connection_thread.join()
 
484
 
 
485
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
486
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
487
                                                     ignored_exceptions)
 
488
        for sock, addr, connection_thread in self.clients:
 
489
            if connection_thread is not None:
 
490
                connection_thread.set_ignored_exceptions(
 
491
                    self.ignored_exceptions)
 
492
 
 
493
    def _pending_exception(self, thread):
 
494
        for sock, addr, connection_thread in self.clients:
 
495
            if connection_thread is not None:
 
496
                connection_thread.pending_exception()
 
497
        TestingTCPServerMixin._pending_exception(self, thread)
 
498
 
 
499
 
 
500
class TestingTCPServerInAThread(transport.Server):
 
501
    """A server in a thread that re-raise thread exceptions."""
 
502
 
 
503
    def __init__(self, server_address, server_class, request_handler_class):
 
504
        self.server_class = server_class
 
505
        self.request_handler_class = request_handler_class
 
506
        self.host, self.port = server_address
 
507
        self.server = None
 
508
        self._server_thread = None
 
509
 
 
510
    def __repr__(self):
 
511
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
512
 
 
513
    def create_server(self):
 
514
        return self.server_class((self.host, self.port),
 
515
                                 self.request_handler_class)
 
516
 
 
517
    def start_server(self):
 
518
        self.server = self.create_server()
 
519
        self._server_thread = TestThread(
 
520
            sync_event=self.server.started,
 
521
            target=self.run_server)
 
522
        self._server_thread.start()
 
523
        # Wait for the server thread to start (i.e. release the lock)
 
524
        self.server.started.wait()
 
525
        # Get the real address, especially the port
 
526
        self.host, self.port = self.server.server_address
 
527
        self._server_thread.name = self.server.server_address
 
528
        if debug_threads():
 
529
            sys.stderr.write('Server thread %s started\n'
 
530
                             % (self._server_thread.name,))
 
531
        # If an exception occured during the server start, it will get raised,
 
532
        # otherwise, the server is blocked on its accept() call.
 
533
        self._server_thread.pending_exception()
 
534
        # From now on, we'll use a different event to ensure the server can set
 
535
        # its exception
 
536
        self._server_thread.set_sync_event(self.server.stopped)
 
537
 
 
538
    def run_server(self):
 
539
        self.server.serve()
 
540
 
 
541
    def stop_server(self):
 
542
        if self.server is None:
 
543
            return
 
544
        try:
 
545
            # The server has been started successfully, shut it down now.  As
 
546
            # soon as we stop serving, no more connection are accepted except
 
547
            # one to get out of the blocking listen.
 
548
            self.set_ignored_exceptions(
 
549
                self.server.ignored_exceptions_during_shutdown)
 
550
            self.server.serving = False
 
551
            if debug_threads():
 
552
                sys.stderr.write('Server thread %s will be joined\n'
 
553
                                 % (self._server_thread.name,))
 
554
            # The server is listening for a last connection, let's give it:
 
555
            last_conn = None
 
556
            try:
 
557
                last_conn = osutils.connect_socket((self.host, self.port))
 
558
            except socket.error, e:
 
559
                # But ignore connection errors as the point is to unblock the
 
560
                # server thread, it may happen that it's not blocked or even
 
561
                # not started.
 
562
                pass
 
563
            # We start shutting down the clients while the server itself is
 
564
            # shutting down.
 
565
            self.server.stop_client_connections()
 
566
            # Now we wait for the thread running self.server.serve() to finish
 
567
            self.server.stopped.wait()
 
568
            if last_conn is not None:
 
569
                # Close the last connection without trying to use it. The
 
570
                # server will not process a single byte on that socket to avoid
 
571
                # complications (SSL starts with a handshake for example).
 
572
                last_conn.close()
 
573
            # Check for any exception that could have occurred in the server
 
574
            # thread
 
575
            try:
 
576
                self._server_thread.join()
 
577
            except Exception, e:
 
578
                if self.server.ignored_exceptions(e):
 
579
                    pass
 
580
                else:
 
581
                    raise
 
582
        finally:
 
583
            # Make sure we can be called twice safely, note that this means
 
584
            # that we will raise a single exception even if several occurred in
 
585
            # the various threads involved.
 
586
            self.server = None
 
587
 
 
588
    def set_ignored_exceptions(self, ignored_exceptions):
 
589
        """Install an exception handler for the server."""
 
590
        self.server.set_ignored_exceptions(self._server_thread,
 
591
                                           ignored_exceptions)
 
592
 
 
593
    def pending_exception(self):
 
594
        """Raise uncaught exception in the server."""
 
595
        self.server._pending_exception(self._server_thread)
 
596
 
 
597
 
 
598
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
599
                                    medium.SmartServerSocketStreamMedium):
 
600
 
 
601
    def __init__(self, request, client_address, server):
 
602
        medium.SmartServerSocketStreamMedium.__init__(
 
603
            self, request, server.backing_transport,
 
604
            server.root_client_path,
 
605
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
606
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
607
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
608
                                                 server)
 
609
 
 
610
    def handle(self):
 
611
        try:
 
612
            while not self.finished:
 
613
                server_protocol = self._build_protocol()
 
614
                self._serve_one_request(server_protocol)
 
615
        except errors.ConnectionTimeout:
 
616
            # idle connections aren't considered a failure of the server
 
617
            return
 
618
 
 
619
 
 
620
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
 
621
 
 
622
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
623
 
 
624
    def __init__(self, server_address, request_handler_class,
 
625
                 backing_transport, root_client_path):
 
626
        TestingThreadingTCPServer.__init__(self, server_address,
 
627
                                           request_handler_class)
 
628
        server.SmartTCPServer.__init__(self, backing_transport,
 
629
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
630
 
 
631
    def serve(self):
 
632
        self.run_server_started_hooks()
 
633
        try:
 
634
            TestingThreadingTCPServer.serve(self)
 
635
        finally:
 
636
            self.run_server_stopped_hooks()
 
637
 
 
638
    def get_url(self):
 
639
        """Return the url of the server"""
 
640
        return "bzr://%s:%d/" % self.server_address
 
641
 
 
642
 
 
643
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
227
644
    """Server suitable for use by transport tests.
228
645
 
229
646
    This server is backed by the process's cwd.
230
647
    """
231
 
 
232
648
    def __init__(self, thread_name_suffix=''):
233
 
        super(SmartTCPServer_for_testing, self).__init__(None)
234
649
        self.client_path_extra = None
235
650
        self.thread_name_suffix = thread_name_suffix
236
 
 
237
 
    def get_backing_transport(self, backing_transport_server):
238
 
        """Get a backing transport from a server we are decorating."""
239
 
        return transport.get_transport(backing_transport_server.get_url())
 
651
        self.host = '127.0.0.1'
 
652
        self.port = 0
 
653
        super(SmartTCPServer_for_testing, self).__init__(
 
654
                (self.host, self.port),
 
655
                TestingSmartServer,
 
656
                TestingSmartConnectionHandler)
 
657
 
 
658
    def create_server(self):
 
659
        return self.server_class((self.host, self.port),
 
660
                                 self.request_handler_class,
 
661
                                 self.backing_transport,
 
662
                                 self.root_client_path)
 
663
 
240
664
 
241
665
    def start_server(self, backing_transport_server=None,
242
 
              client_path_extra='/extra/'):
 
666
                     client_path_extra='/extra/'):
243
667
        """Set up server for testing.
244
668
 
245
669
        :param backing_transport_server: backing server to use.  If not
254
678
        """
255
679
        if not client_path_extra.startswith('/'):
256
680
            raise ValueError(client_path_extra)
 
681
        self.root_client_path = self.client_path_extra = client_path_extra
257
682
        from bzrlib.transport.chroot import ChrootServer
258
683
        if backing_transport_server is None:
259
684
            backing_transport_server = LocalURLServer()
260
685
        self.chroot_server = ChrootServer(
261
686
            self.get_backing_transport(backing_transport_server))
262
687
        self.chroot_server.start_server()
263
 
        self.backing_transport = transport.get_transport(
 
688
        self.backing_transport = transport.get_transport_from_url(
264
689
            self.chroot_server.get_url())
265
 
        self.root_client_path = self.client_path_extra = client_path_extra
266
 
        self.start_background_thread(self.thread_name_suffix)
 
690
        super(SmartTCPServer_for_testing, self).start_server()
267
691
 
268
692
    def stop_server(self):
269
 
        self.stop_background_thread()
270
 
        self.chroot_server.stop_server()
 
693
        try:
 
694
            super(SmartTCPServer_for_testing, self).stop_server()
 
695
        finally:
 
696
            self.chroot_server.stop_server()
 
697
 
 
698
    def get_backing_transport(self, backing_transport_server):
 
699
        """Get a backing transport from a server we are decorating."""
 
700
        return transport.get_transport_from_url(
 
701
            backing_transport_server.get_url())
271
702
 
272
703
    def get_url(self):
273
 
        url = super(SmartTCPServer_for_testing, self).get_url()
 
704
        url = self.server.get_url()
274
705
        return url[:-1] + self.client_path_extra
275
706
 
276
707
    def get_bogus_url(self):
284
715
    def get_backing_transport(self, backing_transport_server):
285
716
        """Get a backing transport from a server we are decorating."""
286
717
        url = 'readonly+' + backing_transport_server.get_url()
287
 
        return transport.get_transport(url)
 
718
        return transport.get_transport_from_url(url)
288
719
 
289
720
 
290
721
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
305
736
    def get_backing_transport(self, backing_transport_server):
306
737
        """Get a backing transport from a server we are decorating."""
307
738
        url = 'readonly+' + backing_transport_server.get_url()
308
 
        return transport.get_transport(url)
309
 
 
310
 
 
311
 
 
312
 
 
 
739
        return transport.get_transport_from_url(url)