~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/server.py

Merge from bzr.dev, resolving conflicts.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006 Canonical Ltd
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
16
16
 
17
17
"""Server for smart-server protocol."""
18
18
 
 
19
import errno
19
20
import socket
20
21
import os
21
22
import threading
22
23
 
 
24
from bzrlib.hooks import Hooks
23
25
from bzrlib.smart import medium
24
26
from bzrlib import (
25
27
    trace,
26
28
    transport,
27
29
    urlutils,
28
30
)
 
31
from bzrlib.smart.medium import SmartServerSocketStreamMedium
29
32
 
30
33
 
31
34
class SmartTCPServer(object):
32
35
    """Listens on a TCP socket and accepts connections from smart clients.
33
 
    
 
36
 
34
37
    Each connection will be served by a SmartServerSocketStreamMedium running in
35
 
    thread.
 
38
    a thread.
 
39
 
 
40
    hooks: An instance of SmartServerHooks.
36
41
    """
37
42
 
38
43
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
44
49
        :param host: Name of the interface to listen on.
45
50
        :param port: TCP port to listen on, or 0 to allocate a transient port.
46
51
        """
 
52
        # let connections timeout so that we get a chance to terminate
 
53
        # Keep a reference to the exceptions we want to catch because the socket
 
54
        # module's globals get set to None during interpreter shutdown.
 
55
        from socket import timeout as socket_timeout
 
56
        from socket import error as socket_error
 
57
        self._socket_error = socket_error
 
58
        self._socket_timeout = socket_timeout
47
59
        self._server_socket = socket.socket()
48
60
        self._server_socket.bind((host, port))
49
 
        self.port = self._server_socket.getsockname()[1]
 
61
        self._sockname = self._server_socket.getsockname()
 
62
        self.port = self._sockname[1]
50
63
        self._server_socket.listen(1)
51
64
        self._server_socket.settimeout(1)
52
65
        self.backing_transport = backing_transport
 
66
        self._started = threading.Event()
 
67
        self._stopped = threading.Event()
53
68
 
54
69
    def serve(self):
55
 
        # let connections timeout so that we get a chance to terminate
56
 
        # Keep a reference to the exceptions we want to catch because the socket
57
 
        # module's globals get set to None during interpreter shutdown.
58
 
        from socket import timeout as socket_timeout
59
 
        from socket import error as socket_error
60
70
        self._should_terminate = False
61
 
        while not self._should_terminate:
62
 
            try:
63
 
                self.accept_and_serve()
64
 
            except socket_timeout:
65
 
                # just check if we're asked to stop
66
 
                pass
67
 
            except socket_error, e:
68
 
                trace.warning("client disconnected: %s", e)
69
 
                pass
 
71
        for hook in SmartTCPServer.hooks['server_started']:
 
72
            hook(self.backing_transport.base, self.get_url())
 
73
        self._started.set()
 
74
        try:
 
75
            try:
 
76
                while not self._should_terminate:
 
77
                    try:
 
78
                        conn, client_addr = self._server_socket.accept()
 
79
                    except self._socket_timeout:
 
80
                        # just check if we're asked to stop
 
81
                        pass
 
82
                    except self._socket_error, e:
 
83
                        # if the socket is closed by stop_background_thread
 
84
                        # we might get a EBADF here, any other socket errors
 
85
                        # should get logged.
 
86
                        if e.args[0] != errno.EBADF:
 
87
                            trace.warning("listening socket error: %s", e)
 
88
                    else:
 
89
                        self.serve_conn(conn)
 
90
            except KeyboardInterrupt:
 
91
                # dont log when CTRL-C'd.
 
92
                raise
 
93
            except Exception, e:
 
94
                trace.error("Unhandled smart server error.")
 
95
                trace.log_exception_quietly()
 
96
                raise
 
97
        finally:
 
98
            self._stopped.set()
 
99
            try:
 
100
                # ensure the server socket is closed.
 
101
                self._server_socket.close()
 
102
            except self._socket_error:
 
103
                # ignore errors on close
 
104
                pass
 
105
            for hook in SmartTCPServer.hooks['server_stopped']:
 
106
                hook(self.backing_transport.base, self.get_url())
70
107
 
71
108
    def get_url(self):
72
109
        """Return the url of the server"""
73
 
        return "bzr://%s:%d/" % self._server_socket.getsockname()
 
110
        return "bzr://%s:%d/" % self._sockname
74
111
 
75
 
    def accept_and_serve(self):
76
 
        conn, client_addr = self._server_socket.accept()
 
112
    def serve_conn(self, conn):
77
113
        # For WIN32, where the timeout value from the listening socket
78
114
        # propogates to the newly accepted socket.
79
115
        conn.setblocking(True)
80
116
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
81
 
        handler = medium.SmartServerSocketStreamMedium(conn, self.backing_transport)
82
 
        connection_thread = threading.Thread(
83
 
            None, handler.serve, name='smart-server-child')
 
117
        handler = SmartServerSocketStreamMedium(conn, self.backing_transport)
 
118
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
84
119
        connection_thread.setDaemon(True)
85
120
        connection_thread.start()
86
121
 
87
122
    def start_background_thread(self):
 
123
        self._started.clear()
88
124
        self._server_thread = threading.Thread(None,
89
125
                self.serve,
90
126
                name='server-' + self.get_url())
91
127
        self._server_thread.setDaemon(True)
92
128
        self._server_thread.start()
 
129
        self._started.wait()
93
130
 
94
131
    def stop_background_thread(self):
 
132
        self._stopped.clear()
 
133
        # tell the main loop to quit on the next iteration.
95
134
        self._should_terminate = True
96
 
        # self._server_socket.close()
97
 
        # we used to join the thread, but it's not really necessary; it will
98
 
        # terminate in time
99
 
        ## self._server_thread.join()
100
 
 
 
135
        # close the socket - gives error to connections from here on in,
 
136
        # rather than a connection reset error to connections made during
 
137
        # the period between setting _should_terminate = True and 
 
138
        # the current request completing/aborting. It may also break out the
 
139
        # main loop if it was currently in accept() (on some platforms).
 
140
        try:
 
141
            self._server_socket.close()
 
142
        except self._socket_error:
 
143
            # ignore errors on close
 
144
            pass
 
145
        if not self._stopped.isSet():
 
146
            # server has not stopped (though it may be stopping)
 
147
            # its likely in accept(), so give it a connection
 
148
            temp_socket = socket.socket()
 
149
            temp_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
150
            if not temp_socket.connect_ex(self._sockname):
 
151
                # and close it immediately: we dont choose to send any requests.
 
152
                temp_socket.close()
 
153
        self._stopped.wait()
 
154
        self._server_thread.join()
 
155
 
 
156
 
 
157
class SmartServerHooks(Hooks):
 
158
    """Hooks for the smart server."""
 
159
 
 
160
    def __init__(self):
 
161
        """Create the default hooks.
 
