~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_server.py

  • Committer: John Arbash Meinel
  • Date: 2011-10-03 14:15:44 UTC
  • mto: This revision was merged to the branch mainline in revision 6186.
  • Revision ID: john@arbash-meinel.com-20111003141544-2upoh3swgxqerfv7
Separate the comments, for vila.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2007, 2008, 2010 Canonical Ltd
 
1
# Copyright (C) 2010, 2011 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
14
14
# along with this program; if not, write to the Free Software
15
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
 
17
import errno
 
18
import socket
 
19
import SocketServer
 
20
import sys
 
21
import threading
 
22
import traceback
 
23
 
 
24
 
17
25
from bzrlib import (
 
26
    cethread,
 
27
    errors,
 
28
    osutils,
18
29
    transport,
19
30
    urlutils,
20
31
    )
22
33
    chroot,
23
34
    pathfilter,
24
35
    )
25
 
from bzrlib.smart import server
 
36
from bzrlib.smart import (
 
37
    medium,
 
38
    server,
 
39
    )
 
40
 
 
41
 
 
42
def debug_threads():
 
43
    # FIXME: There is a dependency loop between bzrlib.tests and
 
44
    # bzrlib.tests.test_server that needs to be fixed. In the mean time
 
45
    # defining this function is enough for our needs. -- vila 20100611
 
46
    from bzrlib import tests
 
47
    return 'threads' in tests.selftest_debug_flags
26
48
 
27
49
 
28
50
class TestServer(transport.Server):
192
214
    def start_server(self, backing_server=None):
193
215
        """Setup the Chroot on backing_server."""
194
216
        if backing_server is not None:
