1
# Copyright (C) 2010, 2011 Canonical Ltd
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
30
from bzrlib.transport import (
34
from bzrlib.smart import (
41
# FIXME: There is a dependency loop between bzrlib.tests and
42
# bzrlib.tests.test_server that needs to be fixed. In the mean time
43
# defining this function is enough for our needs. -- vila 20100611
44
from bzrlib import tests
45
return 'threads' in tests.selftest_debug_flags
48
class TestServer(transport.Server):
49
"""A Transport Server dedicated to tests.
51
The TestServer interface provides a server for a given transport. We use
52
these servers as loopback testing tools. For any given transport the
53
Servers it provides must either allow writing, or serve the contents
54
of os.getcwdu() at the time start_server is called.
56
Note that these are real servers - they must implement all the things
57
that we want bzr transports to take advantage of.
61
"""Return a url for this server.
63
If the transport does not represent a disk directory (i.e. it is
64
a database like svn, or a memory only transport, it should return
65
a connection to a newly established resource for this Server.
66
Otherwise it should return a url that will provide access to the path
67
that was os.getcwdu() when start_server() was called.
69
Subsequent calls will return the same resource.
71
raise NotImplementedError
73
def get_bogus_url(self):
74
"""Return a url for this protocol, that will fail to connect.
76
This may raise NotImplementedError to indicate that this server cannot
79
raise NotImplementedError
82
class LocalURLServer(TestServer):
83
"""A pretend server for local transports, using file:// urls.
85
Of course no actual server is required to access the local filesystem, so
86
this just exists to tell the test code how to get to it.
89
def start_server(self):
93
"""See Transport.Server.get_url."""
94
return urlutils.local_path_to_url('')
97
class DecoratorServer(TestServer):
98
"""Server for the TransportDecorator for testing with.
100
To use this when subclassing TransportDecorator, override override the
101
get_decorator_class method.
104
def start_server(self, server=None):
105
"""See bzrlib.transport.Server.start_server.
107
:server: decorate the urls given by server. If not provided a
108
LocalServer is created.
110
if server is not None:
111
self._made_server = False
112
self._server = server
114
self._made_server = True
115
self._server = LocalURLServer()
116
self._server.start_server()
118
def stop_server(self):
119
if self._made_server:
120
self._server.stop_server()
122
def get_decorator_class(self):
123
"""Return the class of the decorators we should be constructing."""
124
raise NotImplementedError(self.get_decorator_class)
126
def get_url_prefix(self):
127
"""What URL prefix does this decorator produce?"""
128
return self.get_decorator_class()._get_url_prefix()
130
def get_bogus_url(self):
131
"""See bzrlib.transport.Server.get_bogus_url."""
132
return self.get_url_prefix() + self._server.get_bogus_url()
135
"""See bzrlib.transport.Server.get_url."""
136
return self.get_url_prefix() + self._server.get_url()
139
class BrokenRenameServer(DecoratorServer):
140
"""Server for the BrokenRenameTransportDecorator for testing with."""
142
def get_decorator_class(self):
143
from bzrlib.transport import brokenrename
144
return brokenrename.BrokenRenameTransportDecorator
147
class FakeNFSServer(DecoratorServer):
148
"""Server for the FakeNFSTransportDecorator for testing with."""
150
def get_decorator_class(self):
151
from bzrlib.transport import fakenfs
152
return fakenfs.FakeNFSTransportDecorator
155
class FakeVFATServer(DecoratorServer):
156
"""A server that suggests connections through FakeVFATTransportDecorator
161
def get_decorator_class(self):
162
from bzrlib.transport import fakevfat
163
return fakevfat.FakeVFATTransportDecorator
166
class LogDecoratorServer(DecoratorServer):
167
"""Server for testing."""
169
def get_decorator_class(self):
170
from bzrlib.transport import log
171
return log.TransportLogDecorator
174
class NoSmartTransportServer(DecoratorServer):
175
"""Server for the NoSmartTransportDecorator for testing with."""
177
def get_decorator_class(self):
178
from bzrlib.transport import nosmart
179
return nosmart.NoSmartTransportDecorator
182
class ReadonlyServer(DecoratorServer):
183
"""Server for the ReadonlyTransportDecorator for testing with."""
185
def get_decorator_class(self):
186
from bzrlib.transport import readonly
187
return readonly.ReadonlyTransportDecorator
190
class TraceServer(DecoratorServer):
191
"""Server for the TransportTraceDecorator for testing with."""
193
def get_decorator_class(self):
194
from bzrlib.transport import trace
195
return trace.TransportTraceDecorator
198
class UnlistableServer(DecoratorServer):
199
"""Server for the UnlistableTransportDecorator for testing with."""
201
def get_decorator_class(self):
202
from bzrlib.transport import unlistable
203
return unlistable.UnlistableTransportDecorator
206
class TestingPathFilteringServer(pathfilter.PathFilteringServer):
209
"""TestingPathFilteringServer is not usable until start_server
212
def start_server(self, backing_server=None):
213
"""Setup the Chroot on backing_server."""
214
if backing_server is not None:
215
self.backing_transport = transport.get_transport_from_url(
216
backing_server.get_url())
218
self.backing_transport = transport.get_transport_from_path('.')
219
self.backing_transport.clone('added-by-filter').ensure_base()
220
self.filter_func = lambda x: 'added-by-filter/' + x
221
super(TestingPathFilteringServer, self).start_server()
223
def get_bogus_url(self):
224
raise NotImplementedError
227
class TestingChrootServer(chroot.ChrootServer):
230
"""TestingChrootServer is not usable until start_server is called."""
231
super(TestingChrootServer, self).__init__(None)
233
def start_server(self, backing_server=None):
234
"""Setup the Chroot on backing_server."""
235
if backing_server is not None:
236
self.backing_transport = transport.get_transport_from_url(
237
backing_server.get_url())
239
self.backing_transport = transport.get_transport_from_path('.')
240
super(TestingChrootServer, self).start_server()
242
def get_bogus_url(self):
243
raise NotImplementedError
246
class TestThread(cethread.CatchingExceptionThread):
248
def join(self, timeout=5):
249
"""Overrides to use a default timeout.
251
The default timeout is set to 5 and should expire only when a thread
252
serving a client connection is hung.
254
super(TestThread, self).join(timeout)
255
if timeout and self.isAlive():
256
# The timeout expired without joining the thread, the thread is
257
# therefore stucked and that's a failure as far as the test is
258
# concerned. We used to hang here.
260
# FIXME: we need to kill the thread, but as far as the test is
261
# concerned, raising an assertion is too strong. On most of the
262
# platforms, this doesn't occur, so just mentioning the problem is
263
# enough for now -- vila 2010824
264
sys.stderr.write('thread %s hung\n' % (self.name,))
265
#raise AssertionError('thread %s hung' % (self.name,))
268
class TestingTCPServerMixin(object):
269
"""Mixin to support running SocketServer.TCPServer in a thread.
271
Tests are connecting from the main thread, the server has to be run in a
276
self.started = threading.Event()
278
self.stopped = threading.Event()
279
# We collect the resources used by the clients so we can release them
282
self.ignored_exceptions = None
284
def server_bind(self):
285
self.socket.bind(self.server_address)
286
self.server_address = self.socket.getsockname()
291
# We are listening and ready to accept connections
295
# Really a connection but the python framework is generic and
297
self.handle_request()
298
# Let's close the listening socket
303
def handle_request(self):
304
"""Handle one request.
306
The python version swallows some socket exceptions and we don't use
307
timeout, so we override it to better control the server behavior.
309
request, client_address = self.get_request()
310
if self.verify_request(request, client_address):
312
self.process_request(request, client_address)
314
self.handle_error(request, client_address)
315
self.close_request(request)
317
def get_request(self):
318
return self.socket.accept()
320
def verify_request(self, request, client_address):
321
"""Verify the request.
323
Return True if we should proceed with this request, False if we should
324
not even touch a single byte in the socket ! This is useful when we
325
stop the server with a dummy last connection.
329
def handle_error(self, request, client_address):
330
# Stop serving and re-raise the last exception seen
332
# The following can be used for debugging purposes, it will display the
333
# exception and the traceback just when it occurs instead of waiting
334
# for the thread to be joined.
336
# SocketServer.BaseServer.handle_error(self, request, client_address)
339
def ignored_exceptions_during_shutdown(self, e):
340
if sys.platform == 'win32':
341
accepted_errnos = [errno.EBADF,
349
accepted_errnos = [errno.EBADF,
354
if isinstance(e, socket.error) and e[0] in accepted_errnos:
358
# The following methods are called by the main thread
360
def stop_client_connections(self):
362
c = self.clients.pop()
363
self.shutdown_client(c)
365
def shutdown_socket(self, sock):
366
"""Properly shutdown a socket.
368
This should be called only when no other thread is trying to use the
372
sock.shutdown(socket.SHUT_RDWR)
375
if self.ignored_exceptions(e):
380
# The following methods are called by the main thread
382
def set_ignored_exceptions(self, thread, ignored_exceptions):
383
self.ignored_exceptions = ignored_exceptions
384
thread.set_ignored_exceptions(self.ignored_exceptions)
386
def _pending_exception(self, thread):
387
"""Raise server uncaught exception.
389
Daughter classes can override this if they use daughter threads.
391
thread.pending_exception()
394
class TestingTCPServer(TestingTCPServerMixin, SocketServer.TCPServer):
396
def __init__(self, server_address, request_handler_class):
397
TestingTCPServerMixin.__init__(self)
398
SocketServer.TCPServer.__init__(self, server_address,
399
request_handler_class)
401
def get_request(self):
402
"""Get the request and client address from the socket."""
403
sock, addr = TestingTCPServerMixin.get_request(self)
404
self.clients.append((sock, addr))
407
# The following methods are called by the main thread
409
def shutdown_client(self, client):
411
self.shutdown_socket(sock)
414
class TestingThreadingTCPServer(TestingTCPServerMixin,
415
SocketServer.ThreadingTCPServer):
417
def __init__(self, server_address, request_handler_class):
418
TestingTCPServerMixin.__init__(self)
419
SocketServer.ThreadingTCPServer.__init__(self, server_address,
420
request_handler_class)
422
def get_request (self):
423
"""Get the request and client address from the socket."""
424
sock, addr = TestingTCPServerMixin.get_request(self)
425
# The thread is not create yet, it will be updated in process_request
426
self.clients.append((sock, addr, None))
429
def process_request_thread(self, started, stopped, request, client_address):
431
SocketServer.ThreadingTCPServer.process_request_thread(
432
self, request, client_address)
433
self.close_request(request)
436
def process_request(self, request, client_address):
437
"""Start a new thread to process the request."""
438
started = threading.Event()
439
stopped = threading.Event()
442
name='%s -> %s' % (client_address, self.server_address),
443
target = self.process_request_thread,
444
args = (started, stopped, request, client_address))
445
# Update the client description
447
self.clients.append((request, client_address, t))
448
# Propagate the exception handler since we must use the same one as
449
# TestingTCPServer for connections running in their own threads.
450
t.set_ignored_exceptions(self.ignored_exceptions)
454
sys.stderr.write('Client thread %s started\n' % (t.name,))
455
# If an exception occured during the thread start, it will get raised.
456
t.pending_exception()
458
# The following methods are called by the main thread
460
def shutdown_client(self, client):
461
sock, addr, connection_thread = client
462
self.shutdown_socket(sock)
463
if connection_thread is not None:
464
# The thread has been created only if the request is processed but
465
# after the connection is inited. This could happen during server
466
# shutdown. If an exception occurred in the thread it will be
469
sys.stderr.write('Client thread %s will be joined\n'
470
% (connection_thread.name,))
471
connection_thread.join()
473
def set_ignored_exceptions(self, thread, ignored_exceptions):
474
TestingTCPServerMixin.set_ignored_exceptions(self, thread,
476
for sock, addr, connection_thread in self.clients:
477
if connection_thread is not None:
478
connection_thread.set_ignored_exceptions(
479
self.ignored_exceptions)
481
def _pending_exception(self, thread):
482
for sock, addr, connection_thread in self.clients:
483
if connection_thread is not None:
484
connection_thread.pending_exception()
485
TestingTCPServerMixin._pending_exception(self, thread)
488
class TestingTCPServerInAThread(transport.Server):
489
"""A server in a thread that re-raise thread exceptions."""
491
def __init__(self, server_address, server_class, request_handler_class):
492
self.server_class = server_class
493
self.request_handler_class = request_handler_class
494
self.host, self.port = server_address
496
self._server_thread = None
499
return "%s(%s:%s)" % (self.__class__.__name__, self.host, self.port)
501
def create_server(self):
502
return self.server_class((self.host, self.port),
503
self.request_handler_class)
505
def start_server(self):
506
self.server = self.create_server()
507
self._server_thread = TestThread(
508
sync_event=self.server.started,
509
target=self.run_server)
510
self._server_thread.start()
511
# Wait for the server thread to start (i.e release the lock)
512
self.server.started.wait()
513
# Get the real address, especially the port
514
self.host, self.port = self.server.server_address
515
self._server_thread.name = self.server.server_address
517
sys.stderr.write('Server thread %s started\n'
518
% (self._server_thread.name,))
519
# If an exception occured during the server start, it will get raised,
520
# otherwise, the server is blocked on its accept() call.
521
self._server_thread.pending_exception()
522
# From now on, we'll use a different event to ensure the server can set
524
self._server_thread.set_sync_event(self.server.stopped)
526
def run_server(self):
529
def stop_server(self):
530
if self.server is None:
533
# The server has been started successfully, shut it down now. As
534
# soon as we stop serving, no more connection are accepted except
535
# one to get out of the blocking listen.
536
self.set_ignored_exceptions(
537
self.server.ignored_exceptions_during_shutdown)
538
self.server.serving = False
540
sys.stderr.write('Server thread %s will be joined\n'
541
% (self._server_thread.name,))
542
# The server is listening for a last connection, let's give it:
545
last_conn = osutils.connect_socket((self.host, self.port))
546
except socket.error, e:
547
# But ignore connection errors as the point is to unblock the
548
# server thread, it may happen that it's not blocked or even
551
# We start shutting down the clients while the server itself is
553
self.server.stop_client_connections()
554
# Now we wait for the thread running self.server.serve() to finish
555
self.server.stopped.wait()
556
if last_conn is not None:
557
# Close the last connection without trying to use it. The
558
# server will not process a single byte on that socket to avoid
559
# complications (SSL starts with a handshake for example).
561
# Check for any exception that could have occurred in the server
564
self._server_thread.join()
566
if self.server.ignored_exceptions(e):
571
# Make sure we can be called twice safely, note that this means
572
# that we will raise a single exception even if several occurred in
573
# the various threads involved.
576
def set_ignored_exceptions(self, ignored_exceptions):
577
"""Install an exception handler for the server."""
578
self.server.set_ignored_exceptions(self._server_thread,
581
def pending_exception(self):
582
"""Raise uncaught exception in the server."""
583
self.server._pending_exception(self._server_thread)
586
class TestingSmartConnectionHandler(SocketServer.BaseRequestHandler,
587
medium.SmartServerSocketStreamMedium):
589
def __init__(self, request, client_address, server):
590
medium.SmartServerSocketStreamMedium.__init__(
591
self, request, server.backing_transport,
592
server.root_client_path,
593
timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
594
request.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
595
SocketServer.BaseRequestHandler.__init__(self, request, client_address,
599
while not self.finished:
600
server_protocol = self._build_protocol()
601
self._serve_one_request(server_protocol)
604
_DEFAULT_TESTING_CLIENT_TIMEOUT = 4.0
606
class TestingSmartServer(TestingThreadingTCPServer, server.SmartTCPServer):
608
def __init__(self, server_address, request_handler_class,
609
backing_transport, root_client_path):
610
TestingThreadingTCPServer.__init__(self, server_address,
611
request_handler_class)
612
server.SmartTCPServer.__init__(self, backing_transport,
613
root_client_path, client_timeout=_DEFAULT_TESTING_CLIENT_TIMEOUT)
616
self.run_server_started_hooks()
618
TestingThreadingTCPServer.serve(self)
620
self.run_server_stopped_hooks()
623
"""Return the url of the server"""
624
return "bzr://%s:%d/" % self.server_address
627
class SmartTCPServer_for_testing(TestingTCPServerInAThread):
628
"""Server suitable for use by transport tests.
630
This server is backed by the process's cwd.
632
def __init__(self, thread_name_suffix=''):
633
self.client_path_extra = None
634
self.thread_name_suffix = thread_name_suffix
635
self.host = '127.0.0.1'
637
super(SmartTCPServer_for_testing, self).__init__(
638
(self.host, self.port),
640
TestingSmartConnectionHandler)
642
def create_server(self):
643
return self.server_class((self.host, self.port),
644
self.request_handler_class,
645
self.backing_transport,
646
self.root_client_path)
649
def start_server(self, backing_transport_server=None,
650
client_path_extra='/extra/'):
651
"""Set up server for testing.
653
:param backing_transport_server: backing server to use. If not
654
specified, a LocalURLServer at the current working directory will
656
:param client_path_extra: a path segment starting with '/' to append to
657
the root URL for this server. For instance, a value of '/foo/bar/'
658
will mean the root of the backing transport will be published at a
659
URL like `bzr://127.0.0.1:nnnn/foo/bar/`, rather than
660
`bzr://127.0.0.1:nnnn/`. Default value is `extra`, so that tests
661
by default will fail unless they do the necessary path translation.
663
if not client_path_extra.startswith('/'):
664
raise ValueError(client_path_extra)
665
self.root_client_path = self.client_path_extra = client_path_extra
666
from bzrlib.transport.chroot import ChrootServer
667
if backing_transport_server is None:
668
backing_transport_server = LocalURLServer()
669
self.chroot_server = ChrootServer(
670
self.get_backing_transport(backing_transport_server))
671
self.chroot_server.start_server()
672
self.backing_transport = transport.get_transport_from_url(
673
self.chroot_server.get_url())
674
super(SmartTCPServer_for_testing, self).start_server()
676
def stop_server(self):
678
super(SmartTCPServer_for_testing, self).stop_server()
680
self.chroot_server.stop_server()
682
def get_backing_transport(self, backing_transport_server):
683
"""Get a backing transport from a server we are decorating."""
684
return transport.get_transport_from_url(
685
backing_transport_server.get_url())
688
url = self.server.get_url()
689
return url[:-1] + self.client_path_extra
691
def get_bogus_url(self):
692
"""Return a URL which will fail to connect"""
693
return 'bzr://127.0.0.1:1/'
696
class ReadonlySmartTCPServer_for_testing(SmartTCPServer_for_testing):
697
"""Get a readonly server for testing."""
699
def get_backing_transport(self, backing_transport_server):
700
"""Get a backing transport from a server we are decorating."""
701
url = 'readonly+' + backing_transport_server.get_url()
702
return transport.get_transport_from_url(url)
705
class SmartTCPServer_for_testing_v2_only(SmartTCPServer_for_testing):
706
"""A variation of SmartTCPServer_for_testing that limits the client to
707
using RPCs in protocol v2 (i.e. bzr <= 1.5).
711
url = super(SmartTCPServer_for_testing_v2_only, self).get_url()
712
url = 'bzr-v2://' + url[len('bzr://'):]
716
class ReadonlySmartTCPServer_for_testing_v2_only(
717
SmartTCPServer_for_testing_v2_only):
718
"""Get a readonly server for testing."""
720
def get_backing_transport(self, backing_transport_server):
721
"""Get a backing transport from a server we are decorating."""
722
url = 'readonly+' + backing_transport_server.get_url()
723
return transport.get_transport_from_url(url)