~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: Jelmer Vernooij
  • Date: 2011-12-16 16:40:10 UTC
  • mto: This revision was merged to the branch mainline in revision 6391.
  • Revision ID: jelmer@samba.org-20111216164010-z3hy00xrnclnkf7a
Update tests.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2010 Canonical Ltd
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
17
17
import errno
18
18
import socket
19
19
import SocketServer
20
 
import select
21
20
import sys
22
21
import threading
 
22
import traceback
23
23
 
24
24
 
25
25
from bzrlib import (
 
26
    cethread,
 
27
    errors,
26
28
    osutils,
27
29
    transport,
28
30
    urlutils,
212
214
    def start_server(self, backing_server=None):
213
215
        """Setup the Chroot on backing_server."""
214
216
        if backing_server is not None:
215
 
            self.backing_transport = transport.get_transport(
 
217
            self.backing_transport = transport.get_transport_from_url(
216
218
                backing_server.get_url())
217
219
        else:
218
 
            self.backing_transport = transport.get_transport('.')
 
220
            self.backing_transport = transport.get_transport_from_path('.')
219
221
        self.backing_transport.clone('added-by-filter').ensure_base()
220
222
        self.filter_func = lambda x: 'added-by-filter/' + x
221
223
        super(TestingPathFilteringServer, self).start_server()
233
235
    def start_server(self, backing_server=None):
234
236
        """Setup the Chroot on backing_server."""
235
237
        if backing_server is not None:
236
 
            self.backing_transport = transport.get_transport(
 
238
            self.backing_transport = transport.get_transport_from_url(
237
239
                backing_server.get_url())
238
240
        else:
239
 
            self.backing_transport = transport.get_transport('.')
 
241
            self.backing_transport = transport.get_transport_from_path('.')
240
242
        super(TestingChrootServer, self).start_server()
241
243
 
242
244
    def get_bogus_url(self):
243
245
        raise NotImplementedError
244
246
 
245
247
 
246
 
class ThreadWithException(threading.Thread):
247
 
    """A catching exception thread.
248
 
 
249
 
    If an exception occurs during the thread execution, it's caught and
250
 
    re-raised when the thread is joined().
251
 
    """
252
 
 
253
 
    def __init__(self, *args, **kwargs):
254
 
        # There are cases where the calling thread must wait, yet, if an
255
 
        # exception occurs, the event should be set so the caller is not
256
 
        # blocked. The main example is a calling thread that want to wait for
257
 
        # the called thread to be in a given state before continuing.
258
 
        try:
259
 
            event = kwargs.pop('event')
260
 
        except KeyError:
261
 
            # If the caller didn't pass a specific event, create our own
262
 
            event = threading.Event()
263
 
        super(ThreadWithException, self).__init__(*args, **kwargs)
264
 
        self.set_ready_event(event)
265
 
        self.exception = None
266
 
        self.ignored_exceptions = None # see set_ignored_exceptions
267
 
 
268
 
    # compatibility thunk for python-2.4 and python-2.5...
269
 
    if sys.version_info < (2, 6):
270
 
        name = property(threading.Thread.getName, threading.Thread.setName)
271
 
 
272
 
    def set_ready_event(self, event):
273
 
        """Set the ``ready`` event used to synchronize exception catching.
274
 
 
275
 
        When the thread uses an event to synchronize itself with another thread
276
 
        (setting it when the other thread can wake up from a ``wait`` call),
277
 
        the event must be set after catching an exception or the other thread
278
 
        will hang.
279
 
 
280
 
        Some threads require multiple events and should set the relevant one
281
 
        when appropriate.
282
 
        """
283
 
        self.ready = event
284
 
 
285
 
    def set_ignored_exceptions(self, ignored):
286
 
        """Declare which exceptions will be ignored.
287
 
 
288
 
        :param ignored: Can be either:
289
 
           - None: all exceptions will be raised,
290
 
           - an exception class: the instances of this class will be ignored,
291
 
           - a tuple of exception classes: the instances of any class of the
292
 
             list will be ignored,
293
 
           - a callable: that will be passed the exception object
294
 
             and should return True if the exception should be ignored
295
 
        """
296
 
        if ignored is None:
297
 
            self.ignored_exceptions = None
298
 
        elif isinstance(ignored, (Exception, tuple)):
299
 
            self.ignored_exceptions = lambda e: isinstance(e, ignored)
300
 
        else:
301
 
            self.ignored_exceptions = ignored
302
 
 
303
 
    def run(self):
304
 
        """Overrides Thread.run to capture any exception."""
305
 
        self.ready.clear()
306
 
        try:
307
 
            try:
308
 
                super(ThreadWithException, self).run()
309
 
            except:
310
 
                self.exception = sys.exc_info()
311
 
        finally:
312
 
            # Make sure the calling thread is released
313
 
            self.ready.set()
314
 
 
 
248
class TestThread(cethread.CatchingExceptionThread):
315
249
 
316
250
    def join(self, timeout=5):
317
 
        """Overrides Thread.join to raise any exception caught.
318
 
 
319
 
 
320
 
        Calling join(timeout=0) will raise the caught exception or return None
321
 
        if the thread is still alive.
 
251
        """Overrides to use a default timeout.
322
252
 
323
253
        The default timeout is set to 5 and should expire only when a thread
324
254
        serving a client connection is hung.
325
255
        """
326
 
        super(ThreadWithException, self).join(timeout)
327
 
        if self.exception is not None:
328
 
            exc_class, exc_value, exc_tb = self.exception
329
 
            self.exception = None # The exception should be raised only once
330
 
            if (self.ignored_exceptions is None
331
 
                or not self.ignored_exceptions(exc_value)):
332
 
                # Raise non ignored exceptions
333
 
                raise exc_class, exc_value, exc_tb
 
256
        super(TestThread, self).join(timeout)
334
257
        if timeout and self.isAlive():
335
258
            # The timeout expired without joining the thread, the thread is
336
259
            # therefore stucked and that's a failure as far as the test is
343
266
            sys.stderr.write('thread %s hung\n' % (self.name,))
344
267
            #raise AssertionError('thread %s hung' % (self.name,))
345
268
 
346
 
    def pending_exception(self):
347
 
        """Raise the caught exception.
348
 
 
349
 
        This does nothing if no exception occurred.
350
 
        """
351
 
        self.join(timeout=0)
352
 
 
353
 
 
354
 
class TestingTCPServerMixin:
 
269
 
 
270
class TestingTCPServerMixin(object):
355
271
    """Mixin to support running SocketServer.TCPServer in a thread.
356
272
 
357
273
    Tests are connecting from the main thread, the server has to be run in a
373
289
 
374
290
    def serve(self):
375
291
        self.serving = True
376
 
        self.stopped.clear()
377
292
        # We are listening and ready to accept connections
378
293
        self.started.set()
379
294
        try:
398
313
                self.process_request(request, client_address)
399
314
            except:
400
315
                self.handle_error(request, client_address)
401
 
                self.close_request(request)
 
316
        else:
 
317
            self.close_request(request)
402
318
 
403
319
    def get_request(self):
404
320
        return self.socket.accept()
418
334
        # The following can be used for debugging purposes, it will display the
419
335
        # exception and the traceback just when it occurs instead of waiting
420
336
        # for the thread to be joined.
421
 
 
422
337
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
338
 
 
339
        # We call close_request manually, because we are going to raise an
 
340
        # exception. The SocketServer implementation calls:
 
341
        #   handle_error(...)
 
342
        #   close_request(...)
 
343
        # But because we raise the exception, close_request will never be
 
344
        # triggered. This helps client not block waiting for a response when
 
345
        # the server gets an exception.
 
346
        self.close_request(request)
423
347
        raise
424
348
 
425
349
    def ignored_exceptions_during_shutdown(self, e):
505
429
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
506
430
                                                 request_handler_class)
507
431
 
508
 
    def get_request (self):
 
432
    def get_request(self):
509
433
        """Get the request and client address from the socket."""
510
434
        sock, addr = TestingTCPServerMixin.get_request(self)
511
 
        # The thread is not create yet, it will be updated in process_request
 
435
        # The thread is not created yet, it will be updated in process_request
512
436
        self.clients.append((sock, addr, None))
513
437
        return sock, addr
514
438
 
515
 
    def process_request_thread(self, started, stopped, request, client_address):
 
439
    def process_request_thread(self, started, detached, stopped,
 
440
                               request, client_address):
516
441
        started.set()
 
442
        # We will be on our own once the server tells us we're detached
 
443
        detached.wait()
517
444
        SocketServer.ThreadingTCPServer.process_request_thread(
518
445
            self, request, client_address)
519
446
        self.close_request(request)
522
449
    def process_request(self, request, client_address):
523
450
        """Start a new thread to process the request."""
524
451
        started = threading.Event()
 
452
        detached = threading.Event()
525
453
        stopped = threading.Event()
526
 
        t = ThreadWithException(
527
 
            event=stopped,
 
454
        t = TestThread(
 
455
            sync_event=stopped,
528
456
            name='%s -> %s' % (client_address, self.server_address),
529
457
            target = self.process_request_thread,
530
 
            args = (started, stopped, request, client_address))
 
458
            args = (started, detached, stopped, request, client_address))
531
459
        # Update the client description
532
460
        self.clients.pop()
533
461
        self.clients.append((request, client_address, t))
534
 
        # Propagate the exception handler since we must use the same one for
535
 
        # connections running in their own threads than TestingTCPServer.
 
462
        # Propagate the exception handler since we must use the same one as
 
463
        # TestingTCPServer for connections running in their own threads.
536
464
        t.set_ignored_exceptions(self.ignored_exceptions)
537
465
        t.start()
538
466
        started.wait()
539
 
        if debug_threads():
540
 
            sys.stderr.write('Client thread %s started\n' % (t.name,))
541
467
        # If an exception occured during the thread start, it will get raised.
542
468
        t.pending_exception()
 
469
        if debug_threads():
 
470
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
471
        # Tell the thread, it's now on its own for exception handling.
 
472
        detached.set()
543
473
 
544
474
    # The following methods are called by the main thread
545
475
 
590
520
 
591
521
    def start_server(self):
592
522
        self.server = self.create_server()
593
 
        self._server_thread = ThreadWithException(
594
 
            event=self.server.started,
 
523
        self._server_thread = TestThread(
 
524
            sync_event=self.server.started,
595
525
            target=self.run_server)
596
526
        self._server_thread.start()
597
 
        # Wait for the server thread to start (i.e release the lock)
 
527
        # Wait for the server thread to start (i.e. release the lock)
598
528
        self.server.started.wait()
599
529
        # Get the real address, especially the port
600
530
        self.host, self.port = self.server.server_address
607
537
        self._server_thread.pending_exception()
608
538
        # From now on, we'll use a different event to ensure the server can set
609
539
        # its exception
610
 
        self._server_thread.set_ready_event(self.server.stopped)
 
540
        self._server_thread.set_sync_event(self.server.stopped)
611
541
 
612
542
    def run_server(self):
613
543
        self.server.serve()
634
564
                # server thread, it may happen that it's not blocked or even
635
565
                # not started.
636
566
                pass
637
 
            # We start shutting down the client while the server itself is
 
567
            # We start shutting down the clients while the server itself is
638
568
            # shutting down.
639
569
            self.server.stop_client_connections()
640
570
            # Now we wait for the thread running self.server.serve() to finish
675
605
    def __init__(self, request, client_address, server):
676
606
        medium.SmartServerSocketStreamMedium.__init__(
677
607
            self, request, server.backing_transport,
678
 
            server.root_client_path)
 
608
            server.root_client_path,
 
609
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
679
610
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
680
611
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
681
612
                                                 server)
682
613
 
683
614
    def handle(self):
684
 
        while not self.finished:
685
 
            server_protocol = self._build_protocol()
686
 
            self._serve_one_request(server_protocol)
687
 
 
 
615
        try:
 
616
            while not self.finished:
 
617
                server_protocol = self._build_protocol()
 
618
                self._serve_one_request(server_protocol)
 
619
        except errors.ConnectionTimeout:
 
620
            # idle connections aren't considered a failure of the server
 
621
            return
 
622
 
 
623
 
 
624
_DEFAULT_TESTING_CLIENT_TIMEOUT = 60.0
688
625
 
689
626
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
690
627
 
693
630
        TestingThreadingTCPServer.__init__(self, server_address,
694
631
                                           request_handler_class)
695
632
        server.SmartTCPServer.__init__(self, backing_transport,
696
 
                                       root_client_path)
 
633
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
634
 
697
635
    def serve(self):
698
 
        # FIXME: No test are exercising the hooks for the test server
699
 
        # -- vila 20100618
700
636
        self.run_server_started_hooks()
701
637
        try:
702
638
            TestingThreadingTCPServer.serve(self)
753
689
        self.chroot_server = ChrootServer(
754
690
            self.get_backing_transport(backing_transport_server))
755
691
        self.chroot_server.start_server()
756
 
        self.backing_transport = transport.get_transport(
 
692
        self.backing_transport = transport.get_transport_from_url(
757
693
            self.chroot_server.get_url())
758
694
        super(SmartTCPServer_for_testing, self).start_server()
759
695
 
765
701
 
766
702
    def get_backing_transport(self, backing_transport_server):
767
703
        """Get a backing transport from a server we are decorating."""
768
 
        return transport.get_transport(backing_transport_server.get_url())
 
704
        return transport.get_transport_from_url(
 
705
            backing_transport_server.get_url())
769
706
 
770
707
    def get_url(self):
771
708
        url = self.server.get_url()
782
719
    def get_backing_transport(self, backing_transport_server):
783
720
        """Get a backing transport from a server we are decorating."""
784
721
        url = 'readonly+' + backing_transport_server.get_url()
785
 
        return transport.get_transport(url)
 
722
        return transport.get_transport_from_url(url)
786
723
 
787
724
 
788
725
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
803
740
    def get_backing_transport(self, backing_transport_server):
804
741
        """Get a backing transport from a server we are decorating."""
805
742
        url = 'readonly+' + backing_transport_server.get_url()
806
 
        return transport.get_transport(url)
807
 
 
808
 
 
809
 
 
810
 
 
 
743
        return transport.get_transport_from_url(url)