~bzr-pqm/bzr/bzr.dev

2400.1.3 by Andrew Bennetts
Split smart transport code into several separate modules.
1
# Copyright (C) 2006,2007 Canonical Ltd
2
#
3
# This program is free software; you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation; either version 2 of the License, or
6
# (at your option) any later version.
7
#
8
# This program is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
# GNU General Public License for more details.
12
#
13
# You should have received a copy of the GNU General Public License
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
17
"""The 'medium' layer for the smart servers and clients.
18
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
20
for "the quality between big and small."
21
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
23
over SSH), and pass them to and from the protocol logic.  See the overview in
24
bzrlib/transport/smart/__init__.py.
25
"""
26
27
import os
28
import socket
29
import sys
30
from bzrlib import errors
31
from bzrlib.smart.protocol import SmartServerRequestProtocolOne
32
33
try:
34
    from bzrlib.transport import ssh
35
except errors.ParamikoNotPresent:
36
    # no paramiko.  SmartSSHClientMedium will break.
37
    pass
38
39
40
class SmartServerStreamMedium(object):
41
    """Handles smart commands coming over a stream.
42
43
    The stream may be a pipe connected to sshd, or a tcp socket, or an
44
    in-process fifo for testing.
45
46
    One instance is created for each connected client; it can serve multiple
47
    requests in the lifetime of the connection.
48
49
    The server passes requests through to an underlying backing transport, 
50
    which will typically be a LocalTransport looking at the server's filesystem.
51
    """
52
53
    def __init__(self, backing_transport):
54
        """Construct new server.
55
56
        :param backing_transport: Transport for the directory served.
57
        """
58
        # backing_transport could be passed to serve instead of __init__
59
        self.backing_transport = backing_transport
60
        self.finished = False
61
62
    def serve(self):
63
        """Serve requests until the client disconnects."""
64
        # Keep a reference to stderr because the sys module's globals get set to
65
        # None during interpreter shutdown.
66
        from sys import stderr
67
        try:
68
            while not self.finished:
69
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
70
                                                         self._write_out)
71
                self._serve_one_request(protocol)
72
        except Exception, e:
73
            stderr.write("%s terminating on exception %s\n" % (self, e))
74
            raise
75
76
    def _serve_one_request(self, protocol):
77
        """Read one request from input, process, send back a response.
78
        
79
        :param protocol: a SmartServerRequestProtocol.
80
        """
81
        try:
82
            self._serve_one_request_unguarded(protocol)
83
        except KeyboardInterrupt:
84
            raise
85
        except Exception, e:
86
            self.terminate_due_to_error()
87
88
    def terminate_due_to_error(self):
89
        """Called when an unhandled exception from the protocol occurs."""
90
        raise NotImplementedError(self.terminate_due_to_error)
91
92
93
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
94
95
    def __init__(self, sock, backing_transport):
96
        """Constructor.
97
98
        :param sock: the socket the server will read from.  It will be put
99
            into blocking mode.
100
        """
101
        SmartServerStreamMedium.__init__(self, backing_transport)
102
        self.push_back = ''
103
        sock.setblocking(True)
104
        self.socket = sock
105
106
    def _serve_one_request_unguarded(self, protocol):
107
        while protocol.next_read_size():
108
            if self.push_back:
109
                protocol.accept_bytes(self.push_back)
110
                self.push_back = ''
111
            else:
112
                bytes = self.socket.recv(4096)
113
                if bytes == '':
114
                    self.finished = True
115
                    return
116
                protocol.accept_bytes(bytes)
117
        
118
        self.push_back = protocol.excess_buffer
119
    
120
    def terminate_due_to_error(self):
121
        """Called when an unhandled exception from the protocol occurs."""
122
        # TODO: This should log to a server log file, but no such thing
123
        # exists yet.  Andrew Bennetts 2006-09-29.
124
        self.socket.close()
125
        self.finished = True
126
127
    def _write_out(self, bytes):
128
        self.socket.sendall(bytes)
129
130
131
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
132
133
    def __init__(self, in_file, out_file, backing_transport):
134
        """Construct new server.
135
136
        :param in_file: Python file from which requests can be read.
137
        :param out_file: Python file to write responses.
138
        :param backing_transport: Transport for the directory served.
139
        """
140
        SmartServerStreamMedium.__init__(self, backing_transport)
141
        if sys.platform == 'win32':
