~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart/medium.py

Start splitting bzrlib/transport/smart.py into a package.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
import os
 
18
import socket
 
19
 
 
20
from bzrlib import errors
 
21
try:
 
22
    from bzrlib.transport import ssh
 
23
except errors.ParamikoNotPresent:
 
24
    # no paramiko.  SmartSSHClientMedium will break.
 
25
    pass
 
26
 
 
27
class SmartServerStreamMedium(object):
 
28
    """Handles smart commands coming over a stream.
 
29
 
 
30
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
31
    in-process fifo for testing.
 
32
 
 
33
    One instance is created for each connected client; it can serve multiple
 
34
    requests in the lifetime of the connection.
 
35
 
 
36
    The server passes requests through to an underlying backing transport, 
 
37
    which will typically be a LocalTransport looking at the server's filesystem.
 
38
    """
 
39
 
 
40
    def __init__(self, backing_transport):
 
41
        """Construct new server.
 
42
 
 
43
        :param backing_transport: Transport for the directory served.
 
44
        """
 
45
        # backing_transport could be passed to serve instead of __init__
 
46
        self.backing_transport = backing_transport
 
47
        self.finished = False
 
48
 
 
49
    def serve(self):
 
50
        """Serve requests until the client disconnects."""
 
51
        # Keep a reference to stderr because the sys module's globals get set to
 
52
        # None during interpreter shutdown.
 
53
        from sys import stderr
 
54
        from bzrlib.transport import _smart
 
55
        try:
 
56
            while not self.finished:
 
57
                protocol = _smart.SmartServerRequestProtocolOne(
 
58
                    self.backing_transport, self._write_out)
 
59
                self._serve_one_request(protocol)
 
60
        except Exception, e:
 
61
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
62
            raise
 
63
 
 
64
    def _serve_one_request(self, protocol):
 
65
        """Read one request from input, process, send back a response.
 
66
        
 
67
        :param protocol: a SmartServerRequestProtocol.
 
68
        """
 
69
        try:
 
70
            self._serve_one_request_unguarded(protocol)
 
71
        except KeyboardInterrupt:
 
72
            raise
 
73
        except Exception, e:
 
74
            self.terminate_due_to_error()
 
75
 
 
76
    def terminate_due_to_error(self):
 
77
        """Called when an unhandled exception from the protocol occurs."""
 
78
        raise NotImplementedError(self.terminate_due_to_error)
 
79
 
 
80
 
 
81
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
82
 
 
83
    def __init__(self, sock, backing_transport):
 
84
        """Constructor.
 
85
 
 
86
        :param sock: the socket the server will read from.  It will be put
 
87
            into blocking mode.
 
88
        """
 
89
        SmartServerStreamMedium.__init__(self, backing_transport)
 
90
        self.push_back = ''
 
91
        sock.setblocking(True)
 
92
        self.socket = sock
 
93
 
 
94
    def _serve_one_request_unguarded(self, protocol):
 
95
        while protocol.next_read_size():
 
96
            if self.push_back:
 
97
                protocol.accept_bytes(self.push_back)
 
98
                self.push_back = ''
 
99
            else:
 
100
                bytes = self.socket.recv(4096)
 
101
                if bytes == '':
 
102
                    self.finished = True
 
103
                    return
 
104
                protocol.accept_bytes(bytes)
 
105
        
 
106
        self.push_back = protocol.excess_buffer
 
107
    
 
108
    def terminate_due_to_error(self):
 
109
        """Called when an unhandled exception from the protocol occurs."""
 
110
        # TODO: This should log to a server log file, but no such thing
 
111
        # exists yet.  Andrew Bennetts 2006-09-29.
 
112
        self.socket.close()
 
113
        self.finished = True
 
114
 
 
115
    def _write_out(self, bytes):
 
116
        self.socket.sendall(bytes)
 
117
 
 
118
 
 
119
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
120
 
 
121
    def __init__(self, in_file, out_file, backing_transport):
 
122
        """Construct new server.
 
123
 
 
124
        :param in_file: Python file from which requests can be read.
 
125
        :param out_file: Python file to write responses.
 
126
        :param backing_transport: Transport for the directory served.
 
127
        """
 
128
        SmartServerStreamMedium.__init__(self, backing_transport)
 
129
        self._in = in_file
 
130
        self._out = out_file
 
131
 
 
132
    def _serve_one_request_unguarded(self, protocol):
 
133
        while True:
 
134
            bytes_to_read = protocol.next_read_size()
 
135
            if bytes_to_read == 0:
 
136
                # Finished serving this request.
 
137
                self._out.flush()
 
138
                return
 
