~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/server.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2010-08-30 22:49:20 UTC
  • mfrom: (5397.1.6 jam-integration)
  • Revision ID: pqm@pqm.ubuntu.com-20100830224920-w9zw1vhsd5oiyljv
(vila, jam) Get PQM running correctly again (bug #626667),
        skip test_bzr_connect_to_bzr_ssh (bug #626876)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Server for smart-server protocol."""
 
18
 
 
19
import errno
 
20
import os.path
 
21
import select
 
22
import socket
 
23
import sys
 
24
import threading
 
25
 
 
26
from bzrlib.hooks import HookPoint, Hooks
 
27
from bzrlib import (
 
28
    errors,
 
29
    trace,
 
30
    transport,
 
31
)
 
32
from bzrlib.lazy_import import lazy_import
 
33
lazy_import(globals(), """
 
34
from bzrlib.smart import medium
 
35
from bzrlib.transport import (
 
36
    chroot,
 
37
    get_transport,
 
38
    pathfilter,
 
39
    )
 
40
from bzrlib import (
 
41
    urlutils,
 
42
    )
 
43
""")
 
44
 
 
45
 
 
46
class SmartTCPServer(object):
 
47
    """Listens on a TCP socket and accepts connections from smart clients.
 
48
 
 
49
    Each connection will be served by a SmartServerSocketStreamMedium running in
 
50
    a thread.
 
51
 
 
52
    hooks: An instance of SmartServerHooks.
 
53
    """
 
54
 
 
55
    def __init__(self, backing_transport, root_client_path='/'):
 
56
        """Construct a new server.
 
57
 
 
58
        To actually start it running, call either start_background_thread or
 
59
        serve.
 
60
 
 
61
        :param backing_transport: The transport to serve.
 
62
        :param root_client_path: The client path that will correspond to root
 
63
            of backing_transport.
 
64
        """
 
65
        self.backing_transport = backing_transport
 
66
        self.root_client_path = root_client_path
 
67
 
 
68
    def start_server(self, host, port):
 
69
        """Create the server listening socket.
 
70
 
 
71
        :param host: Name of the interface to listen on.
 
72
        :param port: TCP port to listen on, or 0 to allocate a transient port.
 
73
        """
 
74
        # let connections timeout so that we get a chance to terminate
 
75
        # Keep a reference to the exceptions we want to catch because the socket
 
76
        # module's globals get set to None during interpreter shutdown.
 
77
        from socket import timeout as socket_timeout
 
78
        from socket import error as socket_error
 
79
        self._socket_error = socket_error
 
80
        self._socket_timeout = socket_timeout
 
81
        addrs = socket.getaddrinfo(host, port, socket.AF_UNSPEC,
 
82
            socket.SOCK_STREAM, 0, socket.AI_PASSIVE)[0]
 
83
 
 
84
        (family, socktype, proto, canonname, sockaddr) = addrs
 
85
 
 
86
        self._server_socket = socket.socket(family, socktype, proto)
 
87
        # SO_REUSERADDR has a different meaning on Windows
 
88
        if sys.platform != 'win32':
 
89
            self._server_socket.setsockopt(socket.SOL_SOCKET,
 
90
                socket.SO_REUSEADDR, 1)
 
91
        try:
 
92
            self._server_socket.bind(sockaddr)
 
93
        except self._socket_error, message:
 
94
            raise errors.CannotBindAddress(host, port, message)
 
95
        self._sockname = self._server_socket.getsockname()
 
96
        self.port = self._sockname[1]
 
97
        self._server_socket.listen(1)
 
98
        self._server_socket.settimeout(1)
 
99
        self._started = threading.Event()
 
100
        self._stopped = threading.Event()
 
101
 
 
102
    def _backing_urls(self):
 
103
        # There are three interesting urls:
 
104
        # The URL the server can be contacted on. (e.g. bzr://host/)
 
105
        # The URL that a commit done on the same machine as the server will
 
106
        # have within the servers space. (e.g. file:///home/user/source)
 
107
        # The URL that will be given to other hooks in the same process -
 
108
        # the URL of the backing transport itself. (e.g. chroot+:///)
 
109
        # We need all three because:
 
110
        #  * other machines see the first
 
111
        #  * local commits on this machine should be able to be mapped to
 
112
        #    this server
 
113
        #  * commits the server does itself need to be mapped across to this
 
114
        #    server.
 
115
        # The latter two urls are different aliases to the servers url,
 
116
        # so we group those in a list - as there might be more aliases
 
117
        # in the future.
 
118
        urls = [self.backing_transport.base]
 
119
        try:
 
120
            urls.append(self.backing_transport.external_url())
 
121
        except errors.InProcessTransport:
 
122
            pass
 
123
        return urls
 
124
 
 
125
    def run_server_started_hooks(self, backing_urls=None):
 
126
        if backing_urls is None:
 
127
            backing_urls = self._backing_urls()
 
128
        for hook in SmartTCPServer.hooks['server_started']:
 
129
            hook(backing_urls, self.get_url())
 
130
        for hook in SmartTCPServer.hooks['server_started_ex']:
 
131
            hook(backing_urls, self)
 
132
 
 
133
    def run_server_stopped_hooks(self, backing_urls=None):
 
134
        if backing_urls is None:
 
135
            backing_urls = self._backing_urls()
 
136
        for hook in SmartTCPServer.hooks['server_stopped']:
 
137
            hook(backing_urls, self.get_url())
 
138
 
 
139
    def serve(self, thread_name_suffix=''):
 
140
        self._should_terminate = False
 
141
        # for hooks we are letting code know that a server has started (and
 
142
        # later stopped).
 
143
        self.run_server_started_hooks()
 
144
        self._started.set()
 
145
        try:
 
146
            try:
 
147
                while not self._should_terminate:
 
148
                    try:
 
149
                        conn, client_addr = self._server_socket.accept()
 
150
                    except self._socket_timeout:
 
151
                        # just check if we're asked to stop
 
152
                        pass
 
153
                    except self._socket_error, e:
 
154
                        # if the socket is closed by stop_background_thread
 
155
                        # we might get a EBADF here, any other socket errors
 
156
                        # should get logged.
 
157
                        if e.args[0] != errno.EBADF:
 
158
                            trace.warning("listening socket error: %s", e)
 
159
                    else:
 
160
                        if self._should_terminate:
 
161
                            break
 
162
                        self.serve_conn(conn, thread_name_suffix)
 
163
            except KeyboardInterrupt:
 
164
                # dont log when CTRL-C'd.
 
165
                raise
 
166
            except Exception, e:
 
167
                trace.report_exception(sys.exc_info(), sys.stderr)
 
168
                raise
 
169
        finally:
 
170
            self._stopped.set()
 
171
            try:
 
172
                # ensure the server socket is closed.
 
173
                self._server_socket.close()
 
174
            except self._socket_error:
 
175
                # ignore errors on close
 
176
                pass
 
177
            self.run_server_stopped_hooks()
 
178
 
 
179
    def get_url(self):
 
180
        """Return the url of the server"""
 
181
        return "bzr://%s:%d/" % self._sockname
 
182
 
 
183
    def serve_conn(self, conn, thread_name_suffix):
 
184
        # For WIN32, where the timeout value from the listening socket
 
185
        # propagates to the newly accepted socket.
 
186
        conn.setblocking(True)
 
187
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
188
        handler = medium.SmartServerSocketStreamMedium(
 
189
            conn, self.backing_transport, self.root_client_path)
 
190
        thread_name = 'smart-server-child' + thread_name_suffix
 
191
        connection_thread = threading.Thread(
 
192
            None, handler.serve, name=thread_name)
 
193
        # FIXME: This thread is never joined, it should at least be collected
 
194
        # somewhere so that tests that want to check for leaked threads can get
 
195
        # rid of them -- vila 20100531
 
196
        connection_thread.setDaemon(True)
 
197
        connection_thread.start()
 
198
        return connection_thread
 
199
 
 
200
    def start_background_thread(self, thread_name_suffix=''):
 
201
        self._started.clear()
 
202
        self._server_thread = threading.Thread(None,
 
203
                self.serve, args=(thread_name_suffix,),
 
204
                name='server-' + self.get_url())
 
205
        self._server_thread.setDaemon(True)
 
206
        self._server_thread.start()
 
207
        self._started.wait()
 
208
 
 
209
    def stop_background_thread(self):
 
210
        self._stopped.clear()
 
211
        # tell the main loop to quit on the next iteration.
 
212
        self._should_terminate = True
 
213
        # close the socket - gives error to connections from here on in,
 
214
        # rather than a connection reset error to connections made during
 
215
        # the period between setting _should_terminate = True and
 
216
        # the current request completing/aborting. It may also break out the
 
217
        # main loop if it was currently in accept() (on some platforms).
 
218
        try:
 
219
            self._server_socket.close()
 
220
        except self._socket_error:
 
221
            # ignore errors on close
 
222
            pass
 
223
        if not self._stopped.isSet():
 
224
            # server has not stopped (though it may be stopping)
 
225
            # its likely in accept(), so give it a connection
 
226
            temp_socket = socket.socket()
 
227
            temp_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
228
            if not temp_socket.connect_ex(self._sockname):
 
229
                # and close it immediately: we dont choose to send any requests.
 
230
                temp_socket.close()
 
231
        self._stopped.wait()
 
232
        self._server_thread.join()
 
233
 
 
234
 
 
235
class SmartServerHooks(Hooks):
 
236
    """Hooks for the smart server."""
 
237
 
 
238
    def __init__(self):
 
239
        """Create the default hooks.
 
240
 
 
241
        These are all empty initially, because by default nothing should get
 
242
        notified.
 
243
        """
 
244
        Hooks.__init__(self)
 
245
        self.create_hook(HookPoint('server_started',
 
246
            "Called by the bzr server when it starts serving a directory. "
 
247
            "server_started is called with (backing urls, public url), "
 
248
            "where backing_url is a list of URLs giving the "
 
249
            "server-specific directory locations, and public_url is the "
 
250
            "public URL for the directory being served.", (0, 16), None))
 
251
        self.create_hook(HookPoint('server_started_ex',
 
252
            "Called by the bzr server when it starts serving a directory. "
 
253
            "server_started is called with (backing_urls, server_obj).",
 
254
            (1, 17), None))
 
255
        self.create_hook(HookPoint('server_stopped',
 
256
            "Called by the bzr server when it stops serving a directory. "
 
257
            "server_stopped is called with the same parameters as the "
 
258
            "server_started hook: (backing_urls, public_url).", (0, 16), None))
 
259
 
 
260
SmartTCPServer.hooks = SmartServerHooks()
 
261
 
 
262
 
 
263
def _local_path_for_transport(transport):
 
264
    """Return a local path for transport, if reasonably possible.
 
265
    
 
266
    This function works even if transport's url has a "readonly+" prefix,
 
267
    unlike local_path_from_url.
 
268
    
 
269
    This essentially recovers the --directory argument the user passed to "bzr
 
270
    serve" from the transport passed to serve_bzr.
 
271
    """
 
272
    try:
 
273
        base_url = transport.external_url()
 
274
    except (errors.InProcessTransport, NotImplementedError):
 
275
        return None
 
276
    else:
 
277
        # Strip readonly prefix
 
278
        if base_url.startswith('readonly+'):
 
279
            base_url = base_url[len('readonly+'):]
 
280
        try:
 
281
            return urlutils.local_path_from_url(base_url)
 
282
        except errors.InvalidURL:
 
283
            return None
 
284
 
 
285
 
 
286
class BzrServerFactory(object):
 
287
    """Helper class for serve_bzr."""
 
288
 
 
289
    def __init__(self, userdir_expander=None, get_base_path=None):
 
290
        self.cleanups = []
 
291
        self.base_path = None
 
292
        self.backing_transport = None
 
293
        if userdir_expander is None:
 
294
            userdir_expander = os.path.expanduser
 
295
        self.userdir_expander = userdir_expander
 
296
        if get_base_path is None:
 
297
            get_base_path = _local_path_for_transport
 
298
        self.get_base_path = get_base_path
 
299
 
 
300
    def _expand_userdirs(self, path):
 
301
        """Translate /~/ or /~user/ to e.g. /home/foo, using
 
302
        self.userdir_expander (os.path.expanduser by default).
 
303
 
 
304
        If the translated path would fall outside base_path, or the path does
 
305
        not start with ~, then no translation is applied.
 
306
 
 
307
        If the path is inside, it is adjusted to be relative to the base path.
 
308
 
 
309
        e.g. if base_path is /home, and the expanded path is /home/joe, then
 
310
        the translated path is joe.
 
311
        """
 
312
        result = path
 
313
        if path.startswith('~'):
 
314
            expanded = self.userdir_expander(path)
 
315
            if not expanded.endswith('/'):
 
316
                expanded += '/'
 
317
            if expanded.startswith(self.base_path):
 
318
                result = expanded[len(self.base_path):]
 
319
        return result
 
320
 
 
321
    def _make_expand_userdirs_filter(self, transport):
 
322
        return pathfilter.PathFilteringServer(transport, self._expand_userdirs)
 
323
 
 
324
    def _make_backing_transport(self, transport):
 
325
        """Chroot transport, and decorate with userdir expander."""
 
326
        self.base_path = self.get_base_path(transport)
 
327
        chroot_server = chroot.ChrootServer(transport)
 
328
        chroot_server.start_server()
 
329
        self.cleanups.append(chroot_server.stop_server)
 
330
        transport = get_transport(chroot_server.get_url())
 
331
        if self.base_path is not None:
 
332
            # Decorate the server's backing transport with a filter that can
 
333
            # expand homedirs.
 
334
            expand_userdirs = self._make_expand_userdirs_filter(transport)
 
335
            expand_userdirs.start_server()
 
336
            self.cleanups.append(expand_userdirs.stop_server)
 
337
            transport = get_transport(expand_userdirs.get_url())
 
338
        self.transport = transport
 
339
 
 
340
    def _make_smart_server(self, host, port, inet):
 
341
        if inet:
 
342
            smart_server = medium.SmartServerPipeStreamMedium(
 
343
                sys.stdin, sys.stdout, self.transport)
 
344
        else:
 
345
            if host is None:
 
346
                host = medium.BZR_DEFAULT_INTERFACE
 
347
            if port is None:
 
348
                port = medium.BZR_DEFAULT_PORT
 
349
            smart_server = SmartTCPServer(self.transport)
 
350
            smart_server.start_server(host, port)
 
351
            trace.note('listening on port: %s' % smart_server.port)
 
352
        self.smart_server = smart_server
 
353
 
 
354
    def _change_globals(self):
 
355
        from bzrlib import lockdir, ui
 
356
        # For the duration of this server, no UI output is permitted. note
 
357
        # that this may cause problems with blackbox tests. This should be
 
358
        # changed with care though, as we dont want to use bandwidth sending
 
359
        # progress over stderr to smart server clients!
 
360
        old_factory = ui.ui_factory
 
361
        old_lockdir_timeout = lockdir._DEFAULT_TIMEOUT_SECONDS
 
362
        def restore_default_ui_factory_and_lockdir_timeout():
 
363
            ui.ui_factory = old_factory
 
364
            lockdir._DEFAULT_TIMEOUT_SECONDS = old_lockdir_timeout
 
365
        self.cleanups.append(restore_default_ui_factory_and_lockdir_timeout)
 
366
        ui.ui_factory = ui.SilentUIFactory()
 
367
        lockdir._DEFAULT_TIMEOUT_SECONDS = 0
 
368
 
 
369
    def set_up(self, transport, host, port, inet):
 
370
        self._make_backing_transport(transport)
 
371
        self._make_smart_server(host, port, inet)
 
372
        self._change_globals()
 
373
 
 
374
    def tear_down(self):
 
375
        for cleanup in reversed(self.cleanups):
 
376
            cleanup()
 
377
 
 
378
 
 
379
def serve_bzr(transport, host=None, port=None, inet=False):
 
380
    """This is the default implementation of 'bzr serve'.
 
381
    
 
382
    It creates a TCP or pipe smart server on 'transport, and runs it.  The
 
383
    transport will be decorated with a chroot and pathfilter (using
 
384
    os.path.expanduser).
 
385
    """
 
386
    bzr_server = BzrServerFactory()
 
387
    try:
 
388
        bzr_server.set_up(transport, host, port, inet)
 
389
        bzr_server.smart_server.serve()
 
390
    finally:
 
391
        bzr_server.tear_down()
 
392