~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/server.py

(jameinel) Allow 'bzr serve' to interpret SIGHUP as a graceful shutdown.
 (bug #795025) (John A Meinel)

Show diffs side-by-side

added added

removed removed

Lines of Context:
20
20
import os.path
21
21
import socket
22
22
import sys
 
23
import time
23
24
import threading
24
25
 
25
26
from bzrlib.hooks import Hooks
31
32
from bzrlib.i18n import gettext
32
33
from bzrlib.lazy_import import lazy_import
33
34
lazy_import(globals(), """
34
 
from bzrlib.smart import medium
 
35
from bzrlib.smart import (
 
36
    medium,
 
37
    signals,
 
38
    )
35
39
from bzrlib.transport import (
36
40
    chroot,
37
41
    pathfilter,
38
42
    )
39
43
from bzrlib import (
 
44
    config,
40
45
    urlutils,
41
46
    )
42
47
""")
51
56
    hooks: An instance of SmartServerHooks.
52
57
    """
53
58
 
54
 
    def __init__(self, backing_transport, root_client_path='/'):
 
59
    # This is the timeout on the socket we use .accept() on. It is exposed here
 
60
    # so the test suite can set it faster. (It thread.interrupt_main() will not
 
61
    # fire a KeyboardInterrupt during socket.accept)
 
62
    _ACCEPT_TIMEOUT = 1.0
 
63
    _SHUTDOWN_POLL_TIMEOUT = 1.0
 
64
    _LOG_WAITING_TIMEOUT = 10.0
 
65
 
 
66
    _timer = time.time
 
67
 
 
68
    def __init__(self, backing_transport, root_client_path='/',
 
69
                 client_timeout=None):
55
70
        """Construct a new server.
56
71
 
57
72
        To actually start it running, call either start_background_thread or
60
75
        :param backing_transport: The transport to serve.
61
76
        :param root_client_path: The client path that will correspond to root
62
77
            of backing_transport.
 
78
        :param client_timeout: See SmartServerSocketStreamMedium's timeout
 
79
            parameter.
63
80
        """
64
81
        self.backing_transport = backing_transport
65
82
        self.root_client_path = root_client_path
 
83
        self._client_timeout = client_timeout
 
84
        self._active_connections = []
 
85
        # This is set to indicate we want to wait for clients to finish before
 
86
        # we disconnect.
 
87
        self._gracefully_stopping = False
66
88
 
67
89
    def start_server(self, host, port):
68
90
        """Create the server listening socket.
94
116
        self._sockname = self._server_socket.getsockname()
95
117
        self.port = self._sockname[1]
96
118
        self._server_socket.listen(1)
97
 
        self._server_socket.settimeout(1)
 
119
        self._server_socket.settimeout(self._ACCEPT_TIMEOUT)
 
120
        # Once we start accept()ing connections, we set started.
98
121
        self._started = threading.Event()
 
122
        # Once we stop accept()ing connections (and are closing the socket) we
 
123
        # set _stopped
99
124
        self._stopped = threading.Event()
 
125
        # Once we have finished waiting for all clients, etc. We set
 
126
        # _fully_stopped
 
127
        self._fully_stopped = threading.Event()
100
128
 
101
129
    def _backing_urls(self):
102
130
        # There are three interesting urls:
135
163
        for hook in SmartTCPServer.hooks['server_stopped']:
136
164
            hook(backing_urls, self.get_url())
137
165
 
 
166
    def _stop_gracefully(self):
 
167
        trace.note(gettext('Requested to stop gracefully'))
 
168
        self._should_terminate = True
 
169
        self._gracefully_stopping = True
 
170
        for handler, _ in self._active_connections:
 
171
            handler._stop_gracefully()
 
172
 
 
173
    def _wait_for_clients_to_disconnect(self):
 
174
        self._poll_active_connections()
 
175
        if not self._active_connections:
 
176
            return
 
177
        trace.note(gettext('Waiting for %d client(s) to finish')
 
178
                   % (len(self._active_connections),))
 
179
        t_next_log = self._timer() + self._LOG_WAITING_TIMEOUT
 
180
        while self._active_connections:
 
181
            now = self._timer()
 
182
            if now >= t_next_log:
 
183
                trace.note(gettext('Still waiting for %d client(s) to finish')
 
184
                           % (len(self._active_connections),))
 
185
                t_next_log = now + self._LOG_WAITING_TIMEOUT
 
186
            self._poll_active_connections(self._SHUTDOWN_POLL_TIMEOUT)
 
187
 
138
188
    def serve(self, thread_name_suffix=''):
 
189
        # Note: There is a temptation to do
 
190
        #       signals.register_on_hangup(id(self), self._stop_gracefully)
 
191
        #       However, that creates a temporary object which is a bound
 
192
        #       method. signals._on_sighup is a WeakKeyDictionary so it
 
193
        #       immediately gets garbage collected, because nothing else
 
194
        #       references it. Instead, we need to keep a real reference to the
 
195
        #       bound method for the lifetime of the serve() function.
 
196
        stop_gracefully = self._stop_gracefully
 
197
        signals.register_on_hangup(id(self), stop_gracefully)
139
198
        self._should_terminate = False
140
199
        # for hooks we are letting code know that a server has started (and
141
200
        # later stopped).
151
210
                        pass
152
211
                    except self._socket_error, e:
153
212
                        # if the socket is closed by stop_background_thread
154
 
                        # we might get a EBADF here, any other socket errors
155
 
                        # should get logged.
156
 
                        if e.args[0] != errno.EBADF:
157
 
                            trace.warning("listening socket error: %s", e)
 
213
                        # we might get a EBADF here, or if we get a signal we
 
214
                        # can get EINTR, any other socket errors should get
 
215
                        # logged.
 
216
                        if e.args[0] not in (errno.EBADF, errno.EINTR):
 
217
                            trace.warning(gettext("listening socket error: %s")
 
218
                                          % (e,))
158
219
                    else:
159
220
                        if self._should_terminate:
 
221
                            conn.close()
160
222
                            break
161
223
                        self.serve_conn(conn, thread_name_suffix)
 
224
                    # Cleanout any threads that have finished processing.
 
225
                    self._poll_active_connections()
162
226
            except KeyboardInterrupt:
163
227
                # dont log when CTRL-C'd.
164
228
                raise
166
230
                trace.report_exception(sys.exc_info(), sys.stderr)
167
231
                raise
168
232
        finally:
169
 
            self._stopped.set()
170
233
            try:
171
234
                # ensure the server socket is closed.
172
235
                self._server_socket.close()
173
236
            except self._socket_error:
174
237
                # ignore errors on close
175
238
                pass
 
239
            self._stopped.set()
 
240
            signals.unregister_on_hangup(id(self))
176
241
            self.run_server_stopped_hooks()
 
242
        if self._gracefully_stopping:
 
243
            self._wait_for_clients_to_disconnect()
 
244
        self._fully_stopped.set()
177
245
 
178
246
    def get_url(self):
179
247
        """Return the url of the server"""
180
248
        return "bzr://%s:%s/" % (self._sockname[0], self._sockname[1])
181
249
 
 
250
    def _make_handler(self, conn):
 
251
        return medium.SmartServerSocketStreamMedium(
 
252
            conn, self.backing_transport, self.root_client_path,
 
253
            timeout=self._client_timeout)
 
254
 
 
255
    def _poll_active_connections(self, timeout=0.0):
 
256
        """Check to see if any active connections have finished.
 
257
 
 
258
        This will iterate through self._active_connections, and update any
 
259
        connections that are finished.
 
260
 
 
261
        :param timeout: The timeout to pass to thread.join(). By default, we
 
262
            set it to 0, so that we don't hang if threads are not done yet.
 
263
        :return: None
 
264
        """
 
265
        still_active = []
 
266
        for handler, thread in self._active_connections:
 
267
            thread.join(timeout)
 
268
            if thread.isAlive():
 
269
                still_active.append((handler, thread))
 
270
        self._active_connections = still_active
 
271
 
182
272
    def serve_conn(self, conn, thread_name_suffix):
183
273
        # For WIN32, where the timeout value from the listening socket
184
274
        # propagates to the newly accepted socket.
185
275
        conn.setblocking(True)
186
276
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
187
 
        handler = medium.SmartServerSocketStreamMedium(
188
 
            conn, self.backing_transport, self.root_client_path)
189
277
        thread_name = 'smart-server-child' + thread_name_suffix
 
278
        handler = self._make_handler(conn)
190
279
        connection_thread = threading.Thread(
191
280
            None, handler.serve, name=thread_name)
192
 
        # FIXME: This thread is never joined, it should at least be collected
193
 
        # somewhere so that tests that want to check for leaked threads can get
194
 
        # rid of them -- vila 20100531
 
281
        self._active_connections.append((handler, connection_thread))
195
282
        connection_thread.setDaemon(True)
196
283
        connection_thread.start()
197
284
        return connection_thread
341
428
            transport = _mod_transport.get_transport_from_url(expand_userdirs.get_url())
342
429
        self.transport = transport
343
430
 
344
 
    def _make_smart_server(self, host, port, inet):
 
431
    def _get_stdin_stdout(self):
 
432
        return sys.stdin, sys.stdout
 
433
 
 
434
    def _make_smart_server(self, host, port, inet, timeout):
 
435
        if timeout is None:
 
436
            c = config.GlobalStack()
 
437
            timeout = c.get('serve.client_timeout')
345
438
        if inet:
 
439
            stdin, stdout = self._get_stdin_stdout()
346
440
            smart_server = medium.SmartServerPipeStreamMedium(
347
 
                sys.stdin, sys.stdout, self.transport)
 
441
                stdin, stdout, self.transport, timeout=timeout)