142
            # force binary mode for files
143
            import msvcrt
144
            for f in (in_file, out_file):
145
                fileno = getattr(f, 'fileno', None)
146
                if fileno:
147
                    msvcrt.setmode(fileno(), os.O_BINARY)
148
        self._in = in_file
149
        self._out = out_file
150
151
    def _serve_one_request_unguarded(self, protocol):
152
        while True:
153
            bytes_to_read = protocol.next_read_size()
154
            if bytes_to_read == 0:
155
                # Finished serving this request.
156
                self._out.flush()
157
                return
158
            bytes = self._in.read(bytes_to_read)
159
            if bytes == '':
160
                # Connection has been closed.
161
                self.finished = True
162
                self._out.flush()
163
                return
164
            protocol.accept_bytes(bytes)
165
166
    def terminate_due_to_error(self):
167
        # TODO: This should log to a server log file, but no such thing
168
        # exists yet.  Andrew Bennetts 2006-09-29.
169
        self._out.close()
170
        self.finished = True
171
172
    def _write_out(self, bytes):
173
        self._out.write(bytes)
174
175
2400.1.6 by Andrew Bennetts
Cosmetic changes to minimise the difference between this branch and the hpss branch.
176
class SmartClientMediumRequest(object):
177
    """A request on a SmartClientMedium.
178
179
    Each request allows bytes to be provided to it via accept_bytes, and then
180
    the response bytes to be read via read_bytes.
181
182
    For instance:
183
    request.accept_bytes('123')
184
    request.finished_writing()
185
    result = request.read_bytes(3)
186
    request.finished_reading()
187
188
    It is up to the individual SmartClientMedium whether multiple concurrent
189
    requests can exist. See SmartClientMedium.get_request to obtain instances 
190
    of SmartClientMediumRequest, and the concrete Medium you are using for 
191
    details on concurrency and pipelining.
192
    """
193
194
    def __init__(self, medium):
195
        """Construct a SmartClientMediumRequest for the medium medium."""
196
        self._medium = medium
197
        # we track state by constants - we may want to use the same
198
        # pattern as BodyReader if it gets more complex.
199
        # valid states are: "writing", "reading", "done"
200
        self._state = "writing"
201
202
    def accept_bytes(self, bytes):
203
        """Accept bytes for inclusion in this request.
204
205
        This method may not be be called after finished_writing() has been
206
        called.  It depends upon the Medium whether or not the bytes will be
207
        immediately transmitted. Message based Mediums will tend to buffer the
208
        bytes until finished_writing() is called.
209
210
        :param bytes: A bytestring.
211
        """
212
        if self._state != "writing":
213
            raise errors.WritingCompleted(self)
214
        self._accept_bytes(bytes)
215
216
    def _accept_bytes(self, bytes):
217
        """Helper for accept_bytes.
218
219
        Accept_bytes checks the state of the request to determing if bytes
220
        should be accepted. After that it hands off to _accept_bytes to do the
221
        actual acceptance.
222
        """
223
        raise NotImplementedError(self._accept_bytes)
224
225
    def finished_reading(self):
226
        """Inform the request that all desired data has been read.
227
228
        This will remove the request from the pipeline for its medium (if the
229
        medium supports pipelining) and any further calls to methods on the
230
        request will raise ReadingCompleted.
231
        """
232
        if self._state == "writing":
233
            raise errors.WritingNotComplete(self)
234
        if self._state != "reading":
235
            raise errors.ReadingCompleted(self)
236
        self._state = "done"
237
        self._finished_reading()
238
239
    def _finished_reading(self):
240
        """Helper for finished_reading.
241
242
        finished_reading checks the state of the request to determine if 
243
        finished_reading is allowed, and if it is hands off to _finished_reading
244
        to perform the action.
245
        """
246
        raise NotImplementedError(self._finished_reading)
247
248
    def finished_writing(self):
249
        """Finish the writing phase of this request.
250
251
        This will flush all pending data for this request along the medium.
252
        After calling finished_writing, you may not call accept_bytes anymore.
253
        """
254
        if self._state != "writing":
255
            raise errors.WritingCompleted(self)
256
        self._state = "reading"
257
        self._finished_writing()
258
259
    def _finished_writing(self):
260
        """Helper for finished_writing.
261
262
        finished_writing checks the state of the request to determine if 
263
        finished_writing is allowed, and if it is hands off to _finished_writing
264
        to perform the action.
265
        """