139
            bytes = self._in.read(bytes_to_read)
 
140
            if bytes == '':
 
141
                # Connection has been closed.
 
142
                self.finished = True
 
143
                self._out.flush()
 
144
                return
 
145
            protocol.accept_bytes(bytes)
 
146
 
 
147
    def terminate_due_to_error(self):
 
148
        # TODO: This should log to a server log file, but no such thing
 
149
        # exists yet.  Andrew Bennetts 2006-09-29.
 
150
        self._out.close()
 
151
        self.finished = True
 
152
 
 
153
    def _write_out(self, bytes):
 
154
        self._out.write(bytes)
 
155
 
 
156
 
 
157
class SmartClientMediumRequest(object):
 
158
    """A request on a SmartClientMedium.
 
159
 
 
160
    Each request allows bytes to be provided to it via accept_bytes, and then
 
161
    the response bytes to be read via read_bytes.
 
162
 
 
163
    For instance:
 
164
    request.accept_bytes('123')
 
165
    request.finished_writing()
 
166
    result = request.read_bytes(3)
 
167
    request.finished_reading()
 
168
 
 
169
    It is up to the individual SmartClientMedium whether multiple concurrent
 
170
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
171
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
172
    details on concurrency and pipelining.
 
173
    """
 
174
 
 
175
    def __init__(self, medium):
 
176
        """Construct a SmartClientMediumRequest for the medium medium."""
 
177
        self._medium = medium
 
178
        # we track state by constants - we may want to use the same
 
179
        # pattern as BodyReader if it gets more complex.
 
180
        # valid states are: "writing", "reading", "done"
 
181
        self._state = "writing"
 
182
 
 
183
    def accept_bytes(self, bytes):
 
184
        """Accept bytes for inclusion in this request.
 
185
 
 
186
        This method may not be be called after finished_writing() has been
 
187
        called.  It depends upon the Medium whether or not the bytes will be
 
188
        immediately transmitted. Message based Mediums will tend to buffer the
 
189
        bytes until finished_writing() is called.
 
190
 
 
191
        :param bytes: A bytestring.
 
192
        """
 
193
        if self._state != "writing":
 
194
            raise errors.WritingCompleted(self)
 
195
        self._accept_bytes(bytes)
 
196
 
 
197
    def _accept_bytes(self, bytes):
 
198
        """Helper for accept_bytes.
 
199
 
 
200
        Accept_bytes checks the state of the request to determing if bytes
 
201
        should be accepted. After that it hands off to _accept_bytes to do the
 
202
        actual acceptance.
 
203
        """
 
204
        raise NotImplementedError(self._accept_bytes)
 
205
 
 
206
    def finished_reading(self):
 
207
        """Inform the request that all desired data has been read.
 
208
 
 
209
        This will remove the request from the pipeline for its medium (if the
 
210
        medium supports pipelining) and any further calls to methods on the
 
211
        request will raise ReadingCompleted.
 
212
        """
 
213
        if self._state == "writing":
 
214
            raise errors.WritingNotComplete(self)
 
215
        if self._state != "reading":
 
216
            raise errors.ReadingCompleted(self)
 
217
        self._state = "done"
 
218
        self._finished_reading()
 
219
 
 
220
    def _finished_reading(self):
 
221
        """Helper for finished_reading.
 
222
 
 
223
        finished_reading checks the state of the request to determine if 
 
224
        finished_reading is allowed, and if it is hands off to _finished_reading
 
225
        to perform the action.
 
226
        """
 
227
        raise NotImplementedError(self._finished_reading)
 
228
 
 
229
    def finished_writing(self):
 
230
        """Finish the writing phase of this request.
 
231
 
 
232
        This will flush all pending data for this request along the medium.
 
233
        After calling finished_writing, you may not call accept_bytes anymore.
 
234
        """
 
235
        if self._state != "writing":
 
236
            raise errors.WritingCompleted(self)
 
237
        self._state = "reading"
 
238
        self._finished_writing()
 
239
 
 
240
    def _finished_writing(self):
 
241
        """Helper for finished_writing.
 
242
 
 
243
        finished_writing checks the state of the request to determine if 
 
244
        finished_writing is allowed, and if it is hands off to _finished_writing
 
245
        to perform the action.
 
246
        """
 
247
        raise NotImplementedError(self._finished_writing)
 
248
 
 
249
    def read_bytes(self, count):
 
250
        """Read bytes from this requests response.
 
251
 
 
252
        This method will block and wait for count bytes to be read. It may not
 
253
        be invoked until finished_writing() has been called - this is to ensure
 
254
        a message-based approach to requests, for compatability with message
 
255
        based mediums like HTTP.
 
256
        """
 