195
 
            self.backing_transport = transport.get_transport(
 
217
            self.backing_transport = transport.get_transport_from_url(
196
218
                backing_server.get_url())
197
219
        else:
198
 
            self.backing_transport = transport.get_transport('.')
 
220
            self.backing_transport = transport.get_transport_from_path('.')
199
221
        self.backing_transport.clone('added-by-filter').ensure_base()
200
222
        self.filter_func = lambda x: 'added-by-filter/' + x
201
223
        super(TestingPathFilteringServer, self).start_server()
213
235
    def start_server(self, backing_server=None):
214
236
        """Setup the Chroot on backing_server."""
215
237
        if backing_server is not None:
216
 
            self.backing_transport = transport.get_transport(
 
238
            self.backing_transport = transport.get_transport_from_url(
217
239
                backing_server.get_url())
218
240
        else:
219
 
            self.backing_transport = transport.get_transport('.')
 
241
            self.backing_transport = transport.get_transport_from_path('.')
220
242
        super(TestingChrootServer, self).start_server()
221
243
 
222
244
    def get_bogus_url(self):
223
245
        raise NotImplementedError
224
246
 
225
247
 
226
 
class SmartTCPServer_for_testing(server.SmartTCPServer):
 
248
class TestThread(cethread.CatchingExceptionThread):
 
249
 
 
250
    def join(self, timeout=5):
 
251
        """Overrides to use a default timeout.
 
252
 
 
253
        The default timeout is set to 5 and should expire only when a thread
 
254
        serving a client connection is hung.
 
255
        """
 
256
        super(TestThread, self).join(timeout)
 
257
        if timeout and self.isAlive():
 
258
            # The timeout expired without joining the thread, the thread is
 
259
            # therefore stucked and that's a failure as far as the test is
 
260
            # concerned. We used to hang here.
 
261
 
 
262
            # FIXME: we need to kill the thread, but as far as the test is
 
263
            # concerned, raising an assertion is too strong. On most of the
 
264
            # platforms, this doesn't occur, so just mentioning the problem is
 
265
            # enough for now -- vila 2010824
 
266
            sys.stderr.write('thread %s hung\n' % (self.name,))
 
267
            #raise AssertionError('thread %s hung' % (self.name,))
 
268
 
 
269
 
 
270
class TestingTCPServerMixin(object):
 
271
    """Mixin to support running SocketServer.TCPServer in a thread.
 
272
 
 
273
    Tests are connecting from the main thread, the server has to be run in a
 
274
    separate thread.
 
275
    """
 
276
 
 
277
    def __init__(self):
 
278
        self.started = threading.Event()
 
279
        self.serving = None
 
280
        self.stopped = threading.Event()
 
281
        # We collect the resources used by the clients so we can release them
 
282
        # when shutting down
 
283
        self.clients = []
 
284
        self.ignored_exceptions = None
 
285
 
 
286
    def server_bind(self):
 
287
        self.socket.bind(self.server_address)
 
288
        self.server_address = self.socket.getsockname()
 
289
 
 
290
    def serve(self):
 
291
        self.serving = True
 
292
        self.stopped.clear()
 
293
        # We are listening and ready to accept connections
 
294
        self.started.set()
 
295
        try:
 
296
            while self.serving:
 
297
                # Really a connection but the python framework is generic and
 
298
                # call them requests
 
299
                self.handle_request()
 
300
            # Let's close the listening socket
 
301
            self.server_close()
 
302
        finally:
 
303
            self.stopped.set()
 
304
 
 
305
    def handle_request(self):
 
306
        """Handle one request.
 
307
 
 
308
        The python version swallows some socket exceptions and we don't use
 
309
        timeout, so we override it to better control the server behavior.
 
310
        """
 
311
        request, client_address = self.get_request()
 
312
        if self.verify_request(request, client_address):
 
313
            try:
 
314
                self.process_request(request, client_address)
 
315
            except:
 
316
                self.handle_error(request, client_address)
 
317
        else:
 
318
            self.close_request(request)
 
319
 
 
320
    def get_request(self):
 
321
        return self.socket.accept()
 
322
 
 
323
    def verify_request(self, request, client_address):
 
324
        """Verify the request.
 
325
 
 
326
        Return True if we should proceed with this request, False if we should
 
327
        not even touch a single byte in the socket ! This is useful when we
 
328
        stop the server with a dummy last connection.
 
329
        """
 
330
        return self.serving
 
331
 
 
332
    def handle_error(self, request, client_address):
 
333
        # Stop serving and re-raise the last exception seen
 
334
        self.serving = False
 
335
        # The following can be used for debugging purposes, it will display the
 
336
        # exception and the traceback just when it occurs instead of waiting
 
337
        # for the thread to be joined.
 
338
        # SocketServer.BaseServer.handle_error(self, request, client_address)
 
339
 
 
340
        # We call close_request manually, because we are going to raise an
 
341
        # exception. The SocketServer implementation calls:
 
342
        #   handle_error(...)
 
343
        #   close_request(...)
 
344
        # But because we raise the exception, close_request will never be
 
345
        # triggered. This helps client not block waiting for a response when
 
346
        # the server gets an exception.
 
347
        self.close_request(request)
 
348
        raise
 
349
 
 
350
    def ignored_exceptions_during_shutdown(self, e):
 
351
        if sys.platform == 'win32':
 
352
            accepted_errnos = [errno.EBADF,
 
353
                               errno.EPIPE,
 
354
                               errno.WSAEBADF,
 
355
                               errno.WSAECONNRESET,
 
356
                               errno.WSAENOTCONN,
 
357
                               errno.WSAESHUTDOWN,
 
358
                               ]
 
359
        else:
 
360
            accepted_errnos = [errno.EBADF,
 
361
                               errno.ECONNRESET,
 
362
                               errno.ENOTCONN,
 
363
                               errno.EPIPE,
 
364
                               ]
 
365
        if isinstance(e, socket.error) and e[0] in accepted_errnos:
 
366
            return True
 
367
        return False
 
368
 
 
369
    # The following methods are called by the main thread
 
370
 
 
371
    def stop_client_connections(self):
 
372
        while self.clients:
 
373
            c = self.clients.pop()
 
374
            self.shutdown_client(c)
 
375
 
 
376
    def shutdown_socket(self, sock):
 
377
        """Properly shutdown a socket.
 
378
 
 
379
        This should be called only when no other thread is trying to use the
 
380
        socket.
 
381
        """
 
382
        try:
 
383
            sock.shutdown(socket.SHUT_RDWR)
 
384
            sock.close()
 
385
        except Exception, e:
 
386
            if self.ignored_exceptions(e):
 
387
                pass
 
388
            else:
 
389
                raise
 
390
 
 
391
    # The following methods are called by the main thread
 
392
 
 
393
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
394
        self.ignored_exceptions = ignored_exceptions
 
395
        thread.set_ignored_exceptions(self.ignored_exceptions)
 
396
 
 
397
    def _pending_exception(self, thread):
 
398
        """Raise server uncaught exception.
 
399
 
 
400
        Daughter classes can override this if they use daughter threads.
 
401
        """
 
402
        thread.pending_exception()
 
403
 
 
404
 
 
405
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
 
406
 
 
407
    def __init__(self, server_address, request_handler_class):
 
408
        TestingTCPServerMixin.__init__(self)
 
409
        SocketServer.TCPServer.__init__(self, server_address,
 
410
                                        request_handler_class)
 
411
 
 
412
    def get_request(self):
 
413
        """Get the request and client address from the socket."""
 
414
        sock, addr = TestingTCPServerMixin.get_request(self)
 
415
        self.clients.append((sock, addr))
 
416
        return sock, addr
 
417
 
 
418
    # The following methods are called by the main thread
 
419
 
 
420
    def shutdown_client(self, client):
 
421
        sock, addr = client
 
422
        self.shutdown_socket(sock)
 
423
 
 
424
 
 
425
class TestingThreadingTCPServer(TestingTCPServerMixin,
 
426
                                SocketServer.ThreadingTCPServer):
 
427
 
 
428
    def __init__(self, server_address, request_handler_class):
 
429
        TestingTCPServerMixin.__init__(self)
 
430
        SocketServer.ThreadingTCPServer.__init__(self, server_address,
 
431
                                                 request_handler_class)
 
432
 
 
433
    def get_request(self):
 
434
        """Get the request and client address from the socket."""
 
435
        sock, addr = TestingTCPServerMixin.get_request(self)
 
436
        # The thread is not create yet, it will be updated in process_request
 
437
        self.clients.append((sock, addr, None))
 
438
        return sock, addr
 
439
 
 
440
    def process_request_thread(self, started, stopped, request, client_address):
 
441
        started.set()
 
442
        SocketServer.ThreadingTCPServer.process_request_thread(
 
443
            self, request, client_address)
 
444
        self.close_request(request)
 
445
        stopped.set()
 
446
 
 
447
    def process_request(self, request, client_address):
 
448
        """Start a new thread to process the request."""
 
449
        started = threading.Event()
 
450
        stopped = threading.Event()
 
451
        t = TestThread(
 
452
            sync_event=stopped,
 
453
            name='%s -> %s' % (client_address, self.server_address),
 
454
            target=self.process_request_thread,
 
455
            args=(started, stopped, request, client_address))
 
456
        # Update the client description
 
457
        self.clients.pop()
 
458
        self.clients.append((request, client_address, t))
 
459
        # Propagate the exception handler since we must use the same one as
 
460
        # TestingTCPServer for connections running in their own threads.
 
461
        t.set_ignored_exceptions(self.ignored_exceptions)
 
462
        t.start()
 
463
        started.wait()
 
464
        if debug_threads():
 
465
            sys.stderr.write('Client thread %s started\n' % (t.name,))
 
466
        # If an exception occured during the thread start, it will get raised.
 
467
        t.pending_exception()
 
468
 
 
469
    # The following methods are called by the main thread
 
470
 
 
471
    def shutdown_client(self, client):
 
472
        sock, addr, connection_thread = client
 
473
        self.shutdown_socket(sock)
 
474
        if connection_thread is not None:
 
475
            # The thread has been created only if the request is processed but
 
476
            # after the connection is inited. This could happen during server
 
477
            # shutdown. If an exception occurred in the thread it will be
 
478
            # re-raised
 
479
            if debug_threads():
 
480
                sys.stderr.write('Client thread %s will be joined\n'
 
481
                                 % (connection_thread.name,))
 
482
            connection_thread.join()
 
483
 
 
484
    def set_ignored_exceptions(self, thread, ignored_exceptions):
 
485
        TestingTCPServerMixin.set_ignored_exceptions(self, thread,
 
486
                                                     ignored_exceptions)
 
487
        for sock, addr, connection_thread in self.clients:
 
488
            if connection_thread is not None:
 
489
                connection_thread.set_ignored_exceptions(
 
490
                    self.ignored_exceptions)
 
491
 
 
492
    def _pending_exception(self, thread):
 
493
        for sock, addr, connection_thread in self.clients:
 
494
            if connection_thread is not None:
 
495
                connection_thread.pending_exception()
 
496
        TestingTCPServerMixin._pending_exception(self, thread)
 
497
 
 
498
 
 
499
class TestingTCPServerInAThread(transport.Server):
 
500
    """A server in a thread that re-raise thread exceptions."""
 
501
 
 
502
    def __init__(self, server_address, server_class, request_handler_class):
 
503
        self.server_class = server_class
 
504
        self.request_handler_class = request_handler_class
 
505
        self.host, self.port = server_address
 
506
        self.server = None
 
507
        self._server_thread = None
 
508
 
 
509
    def __repr__(self):
 
510
        return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
 
511
 
 
512
    def create_server(self):
 
513
        return self.server_class((self.host, self.port),
 
514
                                 self.request_handler_class)
 
515
 
 
516
    def start_server(self):
 
517
        self.server = self.create_server()
 
518
        self._server_thread = TestThread(
 
519
            sync_event=self.server.started,
 
520
            target=self.run_server)
 
521
        self._server_thread.start()
 
522
        # Wait for the server thread to start (i.e release the lock)
 
523
        self.server.started.wait()
 
524
        # Get the real address, especially the port
 
525
        self.host, self.port = self.server.server_address
 
526
        self._server_thread.name = self.server.server_address
 
527
        if debug_threads():
 
528
            sys.stderr.write('Server thread %s started\n'
 
529
                             % (self._server_thread.name,))
 
530
        # If an exception occured during the server start, it will get raised,
 
531
        # otherwise, the server is blocked on its accept() call.
 
532
        self._server_thread.pending_exception()
 
533
        # From now on, we'll use a different event to ensure the server can set
 
534
        # its exception
 
535
        self._server_thread.set_sync_event(self.server.stopped)
 
536
 
 
537
    def run_server(self):
 
538
        self.server.serve()
 
539
 
 
540
    def stop_server(self):
 
541
        if self.server is None:
 
542
            return
 
543
        try:
 
544
            # The server has been started successfully, shut it down now.  As
 
545
            # soon as we stop serving, no more connection are accepted except
 
546
            # one to get out of the blocking listen.
 
547
            self.set_ignored_exceptions(
 
548
                self.server.ignored_exceptions_during_shutdown)
 
549
            self.server.serving = False
 
550
            if debug_threads():
 
551
                sys.stderr.write('Server thread %s will be joined\n'
 
552
                                 % (self._server_thread.name,))
 
553
            # The server is listening for a last connection, let's give it:
 
554
            last_conn = None
 
555
            try:
 
556
                last_conn = osutils.connect_socket((self.host, self.port))
 
557
            except socket.error, e:
 
558
                # But ignore connection errors as the point is to unblock the
 
559
                # server thread, it may happen that it's not blocked or even
 
560
                # not started.
 
561
                pass
 
562
            # We start shutting down the clients while the server itself is
 
563
            # shutting down.
 
564
            self.server.stop_client_connections()
 
565
            # Now we wait for the thread running self.server.serve() to finish
 
566
            self.server.stopped.wait()
 
567
            if last_conn is not None:
 
568
                # Close the last connection without trying to use it. The
 
569
                # server will not process a single byte on that socket to avoid
 
570
                # complications (SSL starts with a handshake for example).
 
571
                last_conn.close()
 
572
            # Check for any exception that could have occurred in the server
 
573
            # thread
 
574
            try:
 
575
                self._server_thread.join()
 
576
            except Exception, e:
 
577
                if self.server.ignored_exceptions(e):
 
578
                    pass
 
579
                else:
 
580
                    raise
 
581
        finally:
 
582
            # Make sure we can be called twice safely, note that this means
 
583
            # that we will raise a single exception even if several occurred in
 
584
            # the various threads involved.
 
585
            self.server = None
 
586
 
 
587
    def set_ignored_exceptions(self, ignored_exceptions):
 
588
        """Install an exception handler for the server."""
 
589
        self.server.set_ignored_exceptions(self._server_thread,
 
590
                                           ignored_exceptions)
 
591
 
 
592
    def pending_exception(self):
 
593
        """Raise uncaught exception in the server."""
 
594
        self.server._pending_exception(self._server_thread)
 
595
 
 
596
 
 
597
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
 
598
                                    medium.SmartServerSocketStreamMedium):
 
599
 
 
600
    def __init__(self, request, client_address, server):
 
601
        medium.SmartServerSocketStreamMedium.__init__(
 
602
            self, request, server.backing_transport,
 
603
            server.root_client_path,
 
604
            timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
605
        request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
606
        SocketServer.BaseRequestHandler.__init__(self, request, client_address,
 
607
                                                 server)
 
608
 
 
609
    def handle(self):
 
610
        try:
 
611
            while not self.finished:
 
612
                server_protocol = self._build_protocol()
 
613
                self._serve_one_request(server_protocol)
 
614
        except errors.ConnectionTimeout:
 
615
            # idle connections aren't considered a failure of the server
 
616
            return
 
617
 
 
618
 
 
619
_DEFAULT_TESTING_CLIENT_TIMEOUT = 4.0
 
620
 
 
621
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
 
622
 
 
623
    def __init__(self, server_address, request_handler_class,
 
624
                 backing_transport, root_client_path):
 
625
        TestingThreadingTCPServer.__init__(self, server_address,
 
626
                                           request_handler_class)
 
627
        server.SmartTCPServer.__init__(self, backing_transport,
 
628
            root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
 
629
 
 
630
    def serve(self):
 
631
        self.run_server_started_hooks()
 
632
        try:
 
633
            TestingThreadingTCPServer.serve(self)
 
634
        finally:
 
635
            self.run_server_stopped_hooks()
 
636
 
 
637
    def get_url(self):
 
638
        """Return the url of the server"""
 
639
        return "bzr://%s:%d/" % self.server_address
 
640
 
 
641
 
 
642
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
227
643
    """Server suitable for use by transport tests.
228
644
 
229
645
    This server is backed by the process's cwd.
230
646
    """
231
 
 
232
647
    def __init__(self, thread_name_suffix=''):
233
 
        super(SmartTCPServer_for_testing, self).__init__(None)
234
648
        self.client_path_extra = None
235
649
        self.thread_name_suffix = thread_name_suffix
236
 
 
237
 
    def get_backing_transport(self, backing_transport_server):
238
 
        """Get a backing transport from a server we are decorating."""
239
 
        return transport.get_transport(backing_transport_server.get_url())
 
650
        self.host = '127.0.0.1'
 
651
        self.port = 0
 
652
        super(SmartTCPServer_for_testing, self).__init__(
 
653
                (self.host, self.port),
 
654
                TestingSmartServer,
 
655
                TestingSmartConnectionHandler)
 
656
 
 
657
    def create_server(self):
 
658
        return self.server_class((self.host, self.port),
 
659
                                 self.request_handler_class,
 
660
                                 self.backing_transport,
 
661
                                 self.root_client_path)
 
662
 
240
663
 
241
664
    def start_server(self, backing_transport_server=None,
242
 
              client_path_extra='/extra/'):
 
665
                     client_path_extra='/extra/'):
243
666
        """Set up server for testing.
244
667
 
245
668
        :param backing_transport_server: backing server to use.  If not
254
677
        """
255
678
        if not client_path_extra.startswith('/'):
256
679
            raise ValueError(client_path_extra)
 
680
        self.root_client_path = self.client_path_extra = client_path_extra
257
681
        from bzrlib.transport.chroot import ChrootServer
258
682
        if backing_transport_server is None:
259
683
            backing_transport_server = LocalURLServer()
260
684
        self.chroot_server = ChrootServer(
261
685
            self.get_backing_transport(backing_transport_server))
262
686
        self.chroot_server.start_server()
263
 
        self.backing_transport = transport.get_transport(
 
687
        self.backing_transport = transport.get_transport_from_url(
264
688
            self.chroot_server.get_url())
265
 
        self.root_client_path = self.client_path_extra = client_path_extra
266
 
        self.start_background_thread(self.thread_name_suffix)
 
689
        super(SmartTCPServer_for_testing, self).start_server()
267
690
 
268
691
    def stop_server(self):
269
 
        self.stop_background_thread()
270
 
        self.chroot_server.stop_server()
 
692
        try:
 
693
            super(SmartTCPServer_for_testing, self).stop_server()
 
694
        finally:
 
695
            self.chroot_server.stop_server()
 
696
 
 
697
    def get_backing_transport(self, backing_transport_server):
 
698
        """Get a backing transport from a server we are decorating."""
 
699
        return transport.get_transport_from_url(
 
700
            backing_transport_server.get_url())
271
701
 
272
702
    def get_url(self):
273
 
        url = super(SmartTCPServer_for_testing, self).get_url()
 
703
        url = self.server.get_url()
274
704
        return url[:-1] + self.client_path_extra
275
705
 
276
706
    def get_bogus_url(self):
284
714
    def get_backing_transport(self, backing_transport_server):
285
715
        """Get a backing transport from a server we are decorating."""
286
716
        url = 'readonly+' + backing_transport_server.get_url()
287
 
        return transport.get_transport(url)
 
717
        return transport.get_transport_from_url(url)
288
718
 
289
719
 
290
720
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
305
735
    def get_backing_transport(self, backing_transport_server):
306
736
        """Get a backing transport from a server we are decorating."""
307
737
        url = 'readonly+' + backing_transport_server.get_url()
308
 
        return transport.get_transport(url)
309
 
 
310
 
 
311
 
 
312
 
 
 
738
        return transport.get_transport_from_url(url)