348
442
        else:
349
443
            if host is None:
350
444
                host = medium.BZR_DEFAULT_INTERFACE
351
445
            if port is None:
352
446
                port = medium.BZR_DEFAULT_PORT
353
 
            smart_server = SmartTCPServer(self.transport)
 
447
            smart_server = SmartTCPServer(self.transport,
 
448
                                          client_timeout=timeout)
354
449
            smart_server.start_server(host, port)
355
450
            trace.note(gettext('listening on port: %s') % smart_server.port)
356
451
        self.smart_server = smart_server
369
464
        self.cleanups.append(restore_default_ui_factory_and_lockdir_timeout)
370
465
        ui.ui_factory = ui.SilentUIFactory()
371
466
        lockdir._DEFAULT_TIMEOUT_SECONDS = 0
 
467
        orig = signals.install_sighup_handler()
 
468
        def restore_signals():
 
469
            signals.restore_sighup_handler(orig)
 
470
        self.cleanups.append(restore_signals)
372
471
 
373
 
    def set_up(self, transport, host, port, inet):
 
472
    def set_up(self, transport, host, port, inet, timeout):
374
473
        self._make_backing_transport(transport)
375
 
        self._make_smart_server(host, port, inet)
 
474
        self._make_smart_server(host, port, inet, timeout)
376
475
        self._change_globals()
377
476
 
378
477
    def tear_down(self):
379
478
        for cleanup in reversed(self.cleanups):
380
479
            cleanup()
381
480
 
382
 
def serve_bzr(transport, host=None, port=None, inet=False):
 
481
 
 
482
def serve_bzr(transport, host=None, port=None, inet=False, timeout=None):
383
483
    """This is the default implementation of 'bzr serve'.
384
 
    
 
484
 
385
485
    It creates a TCP or pipe smart server on 'transport, and runs it.  The
386
486
    transport will be decorated with a chroot and pathfilter (using
387
487
    os.path.expanduser).
388
488
    """
389
489
    bzr_server = BzrServerFactory()
390
490
    try:
391
 
        bzr_server.set_up(transport, host, port, inet)
 
491
        bzr_server.set_up(transport, host, port, inet, timeout)
392
492
        bzr_server.smart_server.serve()
393
493
    except:
394
494
        hook_caught_exception = False