162
 
 
163
        These are all empty initially, because by default nothing should get
 
164
        notified.
 
165
        """
 
166
        Hooks.__init__(self)
 
167
        # Introduced in 0.16:
 
168
        # invoked whenever the server starts serving a directory.
 
169
        # The api signature is (backing url, public url).
 
170
        self['server_started'] = []
 
171
        # Introduced in 0.16:
 
172
        # invoked whenever the server stops serving a directory.
 
173
        # The api signature is (backing url, public url).
 
174
        self['server_stopped'] = []
 
175
 
 
176
SmartTCPServer.hooks = SmartServerHooks()
101
177
 
102
178
 
103
179
class SmartTCPServer_for_testing(SmartTCPServer):
104
180
    """Server suitable for use by transport tests.
105
181
    
106
 
    This server has a _homedir of the current cwd.
 
182
    This server is backed by the process's cwd.
107
183
    """
108
184
 
109
185
    def __init__(self):
110
 
        # The server is set up by default like for inetd access: the backing
111
 
        # transport is connected to a local path that is not '/'.
112
186
        SmartTCPServer.__init__(self, None)
113
 
 
 
187
        
114
188
    def get_backing_transport(self, backing_transport_server):
115
189
        """Get a backing transport from a server we are decorating."""
116
190
        return transport.get_transport(backing_transport_server.get_url())