257
        if self._state == "writing":
 
258
            raise errors.WritingNotComplete(self)
 
259
        if self._state != "reading":
 
260
            raise errors.ReadingCompleted(self)
 
261
        return self._read_bytes(count)
 
262
 
 
263
    def _read_bytes(self, count):
 
264
        """Helper for read_bytes.
 
265
 
 
266
        read_bytes checks the state of the request to determing if bytes
 
267
        should be read. After that it hands off to _read_bytes to do the
 
268
        actual read.
 
269
        """
 
270
        raise NotImplementedError(self._read_bytes)
 
271
 
 
272
 
 
273
class SmartClientMedium(object):
 
274
    """Smart client is a medium for sending smart protocol requests over."""
 
275
 
 
276
    def disconnect(self):
 
277
        """If this medium maintains a persistent connection, close it.
 
278
        
 
279
        The default implementation does nothing.
 
280
        """
 
281
        
 
282
 
 
283
class SmartClientStreamMedium(SmartClientMedium):
 
284
    """Stream based medium common class.
 
285
 
 
286
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
287
    SmartClientStreamMediumRequest for their requests, and should implement
 
288
    _accept_bytes and _read_bytes to allow the request objects to send and
 
289
    receive bytes.
 
290
    """
 
291
 
 
292
    def __init__(self):
 
293
        self._current_request = None
 
294
 
 
295
    def accept_bytes(self, bytes):
 
296
        self._accept_bytes(bytes)
 
297
 
 
298
    def __del__(self):
 
299
        """The SmartClientStreamMedium knows how to close the stream when it is
 
300
        finished with it.
 
301
        """
 
302
        self.disconnect()
 
303
 
 
304
    def _flush(self):
 
305
        """Flush the output stream.
 
306
        
 
307
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
308
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
309
        """
 
310
        raise NotImplementedError(self._flush)
 
311
 
 
312
    def get_request(self):
 
313
        """See SmartClientMedium.get_request().
 
314
 
 
315
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
316
        for get_request.
 
317
        """
 
318
        return SmartClientStreamMediumRequest(self)
 
319
 
 
320
    def read_bytes(self, count):
 
321
        return self._read_bytes(count)
 
322
 
 
323
 
 
324
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
325
    """A client medium using simple pipes.
 
326
    
 
327
    This client does not manage the pipes: it assumes they will always be open.
 
328
    """
 
329
 
 
330
    def __init__(self, readable_pipe, writeable_pipe):
 
331
        SmartClientStreamMedium.__init__(self)
 
332
        self._readable_pipe = readable_pipe
 
333
        self._writeable_pipe = writeable_pipe
 
334
 
 
335
    def _accept_bytes(self, bytes):
 
336
        """See SmartClientStreamMedium.accept_bytes."""
 
337
        self._writeable_pipe.write(bytes)
 
338
 
 
339
    def _flush(self):
 
340
        """See SmartClientStreamMedium._flush()."""
 
341
        self._writeable_pipe.flush()
 
342
 
 
343
    def _read_bytes(self, count):
 
344
        """See SmartClientStreamMedium._read_bytes."""
 
345
        return self._readable_pipe.read(count)
 
346
 
 
347
 
 
348
class SmartSSHClientMedium(SmartClientStreamMedium):
 
349
    """A client medium using SSH."""
 
350
    
 
351
    def __init__(self, host, port=None, username=None, password=None,
 
352
            vendor=None):
 
353
        """Creates a client that will connect on the first use.
 
354
        
 
355
        :param vendor: An optional override for the ssh vendor to use. See
 
356
            bzrlib.transport.ssh for details on ssh vendors.
 
357
        """
 
358
        SmartClientStreamMedium.__init__(self)
 
359
        self._connected = False
 
360
        self._host = host
 
361
        self._password = password
 
362
        self._port = port
 
363
        self._username = username
 
364
        self._read_from = None
 
365
        self._ssh_connection = None
 
366
        self._vendor = vendor
 
367
        self._write_to = None
 
368
 
 
369
    def _accept_bytes(self, bytes):
 
370
        """See SmartClientStreamMedium.accept_bytes."""
 
371
        self._ensure_connection()
 
372
        self._write_to.write(bytes)
 
373
 
 
374
    def disconnect(self):
 
375
        """See SmartClientMedium.disconnect()."""
 
376
        if not self._connected:
 
377
            return
 
378
        self._read_from.close()
 
379
        self._write_to.close()
 
380
        self._ssh_connection.close()
 
381
        self._connected = False
 
382
 
 
383
    def _ensure_connection(self):
 
384
        """Connect this medium if not already connected."""
 