266
        raise NotImplementedError(self._finished_writing)
267
268
    def read_bytes(self, count):
269
        """Read bytes from this requests response.
270
271
        This method will block and wait for count bytes to be read. It may not
272
        be invoked until finished_writing() has been called - this is to ensure
273
        a message-based approach to requests, for compatability with message
274
        based mediums like HTTP.
275
        """
276
        if self._state == "writing":
277
            raise errors.WritingNotComplete(self)
278
        if self._state != "reading":
279
            raise errors.ReadingCompleted(self)
280
        return self._read_bytes(count)
281
282
    def _read_bytes(self, count):
283
        """Helper for read_bytes.
284
285
        read_bytes checks the state of the request to determing if bytes
286
        should be read. After that it hands off to _read_bytes to do the
287
        actual read.
288
        """
289
        raise NotImplementedError(self._read_bytes)
290
291
2400.1.3 by Andrew Bennetts
Split smart transport code into several separate modules.
292
class SmartClientMedium(object):
293
    """Smart client is a medium for sending smart protocol requests over."""
294
295
    def disconnect(self):
296
        """If this medium maintains a persistent connection, close it.
297
        
298
        The default implementation does nothing.
299
        """
300
        
301
302
class SmartClientStreamMedium(SmartClientMedium):
303
    """Stream based medium common class.
304
305
    SmartClientStreamMediums operate on a stream. All subclasses use a common
306
    SmartClientStreamMediumRequest for their requests, and should implement
307
    _accept_bytes and _read_bytes to allow the request objects to send and
308
    receive bytes.
309
    """
310
311
    def __init__(self):
312
        self._current_request = None
313
314
    def accept_bytes(self, bytes):
315
        self._accept_bytes(bytes)
316
317
    def __del__(self):
318
        """The SmartClientStreamMedium knows how to close the stream when it is
319
        finished with it.
320
        """
321
        self.disconnect()
322
323
    def _flush(self):
324
        """Flush the output stream.
325
        
326
        This method is used by the SmartClientStreamMediumRequest to ensure that
327
        all data for a request is sent, to avoid long timeouts or deadlocks.
328
        """
329
        raise NotImplementedError(self._flush)
330
331
    def get_request(self):
332
        """See SmartClientMedium.get_request().
333
334
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
335
        for get_request.
336
        """
337
        return SmartClientStreamMediumRequest(self)
338
339
    def read_bytes(self, count):
340
        return self._read_bytes(count)
341
342
343
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
344
    """A client medium using simple pipes.
345
    
346
    This client does not manage the pipes: it assumes they will always be open.
347
    """
348
349
    def __init__(self, readable_pipe, writeable_pipe):
350
        SmartClientStreamMedium.__init__(self)
351
        self._readable_pipe = readable_pipe
352
        self._writeable_pipe = writeable_pipe
353
354
    def _accept_bytes(self, bytes):
355
        """See SmartClientStreamMedium.accept_bytes."""
356
        self._writeable_pipe.write(bytes)
357
358
    def _flush(self):
359
        """See SmartClientStreamMedium._flush()."""
360
        self._writeable_pipe.flush()
361
362
    def _read_bytes(self, count):
363
        """See SmartClientStreamMedium._read_bytes."""
364
        return self._readable_pipe.read(count)
365
366
367
class SmartSSHClientMedium(SmartClientStreamMedium):
368
    """A client medium using SSH."""
369
    
370
    def __init__(self, host, port=None, username=None, password=None,
371
            vendor=None):
372
        """Creates a client that will connect on the first use.
373
        
374
        :param vendor: An optional override for the ssh vendor to use. See
375
            bzrlib.transport.ssh for details on ssh vendors.
376
        """
377
        SmartClientStreamMedium.__init__(self)
378
        self._connected = False
379
        self._host = host
380
        self._password = password
381
        self._port = port
382
        self._username = username
383
        self._read_from = None
384
        self._ssh_connection = None
385
        self._vendor = vendor
386
        self._write_to = None
387
388
    def _accept_bytes(self, bytes):
389
        """See SmartClientStreamMedium.accept_bytes."""
390
        self._ensure_connection()
391
        self._write_to.write(bytes)
392
393
    def disconnect(self):
394
        """See SmartClientMedium.disconnect()."""
395
        if not self._connected:
396
            return