385
        if self._connected:
 
386
            return
 
387
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
388
        if self._vendor is None:
 
389
            vendor = ssh._get_ssh_vendor()
 
390
        else:
 
391
            vendor = self._vendor
 
392
        self._ssh_connection = vendor.connect_ssh(self._username,
 
393
                self._password, self._host, self._port,
 
394
                command=[executable, 'serve', '--inet', '--directory=/',
 
395
                         '--allow-writes'])
 
396
        self._read_from, self._write_to = \
 
397
            self._ssh_connection.get_filelike_channels()
 
398
        self._connected = True
 
399
 
 
400
    def _flush(self):
 
401
        """See SmartClientStreamMedium._flush()."""
 
402
        self._write_to.flush()
 
403
 
 
404
    def _read_bytes(self, count):
 
405
        """See SmartClientStreamMedium.read_bytes."""
 
406
        if not self._connected:
 
407
            raise errors.MediumNotConnected(self)
 
408
        return self._read_from.read(count)
 
409
 
 
410
 
 
411
class SmartTCPClientMedium(SmartClientStreamMedium):
 
412
    """A client medium using TCP."""
 
413
    
 
414
    def __init__(self, host, port):
 
415
        """Creates a client that will connect on the first use."""
 
416
        SmartClientStreamMedium.__init__(self)
 
417
        self._connected = False
 
418
        self._host = host
 
419
        self._port = port
 
420
        self._socket = None
 
421
 
 
422
    def _accept_bytes(self, bytes):
 
423
        """See SmartClientMedium.accept_bytes."""
 
424
        self._ensure_connection()
 
425
        self._socket.sendall(bytes)
 
426
 
 
427
    def disconnect(self):
 
428
        """See SmartClientMedium.disconnect()."""
 
429
        if not self._connected:
 
430
            return
 
431
        self._socket.close()
 
432
        self._socket = None
 
433
        self._connected = False
 
434
 
 
435
    def _ensure_connection(self):
 
436
        """Connect this medium if not already connected."""
 
437
        if self._connected:
 
438
            return
 
439
        self._socket = socket.socket()
 
440
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
441
        result = self._socket.connect_ex((self._host, int(self._port)))
 
442
        if result:
 
443
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
444
                    (self._host, self._port, os.strerror(result)))
 
445
        self._connected = True
 
446
 
 
447
    def _flush(self):
 
448
        """See SmartClientStreamMedium._flush().
 
449
        
 
450
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
451
        add a means to do a flush, but that can be done in the future.
 
452
        """
 
453
 
 
454
    def _read_bytes(self, count):
 
455
        """See SmartClientMedium.read_bytes."""
 
456
        if not self._connected:
 
457
            raise errors.MediumNotConnected(self)
 
458
        return self._socket.recv(count)
 
459
 
 
460
 
 
461
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
462
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
463
 
 
464
    def __init__(self, medium):
 
465
        SmartClientMediumRequest.__init__(self, medium)
 
466
        # check that we are safe concurrency wise. If some streams start
 
467
        # allowing concurrent requests - i.e. via multiplexing - then this
 
468
        # assert should be moved to SmartClientStreamMedium.get_request,
 
469
        # and the setting/unsetting of _current_request likewise moved into
 
470
        # that class : but its unneeded overhead for now. RBC 20060922
 
471
        if self._medium._current_request is not None:
 
472
            raise errors.TooManyConcurrentRequests(self._medium)
 
473
        self._medium._current_request = self
 
474
 
 
475
    def _accept_bytes(self, bytes):
 
476
        """See SmartClientMediumRequest._accept_bytes.
 
477
        
 
478
        This forwards to self._medium._accept_bytes because we are operating
 
479
        on the mediums stream.
 
480
        """
 
481
        self._medium._accept_bytes(bytes)
 
482
 
 
483
    def _finished_reading(self):
 
484
        """See SmartClientMediumRequest._finished_reading.
 
485
 
 
486
        This clears the _current_request on self._medium to allow a new 
 
487
        request to be created.
 
488
        """
 
489
        assert self._medium._current_request is self
 
490
        self._medium._current_request = None
 
491
        
 
492
    def _finished_writing(self):
 
493
        """See SmartClientMediumRequest._finished_writing.
 
494
 
 
495
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
496
        """
 
497
        self._medium._flush()
 
498
 
 
499
    def _read_bytes(self, count):
 
500
        """See SmartClientMediumRequest._read_bytes.
 
501
        
 
502
        This forwards to self._medium._read_bytes because we are operating
 
503
        on the mediums stream.
 
504
        """
 
505
        return self._medium._read_bytes(count)
 
506
 
 
507