397
        self._read_from.close()
398
        self._write_to.close()
399
        self._ssh_connection.close()
400
        self._connected = False
401
402
    def _ensure_connection(self):
403
        """Connect this medium if not already connected."""
404
        if self._connected:
405
            return
406
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
407
        if self._vendor is None:
408
            vendor = ssh._get_ssh_vendor()
409
        else:
410
            vendor = self._vendor
411
        self._ssh_connection = vendor.connect_ssh(self._username,
412
                self._password, self._host, self._port,
413
                command=[executable, 'serve', '--inet', '--directory=/',
414
                         '--allow-writes'])
415
        self._read_from, self._write_to = \
416
            self._ssh_connection.get_filelike_channels()
417
        self._connected = True
418
419
    def _flush(self):
420
        """See SmartClientStreamMedium._flush()."""
421
        self._write_to.flush()
422
423
    def _read_bytes(self, count):
424
        """See SmartClientStreamMedium.read_bytes."""
425
        if not self._connected:
426
            raise errors.MediumNotConnected(self)
427
        return self._read_from.read(count)
428
429
430
class SmartTCPClientMedium(SmartClientStreamMedium):
431
    """A client medium using TCP."""
432
    
433
    def __init__(self, host, port):
434
        """Creates a client that will connect on the first use."""
435
        SmartClientStreamMedium.__init__(self)
436
        self._connected = False
437
        self._host = host
438
        self._port = port
439
        self._socket = None
440
441
    def _accept_bytes(self, bytes):
442
        """See SmartClientMedium.accept_bytes."""
443
        self._ensure_connection()
444
        self._socket.sendall(bytes)
445
446
    def disconnect(self):
447
        """See SmartClientMedium.disconnect()."""
448
        if not self._connected:
449
            return
450
        self._socket.close()
451
        self._socket = None
452
        self._connected = False
453
454
    def _ensure_connection(self):
455
        """Connect this medium if not already connected."""
456
        if self._connected:
457
            return
458
        self._socket = socket.socket()
459
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
460
        result = self._socket.connect_ex((self._host, int(self._port)))
461
        if result:
462
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
463
                    (self._host, self._port, os.strerror(result)))
464
        self._connected = True
465
466
    def _flush(self):
467
        """See SmartClientStreamMedium._flush().
468
        
469
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
470
        add a means to do a flush, but that can be done in the future.
471
        """
472
473
    def _read_bytes(self, count):
474
        """See SmartClientMedium.read_bytes."""
475
        if not self._connected:
476
            raise errors.MediumNotConnected(self)
477
        return self._socket.recv(count)
478
2400.1.6 by Andrew Bennetts
Cosmetic changes to minimise the difference between this branch and the hpss branch.
479
480
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
481
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
482
483
    def __init__(self, medium):
484
        SmartClientMediumRequest.__init__(self, medium)
485
        # check that we are safe concurrency wise. If some streams start
486
        # allowing concurrent requests - i.e. via multiplexing - then this
487
        # assert should be moved to SmartClientStreamMedium.get_request,
488
        # and the setting/unsetting of _current_request likewise moved into
489
        # that class : but its unneeded overhead for now. RBC 20060922
490
        if self._medium._current_request is not None:
491
            raise errors.TooManyConcurrentRequests(self._medium)
492
        self._medium._current_request = self
493
494
    def _accept_bytes(self, bytes):
495
        """See SmartClientMediumRequest._accept_bytes.
496
        
497
        This forwards to self._medium._accept_bytes because we are operating
498
        on the mediums stream.
499
        """
500
        self._medium._accept_bytes(bytes)
501
502
    def _finished_reading(self):
503
        """See SmartClientMediumRequest._finished_reading.
504
505
        This clears the _current_request on self._medium to allow a new 
506
        request to be created.
507
        """
508
        assert self._medium._current_request is self
509
        self._medium._current_request = None
510
        
511
    def _finished_writing(self):
512
        """See SmartClientMediumRequest._finished_writing.
513
514
        This invokes self._medium._flush to ensure all bytes are transmitted.
515
        """
516
        self._medium._flush()
517
518
    def _read_bytes(self, count):
519
        """See SmartClientMediumRequest._read_bytes.
520
        
521
        This forwards to self._medium._read_bytes because we are operating
522
        on the mediums stream.
523
        """
524
        return self._medium._read_bytes(count)
525
526