~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/transport/smart.py

  • Committer: Aaron Bentley
  • Date: 2007-02-06 14:52:16 UTC
  • mfrom: (2266 +trunk)
  • mto: This revision was merged to the branch mainline in revision 2268.
  • Revision ID: abentley@panoramicfeedback.com-20070206145216-fcpi8o3ufvuzwbp9
Merge bzr.dev

Show diffs side-by-side

added added

removed removed

Lines of Context:
22
22
 
23
23
  SEP := '\001'
24
24
    Fields are separated by Ctrl-A.
25
 
  BULK_DATA := CHUNK+ TRAILER
 
25
  BULK_DATA := CHUNK TRAILER
26
26
    Chunks can be repeated as many times as necessary.
27
27
  CHUNK := CHUNK_LEN CHUNK_BODY
28
28
  CHUNK_LEN := DIGIT+ NEWLINE
46
46
URLs that include ~ should probably be passed across to the server verbatim
47
47
and the server can expand them.  This will proably not be meaningful when 
48
48
limited to a directory?
 
49
 
 
50
At the bottom level socket, pipes, HTTP server.  For sockets, we have the idea
 
51
that you have multiple requests and get a read error because the other side did
 
52
shutdown.  For pipes we have read pipe which will have a zero read which marks
 
53
end-of-file.  For HTTP server environment there is not end-of-stream because
 
54
each request coming into the server is independent.
 
55
 
 
56
So we need a wrapper around pipes and sockets to seperate out requests from
 
57
substrate and this will give us a single model which is consist for HTTP,
 
58
sockets and pipes.
 
59
 
 
60
Server-side
 
61
-----------
 
62
 
 
63
 MEDIUM  (factory for protocol, reads bytes & pushes to protocol,
 
64
          uses protocol to detect end-of-request, sends written
 
65
          bytes to client) e.g. socket, pipe, HTTP request handler.
 
66
  ^
 
67
  | bytes.
 
68
  v
 
69
 
 
70
PROTOCOL  (serialization, deserialization)  accepts bytes for one
 
71
          request, decodes according to internal state, pushes
 
72
          structured data to handler.  accepts structured data from
 
73
          handler and encodes and writes to the medium.  factory for
 
74
          handler.
 
75
  ^
 
76
  | structured data
 
77
  v
 
78
 
 
79
HANDLER   (domain logic) accepts structured data, operates state
 
80
          machine until the request can be satisfied,
 
81
          sends structured data to the protocol.
 
82
 
 
83
 
 
84
Client-side
 
85
-----------
 
86
 
 
87
 CLIENT             domain logic, accepts domain requests, generated structured
 
88
                    data, reads structured data from responses and turns into
 
89
                    domain data.  Sends structured data to the protocol.
 
90
                    Operates state machines until the request can be delivered
 
91
                    (e.g. reading from a bundle generated in bzrlib to deliver a
 
92
                    complete request).
 
93
 
 
94
                    Possibly this should just be RemoteBzrDir, RemoteTransport,
 
95
                    ...
 
96
  ^
 
97
  | structured data
 
98
  v
 
99
 
 
100
PROTOCOL  (serialization, deserialization)  accepts structured data for one
 
101
          request, encodes and writes to the medium.  Reads bytes from the
 
102
          medium, decodes and allows the client to read structured data.
 
103
  ^
 
104
  | bytes.
 
105
  v
 
106
 
 
107
 MEDIUM  (accepts bytes from the protocol & delivers to the remote server.
 
108
          Allows the potocol to read bytes e.g. socket, pipe, HTTP request.
49
109
"""
50
110
 
51
111
 
133
193
# TODO: SmartBzrDir class, proxying all Branch etc methods across to another
134
194
# branch doing file-level operations.
135
195
#
136
 
# TODO: jam 20060915 _decode_tuple is acting directly on input over
137
 
#       the socket, and it assumes everything is UTF8 sections separated
138
 
#       by \001. Which means a request like '\002' Will abort the connection
139
 
#       because of a UnicodeDecodeError. It does look like invalid data will
140
 
#       kill the SmartStreamServer, but only with an abort + exception, and 
141
 
#       the overall server shouldn't die.
142
196
 
143
197
from cStringIO import StringIO
144
 
import errno
145
198
import os
146
199
import socket
147
 
import sys
148
200
import tempfile
149
201
import threading
150
202
import urllib
159
211
    urlutils,
160
212
    )
161
213
from bzrlib.bundle.serializer import write_bundle
162
 
from bzrlib.trace import mutter
163
 
from bzrlib.transport import local
 
214
try:
 
215
    from bzrlib.transport import ssh
 
216
except errors.ParamikoNotPresent:
 
217
    # no paramiko.  SmartSSHClientMedium will break.
 
218
    pass
164
219
 
165
220
# must do this otherwise urllib can't parse the urls properly :(
166
 
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh']:
 
221
for scheme in ['ssh', 'bzr', 'bzr+loopback', 'bzr+ssh', 'bzr+http']:
167
222
    transport.register_urlparse_netloc_protocol(scheme)
168
223
del scheme
169
224
 
178
233
        return None
179
234
    if req_line[-1] != '\n':
180
235
        raise errors.SmartProtocolError("request %r not terminated" % req_line)
181
 
    return tuple((a.decode('utf-8') for a in req_line[:-1].split('\x01')))
 
236
    return tuple(req_line[:-1].split('\x01'))
182
237
 
183
238
 
184
239
def _encode_tuple(args):
185
240
    """Encode the tuple args to a bytestream."""
186
 
    return '\x01'.join((a.encode('utf-8') for a in args)) + '\n'
 
241
    return '\x01'.join(args) + '\n'
187
242
 
188
243
 
189
244
class SmartProtocolBase(object):
190
245
    """Methods common to client and server"""
191
246
 
192
 
    def _send_bulk_data(self, body):
193
 
        """Send chunked body data"""
194
 
        assert isinstance(body, str)
195
 
        bytes = ''.join(('%d\n' % len(body), body, 'done\n'))
196
 
        self._write_and_flush(bytes)
197
 
 
198
 
    # TODO: this only actually accomodates a single block; possibly should support
199
 
    # multiple chunks?
200
 
    def _recv_bulk(self):
201
 
        chunk_len = self._in.readline()
202
 
        try:
203
 
            chunk_len = int(chunk_len)
204
 
        except ValueError:
205
 
            raise errors.SmartProtocolError("bad chunk length line %r" % chunk_len)
206
 
        bulk = self._in.read(chunk_len)
207
 
        if len(bulk) != chunk_len:
208
 
            raise errors.SmartProtocolError("short read fetching bulk data chunk")
209
 
        self._recv_trailer()
210
 
        return bulk
211
 
 
212
 
    def _recv_tuple(self):
213
 
        return _recv_tuple(self._in)
214
 
 
215
 
    def _recv_trailer(self):
216
 
        resp = self._recv_tuple()
217
 
        if resp == ('done', ):
218
 
            return
219
 
        else:
220
 
            self._translate_error(resp)
 
247
    # TODO: this only actually accomodates a single block; possibly should
 
248
    # support multiple chunks?
 
249
    def _encode_bulk_data(self, body):
 
250
        """Encode body as a bulk data chunk."""
 
251
        return ''.join(('%d\n' % len(body), body, 'done\n'))
221
252
 
222
253
    def _serialise_offsets(self, offsets):
223
254
        """Serialise a readv offset list."""
225
256
        for start, length in offsets:
226
257
            txt.append('%d,%d' % (start, length))
227
258
        return '\n'.join(txt)
228
 
 
229
 
    def _write_and_flush(self, bytes):
230
 
        """Write bytes to self._out and flush it."""
231
 
        # XXX: this will be inefficient.  Just ask Robert.
232
 
        self._out.write(bytes)
233
 
        self._out.flush()
234
 
 
235
 
 
236
 
class SmartStreamServer(SmartProtocolBase):
 
259
        
 
260
 
 
261
class SmartServerRequestProtocolOne(SmartProtocolBase):
 
262
    """Server-side encoding and decoding logic for smart version 1."""
 
263
    
 
264
    def __init__(self, backing_transport, write_func):
 
265
        self._backing_transport = backing_transport
 
266
        self.excess_buffer = ''
 
267
        self._finished = False
 
268
        self.in_buffer = ''
 
269
        self.has_dispatched = False
 
270
        self.request = None
 
271
        self._body_decoder = None
 
272
        self._write_func = write_func
 
273
 
 
274
    def accept_bytes(self, bytes):
 
275
        """Take bytes, and advance the internal state machine appropriately.
 
276
        
 
277
        :param bytes: must be a byte string
 
278
        """
 
279
        assert isinstance(bytes, str)
 
280
        self.in_buffer += bytes
 
281
        if not self.has_dispatched:
 
282
            if '\n' not in self.in_buffer:
 
283
                # no command line yet
 
284
                return
 
285
            self.has_dispatched = True
 
286
            try:
 
287
                first_line, self.in_buffer = self.in_buffer.split('\n', 1)
 
288
                first_line += '\n'
 
289
                req_args = _decode_tuple(first_line)
 
290
                self.request = SmartServerRequestHandler(
 
291
                    self._backing_transport)
 
292
                self.request.dispatch_command(req_args[0], req_args[1:])
 
293
                if self.request.finished_reading:
 
294
                    # trivial request
 
295
                    self.excess_buffer = self.in_buffer
 
296
                    self.in_buffer = ''
 
297
                    self._send_response(self.request.response.args,
 
298
                        self.request.response.body)
 
299
            except KeyboardInterrupt:
 
300
                raise
 
301
            except Exception, exception:
 
302
                # everything else: pass to client, flush, and quit
 
303
                self._send_response(('error', str(exception)))
 
304
                return
 
305
 
 
306
        if self.has_dispatched:
 
307
            if self._finished:
 
308
                # nothing to do.XXX: this routine should be a single state 
 
309
                # machine too.
 
310
                self.excess_buffer += self.in_buffer
 
311
                self.in_buffer = ''
 
312
                return
 
313
            if self._body_decoder is None:
 
314
                self._body_decoder = LengthPrefixedBodyDecoder()
 
315
            self._body_decoder.accept_bytes(self.in_buffer)
 
316
            self.in_buffer = self._body_decoder.unused_data
 
317
            body_data = self._body_decoder.read_pending_data()
 
318
            self.request.accept_body(body_data)
 
319
            if self._body_decoder.finished_reading:
 
320
                self.request.end_of_body()
 
321
                assert self.request.finished_reading, \
 
322
                    "no more body, request not finished"
 
323
            if self.request.response is not None:
 
324
                self._send_response(self.request.response.args,
 
325
                    self.request.response.body)
 
326
                self.excess_buffer = self.in_buffer
 
327
                self.in_buffer = ''
 
328
            else:
 
329
                assert not self.request.finished_reading, \
 
330
                    "no response and we have finished reading."
 
331
 
 
332
    def _send_response(self, args, body=None):
 
333
        """Send a smart server response down the output stream."""
 
334
        assert not self._finished, 'response already sent'
 
335
        self._finished = True
 
336
        self._write_func(_encode_tuple(args))
 
337
        if body is not None:
 
338
            assert isinstance(body, str), 'body must be a str'
 
339
            bytes = self._encode_bulk_data(body)
 
340
            self._write_func(bytes)
 
341
 
 
342
    def next_read_size(self):
 
343
        if self._finished:
 
344
            return 0
 
345
        if self._body_decoder is None:
 
346
            return 1
 
347
        else:
 
348
            return self._body_decoder.next_read_size()
 
349
 
 
350
 
 
351
class LengthPrefixedBodyDecoder(object):
 
352
    """Decodes the length-prefixed bulk data."""
 
353
    
 
354
    def __init__(self):
 
355
        self.bytes_left = None
 
356
        self.finished_reading = False
 
357
        self.unused_data = ''
 
358
        self.state_accept = self._state_accept_expecting_length
 
359
        self.state_read = self._state_read_no_data
 
360
        self._in_buffer = ''
 
361
        self._trailer_buffer = ''
 
362
    
 
363
    def accept_bytes(self, bytes):
 
364
        """Decode as much of bytes as possible.
 
365
 
 
366
        If 'bytes' contains too much data it will be appended to
 
367
        self.unused_data.
 
368
 
 
369
        finished_reading will be set when no more data is required.  Further
 
370
        data will be appended to self.unused_data.
 
371
        """
 
372
        # accept_bytes is allowed to change the state
 
373
        current_state = self.state_accept
 
374
        self.state_accept(bytes)
 
375
        while current_state != self.state_accept:
 
376
            current_state = self.state_accept
 
377
            self.state_accept('')
 
378
 
 
379
    def next_read_size(self):
 
380
        if self.bytes_left is not None:
 
381
            # Ideally we want to read all the remainder of the body and the
 
382
            # trailer in one go.
 
383
            return self.bytes_left + 5
 
384
        elif self.state_accept == self._state_accept_reading_trailer:
 
385
            # Just the trailer left
 
386
            return 5 - len(self._trailer_buffer)
 
387
        elif self.state_accept == self._state_accept_expecting_length:
 
388
            # There's still at least 6 bytes left ('\n' to end the length, plus
 
389
            # 'done\n').
 
390
            return 6
 
391
        else:
 
392
            # Reading excess data.  Either way, 1 byte at a time is fine.
 
393
            return 1
 
394
        
 
395
    def read_pending_data(self):
 
396
        """Return any pending data that has been decoded."""
 
397
        return self.state_read()
 
398
 
 
399
    def _state_accept_expecting_length(self, bytes):
 
400
        self._in_buffer += bytes
 
401
        pos = self._in_buffer.find('\n')
 
402
        if pos == -1:
 
403
            return
 
404
        self.bytes_left = int(self._in_buffer[:pos])
 
405
        self._in_buffer = self._in_buffer[pos+1:]
 
406
        self.bytes_left -= len(self._in_buffer)
 
407
        self.state_accept = self._state_accept_reading_body
 
408
        self.state_read = self._state_read_in_buffer
 
409
 
 
410
    def _state_accept_reading_body(self, bytes):
 
411
        self._in_buffer += bytes
 
412
        self.bytes_left -= len(bytes)
 
413
        if self.bytes_left <= 0:
 
414
            # Finished with body
 
415
            if self.bytes_left != 0:
 
416
                self._trailer_buffer = self._in_buffer[self.bytes_left:]
 
417
                self._in_buffer = self._in_buffer[:self.bytes_left]
 
418
            self.bytes_left = None
 
419
            self.state_accept = self._state_accept_reading_trailer
 
420
        
 
421
    def _state_accept_reading_trailer(self, bytes):
 
422
        self._trailer_buffer += bytes
 
423
        # TODO: what if the trailer does not match "done\n"?  Should this raise
 
424
        # a ProtocolViolation exception?
 
425
        if self._trailer_buffer.startswith('done\n'):
 
426
            self.unused_data = self._trailer_buffer[len('done\n'):]
 
427
            self.state_accept = self._state_accept_reading_unused
 
428
            self.finished_reading = True
 
429
    
 
430
    def _state_accept_reading_unused(self, bytes):
 
431
        self.unused_data += bytes
 
432
 
 
433
    def _state_read_no_data(self):
 
434
        return ''
 
435
 
 
436
    def _state_read_in_buffer(self):
 
437
        result = self._in_buffer
 
438
        self._in_buffer = ''
 
439
        return result
 
440
 
 
441
 
 
442
class SmartServerStreamMedium(object):
237
443
    """Handles smart commands coming over a stream.
238
444
 
239
445
    The stream may be a pipe connected to sshd, or a tcp socket, or an
246
452
    which will typically be a LocalTransport looking at the server's filesystem.
247
453
    """
248
454
 
 
455
    def __init__(self, backing_transport):
 
456
        """Construct new server.
 
457
 
 
458
        :param backing_transport: Transport for the directory served.
 
459
        """
 
460
        # backing_transport could be passed to serve instead of __init__
 
461
        self.backing_transport = backing_transport
 
462
        self.finished = False
 
463
 
 
464
    def serve(self):
 
465
        """Serve requests until the client disconnects."""
 
466
        # Keep a reference to stderr because the sys module's globals get set to
 
467
        # None during interpreter shutdown.
 
468
        from sys import stderr
 
469
        try:
 
470
            while not self.finished:
 
471
                protocol = SmartServerRequestProtocolOne(self.backing_transport,
 
472
                                                         self._write_out)
 
473
                self._serve_one_request(protocol)
 
474
        except Exception, e:
 
475
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
476
            raise
 
477
 
 
478
    def _serve_one_request(self, protocol):
 
479
        """Read one request from input, process, send back a response.
 
480
        
 
481
        :param protocol: a SmartServerRequestProtocol.
 
482
        """
 
483
        try:
 
484
            self._serve_one_request_unguarded(protocol)
 
485
        except KeyboardInterrupt:
 
486
            raise
 
487
        except Exception, e:
 
488
            self.terminate_due_to_error()
 
489
 
 
490
    def terminate_due_to_error(self):
 
491
        """Called when an unhandled exception from the protocol occurs."""
 
492
        raise NotImplementedError(self.terminate_due_to_error)
 
493
 
 
494
 
 
495
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
496
 
 
497
    def __init__(self, sock, backing_transport):
 
498
        """Constructor.
 
499
 
 
500
        :param sock: the socket the server will read from.  It will be put
 
501
            into blocking mode.
 
502
        """
 
503
        SmartServerStreamMedium.__init__(self, backing_transport)
 
504
        self.push_back = ''
 
505
        sock.setblocking(True)
 
506
        self.socket = sock
 
507
 
 
508
    def _serve_one_request_unguarded(self, protocol):
 
509
        while protocol.next_read_size():
 
510
            if self.push_back:
 
511
                protocol.accept_bytes(self.push_back)
 
512
                self.push_back = ''
 
513
            else:
 
514
                bytes = self.socket.recv(4096)
 
515
                if bytes == '':
 
516
                    self.finished = True
 
517
                    return
 
518
                protocol.accept_bytes(bytes)
 
519
        
 
520
        self.push_back = protocol.excess_buffer
 
521
    
 
522
    def terminate_due_to_error(self):
 
523
        """Called when an unhandled exception from the protocol occurs."""
 
524
        # TODO: This should log to a server log file, but no such thing
 
525
        # exists yet.  Andrew Bennetts 2006-09-29.
 
526
        self.socket.close()
 
527
        self.finished = True
 
528
 
 
529
    def _write_out(self, bytes):
 
530
        self.socket.sendall(bytes)
 
531
 
 
532
 
 
533
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
534
 
249
535
    def __init__(self, in_file, out_file, backing_transport):
250
536
        """Construct new server.
251
537
 
253
539
        :param out_file: Python file to write responses.
254
540
        :param backing_transport: Transport for the directory served.
255
541
        """
 
542
        SmartServerStreamMedium.__init__(self, backing_transport)
256
543
        self._in = in_file
257
544
        self._out = out_file
258
 
        self.smart_server = SmartServer(backing_transport)
259
 
        # server can call back to us to get bulk data - this is not really
260
 
        # ideal, they should get it per request instead
261
 
        self.smart_server._recv_body = self._recv_bulk
262
 
 
263
 
    def _recv_tuple(self):
264
 
        """Read a request from the client and return as a tuple.
265
 
        
266
 
        Returns None at end of file (if the client closed the connection.)
267
 
        """
268
 
        return _recv_tuple(self._in)
269
 
 
270
 
    def _send_tuple(self, args):
271
 
        """Send response header"""
272
 
        return self._write_and_flush(_encode_tuple(args))
273
 
 
274
 
    def _send_error_and_disconnect(self, exception):
275
 
        self._send_tuple(('error', str(exception)))
276
 
        ## self._out.close()
277
 
        ## self._in.close()
278
 
 
279
 
    def _serve_one_request(self):
280
 
        """Read one request from input, process, send back a response.
281
 
        
282
 
        :return: False if the server should terminate, otherwise None.
283
 
        """
284
 
        req_args = self._recv_tuple()
285
 
        if req_args == None:
286
 
            # client closed connection
287
 
            return False  # shutdown server
288
 
        try:
289
 
            response = self.smart_server.dispatch_command(req_args[0], req_args[1:])
290
 
            self._send_tuple(response.args)
291
 
            if response.body is not None:
292
 
                self._send_bulk_data(response.body)
293
 
        except KeyboardInterrupt:
294
 
            raise
295
 
        except Exception, e:
296
 
            # everything else: pass to client, flush, and quit
297
 
            self._send_error_and_disconnect(e)
298
 
            return False
299
 
 
300
 
    def serve(self):
301
 
        """Serve requests until the client disconnects."""
302
 
        # Keep a reference to stderr because the sys module's globals get set to
303
 
        # None during interpreter shutdown.
304
 
        from sys import stderr
305
 
        try:
306
 
            while self._serve_one_request() != False:
307
 
                pass
308
 
        except Exception, e:
309
 
            stderr.write("%s terminating on exception %s\n" % (self, e))
310
 
            raise
 
545
 
 
546
    def _serve_one_request_unguarded(self, protocol):
 
547
        while True:
 
548
            bytes_to_read = protocol.next_read_size()
 
549
            if bytes_to_read == 0:
 
550
                # Finished serving this request.
 
551
                self._out.flush()
 
552
                return
 
553
            bytes = self._in.read(bytes_to_read)
 
554
            if bytes == '':
 
555
                # Connection has been closed.
 
556
                self.finished = True
 
557
                self._out.flush()
 
558
                return
 
559
            protocol.accept_bytes(bytes)
 
560
 
 
561
    def terminate_due_to_error(self):
 
562
        # TODO: This should log to a server log file, but no such thing
 
563
        # exists yet.  Andrew Bennetts 2006-09-29.
 
564
        self._out.close()
 
565
        self.finished = True
 
566
 
 
567
    def _write_out(self, bytes):
 
568
        self._out.write(bytes)
311
569
 
312
570
 
313
571
class SmartServerResponse(object):
314
 
    """Response generated by SmartServer."""
 
572
    """Response generated by SmartServerRequestHandler."""
315
573
 
316
574
    def __init__(self, args, body=None):
317
575
        self.args = args
318
576
        self.body = body
319
577
 
320
 
# XXX: TODO: Create a SmartServerRequest which will take the responsibility
 
578
# XXX: TODO: Create a SmartServerRequestHandler which will take the responsibility
321
579
# for delivering the data for a request. This could be done with as the
322
580
# StreamServer, though that would create conflation between request and response
323
581
# which may be undesirable.
324
582
 
325
583
 
326
 
class SmartServer(object):
 
584
class SmartServerRequestHandler(object):
327
585
    """Protocol logic for smart server.
328
586
    
329
587
    This doesn't handle serialization at all, it just processes requests and
330
588
    creates responses.
331
589
    """
332
590
 
333
 
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServer not contain
334
 
    # encoding or decoding logic to allow the wire protocol to vary from the
335
 
    # object protocol: we will want to tweak the wire protocol separate from
336
 
    # the object model, and ideally we will be able to do that without having
337
 
    # a SmartServer subclass for each wire protocol, rather just a Protocol
338
 
    # subclass.
 
591
    # IMPORTANT FOR IMPLEMENTORS: It is important that SmartServerRequestHandler
 
592
    # not contain encoding or decoding logic to allow the wire protocol to vary
 
593
    # from the object protocol: we will want to tweak the wire protocol separate
 
594
    # from the object model, and ideally we will be able to do that without
 
595
    # having a SmartServerRequestHandler subclass for each wire protocol, rather
 
596
    # just a Protocol subclass.
339
597
 
340
598
    # TODO: Better way of representing the body for commands that take it,
341
599
    # and allow it to be streamed into the server.
342
600
    
343
601
    def __init__(self, backing_transport):
344
602
        self._backing_transport = backing_transport
 
603
        self._converted_command = False
 
604
        self.finished_reading = False
 
605
        self._body_bytes = ''
 
606
        self.response = None
 
607
 
 
608
    def accept_body(self, bytes):
 
609
        """Accept body data.
 
610
 
 
611
        This should be overriden for each command that desired body data to
 
612
        handle the right format of that data. I.e. plain bytes, a bundle etc.
 
613
 
 
614
        The deserialisation into that format should be done in the Protocol
 
615
        object. Set self.desired_body_format to the format your method will
 
616
        handle.
 
617
        """
 
618
        # default fallback is to accumulate bytes.
 
619
        self._body_bytes += bytes
 
620
        
 
621
    def _end_of_body_handler(self):
 
622
        """An unimplemented end of body handler."""
 
623
        raise NotImplementedError(self._end_of_body_handler)
345
624
        
346
625
    def do_hello(self):
347
626
        """Answer a version request with my version."""
363
642
            return int(mode)
364
643
 
365
644
    def do_append(self, relpath, mode):
 
645
        self._converted_command = True
 
646
        self._relpath = relpath
 
647
        self._mode = self._deserialise_optional_mode(mode)
 
648
        self._end_of_body_handler = self._handle_do_append_end
 
649
    
 
650
    def _handle_do_append_end(self):
366
651
        old_length = self._backing_transport.append_bytes(
367
 
            relpath, self._recv_body(), self._deserialise_optional_mode(mode))
368
 
        return SmartServerResponse(('appended', '%d' % old_length))
 
652
            self._relpath, self._body_bytes, self._mode)
 
653
        self.response = SmartServerResponse(('appended', '%d' % old_length))
369
654
 
370
655
    def do_delete(self, relpath):
371
656
        self._backing_transport.delete(relpath)
372
657
 
373
 
    def do_iter_files_recursive(self, abspath):
374
 
        # XXX: the path handling needs some thought.
375
 
        #relpath = self._backing_transport.relpath(abspath)
376
 
        transport = self._backing_transport.clone(abspath)
 
658
    def do_iter_files_recursive(self, relpath):
 
659
        transport = self._backing_transport.clone(relpath)
377
660
        filenames = transport.iter_files_recursive()
378
661
        return SmartServerResponse(('names',) + tuple(filenames))
379
662
 
389
672
        self._backing_transport.move(rel_from, rel_to)
390
673
 
391
674
    def do_put(self, relpath, mode):
392
 
        self._backing_transport.put_bytes(relpath,
393
 
                self._recv_body(),
394
 
                self._deserialise_optional_mode(mode))
 
675
        self._converted_command = True
 
676
        self._relpath = relpath
 
677
        self._mode = self._deserialise_optional_mode(mode)
 
678
        self._end_of_body_handler = self._handle_do_put
 
679
 
 
680
    def _handle_do_put(self):
 
681
        self._backing_transport.put_bytes(self._relpath,
 
682
                self._body_bytes, self._mode)
 
683
        self.response = SmartServerResponse(('ok',))
395
684
 
396
685
    def _deserialise_offsets(self, text):
397
686
        # XXX: FIXME this should be on the protocol object.
404
693
        return offsets
405
694
 
406
695
    def do_put_non_atomic(self, relpath, mode, create_parent, dir_mode):
407
 
        create_parent_dir = (create_parent == 'T')
408
 
        self._backing_transport.put_bytes_non_atomic(relpath,
409
 
                self._recv_body(),
410
 
                mode=self._deserialise_optional_mode(mode),
411
 
                create_parent_dir=create_parent_dir,
412
 
                dir_mode=self._deserialise_optional_mode(dir_mode))
 
696
        self._converted_command = True
 
697
        self._end_of_body_handler = self._handle_put_non_atomic
 
698
        self._relpath = relpath
 
699
        self._dir_mode = self._deserialise_optional_mode(dir_mode)
 
700
        self._mode = self._deserialise_optional_mode(mode)
 
701
        # a boolean would be nicer XXX
 
702
        self._create_parent = (create_parent == 'T')
 
703
 
 
704
    def _handle_put_non_atomic(self):
 
705
        self._backing_transport.put_bytes_non_atomic(self._relpath,
 
706
                self._body_bytes,
 
707
                mode=self._mode,
 
708
                create_parent_dir=self._create_parent,
 
709
                dir_mode=self._dir_mode)
 
710
        self.response = SmartServerResponse(('ok',))
413
711
 
414
712
    def do_readv(self, relpath):
415
 
        offsets = self._deserialise_offsets(self._recv_body())
 
713
        self._converted_command = True
 
714
        self._end_of_body_handler = self._handle_readv_offsets
 
715
        self._relpath = relpath
 
716
 
 
717
    def end_of_body(self):
 
718
        """No more body data will be received."""
 
719
        self._run_handler_code(self._end_of_body_handler, (), {})
 
720
        # cannot read after this.
 
721
        self.finished_reading = True
 
722
 
 
723
    def _handle_readv_offsets(self):
 
724
        """accept offsets for a readv request."""
 
725
        offsets = self._deserialise_offsets(self._body_bytes)
416
726
        backing_bytes = ''.join(bytes for offset, bytes in
417
 
                             self._backing_transport.readv(relpath, offsets))
418
 
        return SmartServerResponse(('readv',), backing_bytes)
 
727
            self._backing_transport.readv(self._relpath, offsets))
 
728
        self.response = SmartServerResponse(('readv',), backing_bytes)
419
729
        
420
730
    def do_rename(self, rel_from, rel_to):
421
731
        self._backing_transport.rename(rel_from, rel_to)
439
749
        return SmartServerResponse((), tmpf.read())
440
750
 
441
751
    def dispatch_command(self, cmd, args):
 
752
        """Deprecated compatibility method.""" # XXX XXX
442
753
        func = getattr(self, 'do_' + cmd, None)
443
754
        if func is None:
444
755
            raise errors.SmartProtocolError("bad request %r" % (cmd,))
 
756
        self._run_handler_code(func, args, {})
 
757
 
 
758
    def _run_handler_code(self, callable, args, kwargs):
 
759
        """Run some handler specific code 'callable'.
 
760
 
 
761
        If a result is returned, it is considered to be the commands response,
 
762
        and finished_reading is set true, and its assigned to self.response.
 
763
 
 
764
        Any exceptions caught are translated and a response object created
 
765
        from them.
 
766
        """
 
767
        result = self._call_converting_errors(callable, args, kwargs)
 
768
        if result is not None:
 
769
            self.response = result
 
770
            self.finished_reading = True
 
771
        # handle unconverted commands
 
772
        if not self._converted_command:
 
773
            self.finished_reading = True
 
774
            if result is None:
 
775
                self.response = SmartServerResponse(('ok',))
 
776
 
 
777
    def _call_converting_errors(self, callable, args, kwargs):
 
778
        """Call callable converting errors to Response objects."""
445
779
        try:
446
 
            result = func(*args)
447
 
            if result is None: 
448
 
                result = SmartServerResponse(('ok',))
449
 
            return result
 
780
            return callable(*args, **kwargs)
450
781
        except errors.NoSuchFile, e:
451
782
            return SmartServerResponse(('NoSuchFile', e.path))
452
783
        except errors.FileExists, e:
461
792
            # with a plain string
462
793
            str_or_unicode = e.object
463
794
            if isinstance(str_or_unicode, unicode):
464
 
                val = u'u:' + str_or_unicode
 
795
                # XXX: UTF-8 might have \x01 (our seperator byte) in it.  We
 
796
                # should escape it somehow.
 
797
                val = 'u:' + str_or_unicode.encode('utf-8')
465
798
            else:
466
 
                val = u's:' + str_or_unicode.encode('base64')
 
799
                val = 's:' + str_or_unicode.encode('base64')
467
800
            # This handles UnicodeEncodeError or UnicodeDecodeError
468
801
            return SmartServerResponse((e.__class__.__name__,
469
802
                    e.encoding, val, str(e.start), str(e.end), e.reason))
477
810
class SmartTCPServer(object):
478
811
    """Listens on a TCP socket and accepts connections from smart clients"""
479
812
 
480
 
    def __init__(self, backing_transport=None, host='127.0.0.1', port=0):
 
813
    def __init__(self, backing_transport, host='127.0.0.1', port=0):
481
814
        """Construct a new server.
482
815
 
483
816
        To actually start it running, call either start_background_thread or
486
819
        :param host: Name of the interface to listen on.
487
820
        :param port: TCP port to listen on, or 0 to allocate a transient port.
488
821
        """
489
 
        if backing_transport is None:
490
 
            backing_transport = memory.MemoryTransport()
491
822
        self._server_socket = socket.socket()
492
823
        self._server_socket.bind((host, port))
493
824
        self.port = self._server_socket.getsockname()[1]
522
853
        # propogates to the newly accepted socket.
523
854
        conn.setblocking(True)
524
855
        conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
525
 
        from_client = conn.makefile('r')
526
 
        to_client = conn.makefile('w')
527
 
        handler = SmartStreamServer(from_client, to_client,
528
 
                self.backing_transport)
 
856
        handler = SmartServerSocketStreamMedium(conn, self.backing_transport)
529
857
        connection_thread = threading.Thread(None, handler.serve, name='smart-server-child')
530
858
        connection_thread.setDaemon(True)
531
859
        connection_thread.start()
552
880
    """
553
881
 
554
882
    def __init__(self):
555
 
        self._homedir = os.getcwd()
 
883
        self._homedir = urlutils.local_path_to_url(os.getcwd())[7:]
556
884
        # The server is set up by default like for ssh access: the client
557
885
        # passes filesystem-absolute paths; therefore the server must look
558
886
        # them up relative to the root directory.  it might be better to act
559
887
        # a public server and have the server rewrite paths into the test
560
888
        # directory.
561
 
        SmartTCPServer.__init__(self, transport.get_transport("file:///"))
 
889
        SmartTCPServer.__init__(self,
 
890
            transport.get_transport(urlutils.local_path_to_url('/')))
562
891
        
563
892
    def setUp(self):
564
893
        """Set up server for testing"""
570
899
    def get_url(self):
571
900
        """Return the url of the server"""
572
901
        host, port = self._server_socket.getsockname()
573
 
        # XXX: I think this is likely to break on windows -- self._homedir will
574
 
        # have backslashes (and maybe a drive letter?).
575
 
        #  -- Andrew Bennetts, 2006-08-29
576
902
        return "bzr://%s:%d%s" % (host, port, urlutils.escape(self._homedir))
577
903
 
578
904
    def get_bogus_url(self):
610
936
    # SmartTransport is an adapter from the Transport object model to the 
611
937
    # SmartClient model, not an encoder.
612
938
 
613
 
    def __init__(self, url, clone_from=None, client=None):
 
939
    def __init__(self, url, clone_from=None, medium=None):
614
940
        """Constructor.
615
941
 
616
 
        :param client: ignored when clone_from is not None.
 
942
        :param medium: The medium to use for this RemoteTransport. This must be
 
943
            supplied if clone_from is None.
617
944
        """
618
945
        ### Technically super() here is faulty because Transport's __init__
619
946
        ### fails to take 2 parameters, and if super were to choose a silly
624
951
        self._scheme, self._username, self._password, self._host, self._port, self._path = \
625
952
                transport.split_url(url)
626
953
        if clone_from is None:
627
 
            if client is None:
628
 
                self._client = SmartStreamClient(self._connect_to_server)
629
 
            else:
630
 
                self._client = client
 
954
            self._medium = medium
631
955
        else:
632
956
            # credentials may be stripped from the base in some circumstances
633
957
            # as yet to be clearly defined or documented, so copy them.
634
958
            self._username = clone_from._username
635
959
            # reuse same connection
636
 
            self._client = clone_from._client
 
960
            self._medium = clone_from._medium
 
961
        assert self._medium is not None
637
962
 
638
963
    def abspath(self, relpath):
639
964
        """Return the full url to the given relative path.
649
974
        This essentially opens a handle on a different remote directory.
650
975
        """
651
976
        if relative_url is None:
652
 
            return self.__class__(self.base, self)
 
977
            return SmartTransport(self.base, self)
653
978
        else:
654
 
            return self.__class__(self.abspath(relative_url), self)
 
979
            return SmartTransport(self.abspath(relative_url), self)
655
980
 
656
981
    def is_readonly(self):
657
982
        """Smart server transport can do read/write file operations."""
658
983
        return False
659
984
                                                   
660
985
    def get_smart_client(self):
661
 
        return self._client
 
986
        return self._medium
 
987
 
 
988
    def get_smart_medium(self):
 
989
        return self._medium
662
990
                                                   
663
991
    def _unparse_url(self, path):
664
992
        """Return URL for a path.
681
1009
        """Returns the Unicode version of the absolute path for relpath."""
682
1010
        return self._combine_paths(self._path, relpath)
683
1011
 
 
1012
    def _call(self, method, *args):
 
1013
        resp = self._call2(method, *args)
 
1014
        self._translate_error(resp)
 
1015
 
 
1016
    def _call2(self, method, *args):
 
1017
        """Call a method on the remote server."""
 
1018
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1019
        protocol.call(method, *args)
 
1020
        return protocol.read_response_tuple()
 
1021
 
 
1022
    def _call_with_body_bytes(self, method, args, body):
 
1023
        """Call a method on the remote server with body bytes."""
 
1024
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1025
        protocol.call_with_body_bytes((method, ) + args, body)
 
1026
        return protocol.read_response_tuple()
 
1027
 
684
1028
    def has(self, relpath):
685
1029
        """Indicate whether a remote file of the given name exists or not.
686
1030
 
687
1031
        :see: Transport.has()
688
1032
        """
689
 
        resp = self._client._call('has', self._remote_path(relpath))
 
1033
        resp = self._call2('has', self._remote_path(relpath))
690
1034
        if resp == ('yes', ):
691
1035
            return True
692
1036
        elif resp == ('no', ):
699
1043
        
700
1044
        :see: Transport.get_bytes()/get_file()
701
1045
        """
 
1046
        return StringIO(self.get_bytes(relpath))
 
1047
 
 
1048
    def get_bytes(self, relpath):
702
1049
        remote = self._remote_path(relpath)
703
 
        resp = self._client._call('get', remote)
 
1050
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1051
        protocol.call('get', remote)
 
1052
        resp = protocol.read_response_tuple(True)
704
1053
        if resp != ('ok', ):
 
1054
            protocol.cancel_read_body()
705
1055
            self._translate_error(resp, relpath)
706
 
        return StringIO(self._client._recv_bulk())
 
1056
        return protocol.read_body_bytes()
707
1057
 
708
1058
    def _serialise_optional_mode(self, mode):
709
1059
        if mode is None:
712
1062
            return '%d' % mode
713
1063
 
714
1064
    def mkdir(self, relpath, mode=None):
715
 
        resp = self._client._call('mkdir', 
716
 
                                  self._remote_path(relpath), 
717
 
                                  self._serialise_optional_mode(mode))
 
1065
        resp = self._call2('mkdir', self._remote_path(relpath),
 
1066
            self._serialise_optional_mode(mode))
718
1067
        self._translate_error(resp)
719
1068
 
720
1069
    def put_bytes(self, relpath, upload_contents, mode=None):
721
1070
        # FIXME: upload_file is probably not safe for non-ascii characters -
722
1071
        # should probably just pass all parameters as length-delimited
723
1072
        # strings?
724
 
        resp = self._client._call_with_upload(
725
 
            'put',
 
1073
        resp = self._call_with_body_bytes('put',
726
1074
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
727
1075
            upload_contents)
728
1076
        self._translate_error(resp)
736
1084
        if create_parent_dir:
737
1085
            create_parent_str = 'T'
738
1086
 
739
 
        resp = self._client._call_with_upload(
 
1087
        resp = self._call_with_body_bytes(
740
1088
            'put_non_atomic',
741
1089
            (self._remote_path(relpath), self._serialise_optional_mode(mode),
742
1090
             create_parent_str, self._serialise_optional_mode(dir_mode)),
765
1113
        return self.append_bytes(relpath, from_file.read(), mode)
766
1114
        
767
1115
    def append_bytes(self, relpath, bytes, mode=None):
768
 
        resp = self._client._call_with_upload(
 
1116
        resp = self._call_with_body_bytes(
769
1117
            'append',
770
1118
            (self._remote_path(relpath), self._serialise_optional_mode(mode)),
771
1119
            bytes)
774
1122
        self._translate_error(resp)
775
1123
 
776
1124
    def delete(self, relpath):
777
 
        resp = self._client._call('delete', self._remote_path(relpath))
 
1125
        resp = self._call2('delete', self._remote_path(relpath))
778
1126
        self._translate_error(resp)
779
1127
 
780
1128
    def readv(self, relpath, offsets):
791
1139
                               limit=self._max_readv_combine,
792
1140
                               fudge_factor=self._bytes_to_read_before_seek))
793
1141
 
794
 
 
795
 
        resp = self._client._call_with_upload(
796
 
            'readv',
797
 
            (self._remote_path(relpath),),
798
 
            self._client._serialise_offsets((c.start, c.length) for c in coalesced))
 
1142
        protocol = SmartClientRequestProtocolOne(self._medium.get_request())
 
1143
        protocol.call_with_body_readv_array(
 
1144
            ('readv', self._remote_path(relpath)),
 
1145
            [(c.start, c.length) for c in coalesced])
 
1146
        resp = protocol.read_response_tuple(True)
799
1147
 
800
1148
        if resp[0] != 'readv':
801
1149
            # This should raise an exception
 
1150
            protocol.cancel_read_body()
802
1151
            self._translate_error(resp)
803
1152
            return
804
1153
 
805
 
        data = self._client._recv_bulk()
 
1154
        # FIXME: this should know how many bytes are needed, for clarity.
 
1155
        data = protocol.read_body_bytes()
806
1156
        # Cache the results, but only until they have been fulfilled
807
1157
        data_map = {}
808
1158
        for c_offset in coalesced:
821
1171
                cur_offset_and_size = offset_stack.next()
822
1172
 
823
1173
    def rename(self, rel_from, rel_to):
824
 
        self._call('rename', 
 
1174
        self._call('rename',
825
1175
                   self._remote_path(rel_from),
826
1176
                   self._remote_path(rel_to))
827
1177
 
828
1178
    def move(self, rel_from, rel_to):
829
 
        self._call('move', 
 
1179
        self._call('move',
830
1180
                   self._remote_path(rel_from),
831
1181
                   self._remote_path(rel_to))
832
1182
 
833
1183
    def rmdir(self, relpath):
834
1184
        resp = self._call('rmdir', self._remote_path(relpath))
835
1185
 
836
 
    def _call(self, method, *args):
837
 
        resp = self._client._call(method, *args)
838
 
        self._translate_error(resp)
839
 
 
840
1186
    def _translate_error(self, resp, orig_path=None):
841
1187
        """Raise an exception from a response"""
842
1188
        if resp is None:
867
1213
            end = int(resp[4])
868
1214
            reason = str(resp[5]) # reason must always be a string
869
1215
            if val.startswith('u:'):
870
 
                val = val[2:]
 
1216
                val = val[2:].decode('utf-8')
871
1217
            elif val.startswith('s:'):
872
1218
                val = val[2:].decode('base64')
873
1219
            if what == 'UnicodeDecodeError':
879
1225
        else:
880
1226
            raise errors.SmartProtocolError('unexpected smart server error: %r' % (resp,))
881
1227
 
882
 
    def _send_tuple(self, args):
883
 
        self._client._send_tuple(args)
884
 
 
885
 
    def _recv_tuple(self):
886
 
        return self._client._recv_tuple()
887
 
 
888
1228
    def disconnect(self):
889
 
        self._client.disconnect()
 
1229
        self._medium.disconnect()
890
1230
 
891
1231
    def delete_tree(self, relpath):
892
1232
        raise errors.TransportNotPossible('readonly transport')
893
1233
 
894
1234
    def stat(self, relpath):
895
 
        resp = self._client._call('stat', self._remote_path(relpath))
 
1235
        resp = self._call2('stat', self._remote_path(relpath))
896
1236
        if resp[0] == 'stat':
897
1237
            return SmartStat(int(resp[1]), int(resp[2], 8))
898
1238
        else:
915
1255
        return True
916
1256
 
917
1257
    def list_dir(self, relpath):
918
 
        resp = self._client._call('list_dir',
919
 
                                  self._remote_path(relpath))
 
1258
        resp = self._call2('list_dir', self._remote_path(relpath))
920
1259
        if resp[0] == 'names':
921
1260
            return [name.encode('ascii') for name in resp[1:]]
922
1261
        else:
923
1262
            self._translate_error(resp)
924
1263
 
925
1264
    def iter_files_recursive(self):
926
 
        resp = self._client._call('iter_files_recursive',
927
 
                                  self._remote_path(''))
 
1265
        resp = self._call2('iter_files_recursive', self._remote_path(''))
928
1266
        if resp[0] == 'names':
929
1267
            return resp[1:]
930
1268
        else:
931
1269
            self._translate_error(resp)
932
1270
 
933
1271
 
934
 
class SmartStreamClient(SmartProtocolBase):
935
 
    """Connection to smart server over two streams"""
936
 
 
937
 
    def __init__(self, connect_func):
938
 
        self._connect_func = connect_func
939
 
        self._connected = False
940
 
 
941
 
    def __del__(self):
942
 
        self.disconnect()
943
 
 
944
 
    def _ensure_connection(self):
945
 
        if not self._connected:
946
 
            self._in, self._out = self._connect_func()
947
 
            self._connected = True
948
 
 
949
 
    def _send_tuple(self, args):
950
 
        self._ensure_connection()
951
 
        return self._write_and_flush(_encode_tuple(args))
952
 
 
953
 
    def _send_bulk_data(self, body):
954
 
        self._ensure_connection()
955
 
        SmartProtocolBase._send_bulk_data(self, body)
956
 
        
957
 
    def _recv_bulk(self):
958
 
        self._ensure_connection()
959
 
        return SmartProtocolBase._recv_bulk(self)
 
1272
class SmartClientMediumRequest(object):
 
1273
    """A request on a SmartClientMedium.
 
1274
 
 
1275
    Each request allows bytes to be provided to it via accept_bytes, and then
 
1276
    the response bytes to be read via read_bytes.
 
1277
 
 
1278
    For instance:
 
1279
    request.accept_bytes('123')
 
1280
    request.finished_writing()
 
1281
    result = request.read_bytes(3)
 
1282
    request.finished_reading()
 
1283
 
 
1284
    It is up to the individual SmartClientMedium whether multiple concurrent
 
1285
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
1286
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
1287
    details on concurrency and pipelining.
 
1288
    """
 
1289
 
 
1290
    def __init__(self, medium):
 
1291
        """Construct a SmartClientMediumRequest for the medium medium."""
 
1292
        self._medium = medium
 
1293
        # we track state by constants - we may want to use the same
 
1294
        # pattern as BodyReader if it gets more complex.
 
1295
        # valid states are: "writing", "reading", "done"
 
1296
        self._state = "writing"
 
1297
 
 
1298
    def accept_bytes(self, bytes):
 
1299
        """Accept bytes for inclusion in this request.
 
1300
 
 
1301
        This method may not be be called after finished_writing() has been
 
1302
        called.  It depends upon the Medium whether or not the bytes will be
 
1303
        immediately transmitted. Message based Mediums will tend to buffer the
 
1304
        bytes until finished_writing() is called.
 
1305
 
 
1306
        :param bytes: A bytestring.
 
1307
        """
 
1308
        if self._state != "writing":
 
1309
            raise errors.WritingCompleted(self)
 
1310
        self._accept_bytes(bytes)
 
1311
 
 
1312
    def _accept_bytes(self, bytes):
 
1313
        """Helper for accept_bytes.
 
1314
 
 
1315
        Accept_bytes checks the state of the request to determing if bytes
 
1316
        should be accepted. After that it hands off to _accept_bytes to do the
 
1317
        actual acceptance.
 
1318
        """
 
1319
        raise NotImplementedError(self._accept_bytes)
 
1320
 
 
1321
    def finished_reading(self):
 
1322
        """Inform the request that all desired data has been read.
 
1323
 
 
1324
        This will remove the request from the pipeline for its medium (if the
 
1325
        medium supports pipelining) and any further calls to methods on the
 
1326
        request will raise ReadingCompleted.
 
1327
        """
 
1328
        if self._state == "writing":
 
1329
            raise errors.WritingNotComplete(self)
 
1330
        if self._state != "reading":
 
1331
            raise errors.ReadingCompleted(self)
 
1332
        self._state = "done"
 
1333
        self._finished_reading()
 
1334
 
 
1335
    def _finished_reading(self):
 
1336
        """Helper for finished_reading.
 
1337
 
 
1338
        finished_reading checks the state of the request to determine if 
 
1339
        finished_reading is allowed, and if it is hands off to _finished_reading
 
1340
        to perform the action.
 
1341
        """
 
1342
        raise NotImplementedError(self._finished_reading)
 
1343
 
 
1344
    def finished_writing(self):
 
1345
        """Finish the writing phase of this request.
 
1346
 
 
1347
        This will flush all pending data for this request along the medium.
 
1348
        After calling finished_writing, you may not call accept_bytes anymore.
 
1349
        """
 
1350
        if self._state != "writing":
 
1351
            raise errors.WritingCompleted(self)
 
1352
        self._state = "reading"
 
1353
        self._finished_writing()
 
1354
 
 
1355
    def _finished_writing(self):
 
1356
        """Helper for finished_writing.
 
1357
 
 
1358
        finished_writing checks the state of the request to determine if 
 
1359
        finished_writing is allowed, and if it is hands off to _finished_writing
 
1360
        to perform the action.
 
1361
        """
 
1362
        raise NotImplementedError(self._finished_writing)
 
1363
 
 
1364
    def read_bytes(self, count):
 
1365
        """Read bytes from this requests response.
 
1366
 
 
1367
        This method will block and wait for count bytes to be read. It may not
 
1368
        be invoked until finished_writing() has been called - this is to ensure
 
1369
        a message-based approach to requests, for compatability with message
 
1370
        based mediums like HTTP.
 
1371
        """
 
1372
        if self._state == "writing":
 
1373
            raise errors.WritingNotComplete(self)
 
1374
        if self._state != "reading":
 
1375
            raise errors.ReadingCompleted(self)
 
1376
        return self._read_bytes(count)
 
1377
 
 
1378
    def _read_bytes(self, count):
 
1379
        """Helper for read_bytes.
 
1380
 
 
1381
        read_bytes checks the state of the request to determing if bytes
 
1382
        should be read. After that it hands off to _read_bytes to do the
 
1383
        actual read.
 
1384
        """
 
1385
        raise NotImplementedError(self._read_bytes)
 
1386
 
 
1387
 
 
1388
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
1389
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
1390
 
 
1391
    def __init__(self, medium):
 
1392
        SmartClientMediumRequest.__init__(self, medium)
 
1393
        # check that we are safe concurrency wise. If some streams start
 
1394
        # allowing concurrent requests - i.e. via multiplexing - then this
 
1395
        # assert should be moved to SmartClientStreamMedium.get_request,
 
1396
        # and the setting/unsetting of _current_request likewise moved into
 
1397
        # that class : but its unneeded overhead for now. RBC 20060922
 
1398
        if self._medium._current_request is not None:
 
1399
            raise errors.TooManyConcurrentRequests(self._medium)
 
1400
        self._medium._current_request = self
 
1401
 
 
1402
    def _accept_bytes(self, bytes):
 
1403
        """See SmartClientMediumRequest._accept_bytes.
 
1404
        
 
1405
        This forwards to self._medium._accept_bytes because we are operating
 
1406
        on the mediums stream.
 
1407
        """
 
1408
        self._medium._accept_bytes(bytes)
 
1409
 
 
1410
    def _finished_reading(self):
 
1411
        """See SmartClientMediumRequest._finished_reading.
 
1412
 
 
1413
        This clears the _current_request on self._medium to allow a new 
 
1414
        request to be created.
 
1415
        """
 
1416
        assert self._medium._current_request is self
 
1417
        self._medium._current_request = None
 
1418
        
 
1419
    def _finished_writing(self):
 
1420
        """See SmartClientMediumRequest._finished_writing.
 
1421
 
 
1422
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
1423
        """
 
1424
        self._medium._flush()
 
1425
 
 
1426
    def _read_bytes(self, count):
 
1427
        """See SmartClientMediumRequest._read_bytes.
 
1428
        
 
1429
        This forwards to self._medium._read_bytes because we are operating
 
1430
        on the mediums stream.
 
1431
        """
 
1432
        return self._medium._read_bytes(count)
 
1433
 
 
1434
 
 
1435
class SmartClientRequestProtocolOne(SmartProtocolBase):
 
1436
    """The client-side protocol for smart version 1."""
 
1437
 
 
1438
    def __init__(self, request):
 
1439
        """Construct a SmartClientRequestProtocolOne.
 
1440
 
 
1441
        :param request: A SmartClientMediumRequest to serialise onto and
 
1442
            deserialise from.
 
1443
        """
 
1444
        self._request = request
 
1445
        self._body_buffer = None
 
1446
 
 
1447
    def call(self, *args):
 
1448
        bytes = _encode_tuple(args)
 
1449
        self._request.accept_bytes(bytes)
 
1450
        self._request.finished_writing()
 
1451
 
 
1452
    def call_with_body_bytes(self, args, body):
 
1453
        """Make a remote call of args with body bytes 'body'.
 
1454
 
 
1455
        After calling this, call read_response_tuple to find the result out.
 
1456
        """
 
1457
        bytes = _encode_tuple(args)
 
1458
        self._request.accept_bytes(bytes)
 
1459
        bytes = self._encode_bulk_data(body)
 
1460
        self._request.accept_bytes(bytes)
 
1461
        self._request.finished_writing()
 
1462
 
 
1463
    def call_with_body_readv_array(self, args, body):
 
1464
        """Make a remote call with a readv array.
 
1465
 
 
1466
        The body is encoded with one line per readv offset pair. The numbers in
 
1467
        each pair are separated by a comma, and no trailing \n is emitted.
 
1468
        """
 
1469
        bytes = _encode_tuple(args)
 
1470
        self._request.accept_bytes(bytes)
 
1471
        readv_bytes = self._serialise_offsets(body)
 
1472
        bytes = self._encode_bulk_data(readv_bytes)
 
1473
        self._request.accept_bytes(bytes)
 
1474
        self._request.finished_writing()
 
1475
 
 
1476
    def cancel_read_body(self):
 
1477
        """After expecting a body, a response code may indicate one otherwise.
 
1478
 
 
1479
        This method lets the domain client inform the protocol that no body
 
1480
        will be transmitted. This is a terminal method: after calling it the
 
1481
        protocol is not able to be used further.
 
1482
        """
 
1483
        self._request.finished_reading()
 
1484
 
 
1485
    def read_response_tuple(self, expect_body=False):
 
1486
        """Read a response tuple from the wire.
 
1487
 
 
1488
        This should only be called once.
 
1489
        """
 
1490
        result = self._recv_tuple()
 
1491
        if not expect_body:
 
1492
            self._request.finished_reading()
 
1493
        return result
 
1494
 
 
1495
    def read_body_bytes(self, count=-1):
 
1496
        """Read bytes from the body, decoding into a byte stream.
 
1497
        
 
1498
        We read all bytes at once to ensure we've checked the trailer for 
 
1499
        errors, and then feed the buffer back as read_body_bytes is called.
 
1500
        """
 
1501
        if self._body_buffer is not None:
 
1502
            return self._body_buffer.read(count)
 
1503
        _body_decoder = LengthPrefixedBodyDecoder()
 
1504
 
 
1505
        while not _body_decoder.finished_reading:
 
1506
            bytes_wanted = _body_decoder.next_read_size()
 
1507
            bytes = self._request.read_bytes(bytes_wanted)
 
1508
            _body_decoder.accept_bytes(bytes)
 
1509
        self._request.finished_reading()
 
1510
        self._body_buffer = StringIO(_body_decoder.read_pending_data())
 
1511
        # XXX: TODO check the trailer result.
 
1512
        return self._body_buffer.read(count)
960
1513
 
961
1514
    def _recv_tuple(self):
962
 
        self._ensure_connection()
963
 
        return SmartProtocolBase._recv_tuple(self)
964
 
 
965
 
    def _recv_trailer(self):
966
 
        self._ensure_connection()
967
 
        return SmartProtocolBase._recv_trailer(self)
968
 
 
969
 
    def disconnect(self):
970
 
        """Close connection to the server"""
971
 
        if self._connected:
972
 
            self._out.close()
973
 
            self._in.close()
974
 
 
975
 
    def _call(self, *args):
976
 
        self._send_tuple(args)
977
 
        return self._recv_tuple()
978
 
 
979
 
    def _call_with_upload(self, method, args, body):
980
 
        """Call an rpc, supplying bulk upload data.
981
 
 
982
 
        :param method: method name to call
983
 
        :param args: parameter args tuple
984
 
        :param body: upload body as a byte string
985
 
        """
986
 
        self._send_tuple((method,) + args)
987
 
        self._send_bulk_data(body)
988
 
        return self._recv_tuple()
 
1515
        """Receive a tuple from the medium request."""
 
1516
        line = ''
 
1517
        while not line or line[-1] != '\n':
 
1518
            # TODO: this is inefficient - but tuples are short.
 
1519
            new_char = self._request.read_bytes(1)
 
1520
            line += new_char
 
1521
            assert new_char != '', "end of file reading from server."
 
1522
        return _decode_tuple(line)
989
1523
 
990
1524
    def query_version(self):
991
1525
        """Return protocol version number of the server."""
992
 
        # XXX: should make sure it's empty
993
 
        self._send_tuple(('hello',))
994
 
        resp = self._recv_tuple()
 
1526
        self.call('hello')
 
1527
        resp = self.read_response_tuple()
995
1528
        if resp == ('ok', '1'):
996
1529
            return 1
997
1530
        else:
998
1531
            raise errors.SmartProtocolError("bad response %r" % (resp,))
999
1532
 
1000
1533
 
1001
 
class SmartTCPTransport(SmartTransport):
1002
 
    """Connection to smart server over plain tcp"""
1003
 
 
1004
 
    def __init__(self, url, clone_from=None):
1005
 
        super(SmartTCPTransport, self).__init__(url, clone_from)
1006
 
        try:
1007
 
            self._port = int(self._port)
1008
 
        except (ValueError, TypeError), e:
1009
 
            raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
1010
 
        self._socket = None
1011
 
 
1012
 
    def _connect_to_server(self):
 
1534
class SmartClientMedium(object):
 
1535
    """Smart client is a medium for sending smart protocol requests over."""
 
1536
 
 
1537
    def disconnect(self):
 
1538
        """If this medium maintains a persistent connection, close it.
 
1539
        
 
1540
        The default implementation does nothing.
 
1541
        """
 
1542
        
 
1543
 
 
1544
class SmartClientStreamMedium(SmartClientMedium):
 
1545
    """Stream based medium common class.
 
1546
 
 
1547
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
1548
    SmartClientStreamMediumRequest for their requests, and should implement
 
1549
    _accept_bytes and _read_bytes to allow the request objects to send and
 
1550
    receive bytes.
 
1551
    """
 
1552
 
 
1553
    def __init__(self):
 
1554
        self._current_request = None
 
1555
 
 
1556
    def accept_bytes(self, bytes):
 
1557
        self._accept_bytes(bytes)
 
1558
 
 
1559
    def __del__(self):
 
1560
        """The SmartClientStreamMedium knows how to close the stream when it is
 
1561
        finished with it.
 
1562
        """
 
1563
        self.disconnect()
 
1564
 
 
1565
    def _flush(self):
 
1566
        """Flush the output stream.
 
1567
        
 
1568
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
1569
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
1570
        """
 
1571
        raise NotImplementedError(self._flush)
 
1572
 
 
1573
    def get_request(self):
 
1574
        """See SmartClientMedium.get_request().
 
1575
 
 
1576
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
1577
        for get_request.
 
1578
        """
 
1579
        return SmartClientStreamMediumRequest(self)
 
1580
 
 
1581
    def read_bytes(self, count):
 
1582
        return self._read_bytes(count)
 
1583
 
 
1584
 
 
1585
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
1586
    """A client medium using simple pipes.
 
1587
    
 
1588
    This client does not manage the pipes: it assumes they will always be open.
 
1589
    """
 
1590
 
 
1591
    def __init__(self, readable_pipe, writeable_pipe):
 
1592
        SmartClientStreamMedium.__init__(self)
 
1593
        self._readable_pipe = readable_pipe
 
1594
        self._writeable_pipe = writeable_pipe
 
1595
 
 
1596
    def _accept_bytes(self, bytes):
 
1597
        """See SmartClientStreamMedium.accept_bytes."""
 
1598
        self._writeable_pipe.write(bytes)
 
1599
 
 
1600
    def _flush(self):
 
1601
        """See SmartClientStreamMedium._flush()."""
 
1602
        self._writeable_pipe.flush()
 
1603
 
 
1604
    def _read_bytes(self, count):
 
1605
        """See SmartClientStreamMedium._read_bytes."""
 
1606
        return self._readable_pipe.read(count)
 
1607
 
 
1608
 
 
1609
class SmartSSHClientMedium(SmartClientStreamMedium):
 
1610
    """A client medium using SSH."""
 
1611
    
 
1612
    def __init__(self, host, port=None, username=None, password=None,
 
1613
            vendor=None):
 
1614
        """Creates a client that will connect on the first use.
 
1615
        
 
1616
        :param vendor: An optional override for the ssh vendor to use. See
 
1617
            bzrlib.transport.ssh for details on ssh vendors.
 
1618
        """
 
1619
        SmartClientStreamMedium.__init__(self)
 
1620
        self._connected = False
 
1621
        self._host = host
 
1622
        self._password = password
 
1623
        self._port = port
 
1624
        self._username = username
 
1625
        self._read_from = None
 
1626
        self._ssh_connection = None
 
1627
        self._vendor = vendor
 
1628
        self._write_to = None
 
1629
 
 
1630
    def _accept_bytes(self, bytes):
 
1631
        """See SmartClientStreamMedium.accept_bytes."""
 
1632
        self._ensure_connection()
 
1633
        self._write_to.write(bytes)
 
1634
 
 
1635
    def disconnect(self):
 
1636
        """See SmartClientMedium.disconnect()."""
 
1637
        if not self._connected:
 
1638
            return
 
1639
        self._read_from.close()
 
1640
        self._write_to.close()
 
1641
        self._ssh_connection.close()
 
1642
        self._connected = False
 
1643
 
 
1644
    def _ensure_connection(self):
 
1645
        """Connect this medium if not already connected."""
 
1646
        if self._connected:
 
1647
            return
 
1648
        executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
1649
        if self._vendor is None:
 
1650
            vendor = ssh._get_ssh_vendor()
 
1651
        else:
 
1652
            vendor = self._vendor
 
1653
        self._ssh_connection = vendor.connect_ssh(self._username,
 
1654
                self._password, self._host, self._port,
 
1655
                command=[executable, 'serve', '--inet', '--directory=/',
 
1656
                         '--allow-writes'])
 
1657
        self._read_from, self._write_to = \
 
1658
            self._ssh_connection.get_filelike_channels()
 
1659
        self._connected = True
 
1660
 
 
1661
    def _flush(self):
 
1662
        """See SmartClientStreamMedium._flush()."""
 
1663
        self._write_to.flush()
 
1664
 
 
1665
    def _read_bytes(self, count):
 
1666
        """See SmartClientStreamMedium.read_bytes."""
 
1667
        if not self._connected:
 
1668
            raise errors.MediumNotConnected(self)
 
1669
        return self._read_from.read(count)
 
1670
 
 
1671
 
 
1672
class SmartTCPClientMedium(SmartClientStreamMedium):
 
1673
    """A client medium using TCP."""
 
1674
    
 
1675
    def __init__(self, host, port):
 
1676
        """Creates a client that will connect on the first use."""
 
1677
        SmartClientStreamMedium.__init__(self)
 
1678
        self._connected = False
 
1679
        self._host = host
 
1680
        self._port = port
 
1681
        self._socket = None
 
1682
 
 
1683
    def _accept_bytes(self, bytes):
 
1684
        """See SmartClientMedium.accept_bytes."""
 
1685
        self._ensure_connection()
 
1686
        self._socket.sendall(bytes)
 
1687
 
 
1688
    def disconnect(self):
 
1689
        """See SmartClientMedium.disconnect()."""
 
1690
        if not self._connected:
 
1691
            return
 
1692
        self._socket.close()
 
1693
        self._socket = None
 
1694
        self._connected = False
 
1695
 
 
1696
    def _ensure_connection(self):
 
1697
        """Connect this medium if not already connected."""
 
1698
        if self._connected:
 
1699
            return
1013
1700
        self._socket = socket.socket()
1014
1701
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
1015
1702
        result = self._socket.connect_ex((self._host, int(self._port)))
1016
1703
        if result:
1017
1704
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
1018
1705
                    (self._host, self._port, os.strerror(result)))
1019
 
        # TODO: May be more efficient to just treat them as sockets
1020
 
        # throughout?  But what about pipes to ssh?...
1021
 
        to_server = self._socket.makefile('w')
1022
 
        from_server = self._socket.makefile('r')
1023
 
        return from_server, to_server
1024
 
 
1025
 
    def disconnect(self):
1026
 
        super(SmartTCPTransport, self).disconnect()
1027
 
        # XXX: Is closing the socket as well as closing the files really
1028
 
        # necessary?
1029
 
        if self._socket is not None:
1030
 
            self._socket.close()
1031
 
 
1032
 
try:
1033
 
    from bzrlib.transport import sftp, ssh
1034
 
except errors.ParamikoNotPresent:
1035
 
    # no paramiko, no SSHTransport.
1036
 
    pass
1037
 
else:
1038
 
    class SmartSSHTransport(SmartTransport):
1039
 
        """Connection to smart server over SSH."""
1040
 
 
1041
 
        def __init__(self, url, clone_from=None):
1042
 
            # TODO: all this probably belongs in the parent class.
1043
 
            super(SmartSSHTransport, self).__init__(url, clone_from)
1044
 
            try:
1045
 
                if self._port is not None:
1046
 
                    self._port = int(self._port)
1047
 
            except (ValueError, TypeError), e:
1048
 
                raise errors.InvalidURL(path=url, extra="invalid port %s" % self._port)
1049
 
 
1050
 
        def _connect_to_server(self):
1051
 
            executable = os.environ.get('BZR_REMOTE_PATH', 'bzr')
1052
 
            vendor = ssh._get_ssh_vendor()
1053
 
            self._ssh_connection = vendor.connect_ssh(self._username,
1054
 
                    self._password, self._host, self._port,
1055
 
                    command=[executable, 'serve', '--inet', '--directory=/',
1056
 
                             '--allow-writes'])
1057
 
            return self._ssh_connection.get_filelike_channels()
1058
 
 
1059
 
        def disconnect(self):
1060
 
            super(SmartSSHTransport, self).disconnect()
1061
 
            self._ssh_connection.close()
 
1706
        self._connected = True
 
1707
 
 
1708
    def _flush(self):
 
1709
        """See SmartClientStreamMedium._flush().
 
1710
        
 
1711
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
1712
        add a means to do a flush, but that can be done in the future.
 
1713
        """
 
1714
 
 
1715
    def _read_bytes(self, count):
 
1716
        """See SmartClientMedium.read_bytes."""
 
1717
        if not self._connected:
 
1718
            raise errors.MediumNotConnected(self)
 
1719
        return self._socket.recv(count)
 
1720
 
 
1721
 
 
1722
class SmartTCPTransport(SmartTransport):
 
1723
    """Connection to smart server over plain tcp.
 
1724
    
 
1725
    This is essentially just a factory to get 'RemoteTransport(url,
 
1726
        SmartTCPClientMedium).
 
1727
    """
 
1728
 
 
1729
    def __init__(self, url):
 
1730
        _scheme, _username, _password, _host, _port, _path = \
 
1731
            transport.split_url(url)
 
1732
        try:
 
1733
            _port = int(_port)
 
1734
        except (ValueError, TypeError), e:
 
1735
            raise errors.InvalidURL(path=url, extra="invalid port %s" % _port)
 
1736
        medium = SmartTCPClientMedium(_host, _port)
 
1737
        super(SmartTCPTransport, self).__init__(url, medium=medium)
 
1738
 
 
1739
 
 
1740
class SmartSSHTransport(SmartTransport):
 
1741
    """Connection to smart server over SSH.
 
1742
 
 
1743
    This is essentially just a factory to get 'RemoteTransport(url,
 
1744
        SmartSSHClientMedium).
 
1745
    """
 
1746
 
 
1747
    def __init__(self, url):
 
1748
        _scheme, _username, _password, _host, _port, _path = \
 
1749
            transport.split_url(url)
 
1750
        try:
 
1751
            if _port is not None:
 
1752
                _port = int(_port)
 
1753
        except (ValueError, TypeError), e:
 
1754
            raise errors.InvalidURL(path=url, extra="invalid port %s" % 
 
1755
                _port)
 
1756
        medium = SmartSSHClientMedium(_host, _port, _username, _password)
 
1757
        super(SmartSSHTransport, self).__init__(url, medium=medium)
 
1758
 
 
1759
 
 
1760
class SmartHTTPTransport(SmartTransport):
 
1761
    """Just a way to connect between a bzr+http:// url and http://.
 
1762
    
 
1763
    This connection operates slightly differently than the SmartSSHTransport.
 
1764
    It uses a plain http:// transport underneath, which defines what remote
 
1765
    .bzr/smart URL we are connected to. From there, all paths that are sent are
 
1766
    sent as relative paths, this way, the remote side can properly
 
1767
    de-reference them, since it is likely doing rewrite rules to translate an
 
1768
    HTTP path into a local path.
 
1769
    """
 
1770
 
 
1771
    def __init__(self, url, http_transport=None):
 
1772
        assert url.startswith('bzr+http://')
 
1773
 
 
1774
        if http_transport is None:
 
1775
            http_url = url[len('bzr+'):]
 
1776
            self._http_transport = transport.get_transport(http_url)
 
1777
        else:
 
1778
            self._http_transport = http_transport
 
1779
        http_medium = self._http_transport.get_smart_medium()
 
1780
        super(SmartHTTPTransport, self).__init__(url, medium=http_medium)
 
1781
 
 
1782
    def _remote_path(self, relpath):
 
1783
        """After connecting HTTP Transport only deals in relative URLs."""
 
1784
        if relpath == '.':
 
1785
            return ''
 
1786
        else:
 
1787
            return relpath
 
1788
 
 
1789
    def abspath(self, relpath):
 
1790
        """Return the full url to the given relative path.
 
1791
        
 
1792
        :param relpath: the relative path or path components
 
1793
        :type relpath: str or list
 
1794
        """
 
1795
        return self._unparse_url(self._combine_paths(self._path, relpath))
 
1796
 
 
1797
    def clone(self, relative_url):
 
1798
        """Make a new SmartHTTPTransport related to me.
 
1799
 
 
1800
        This is re-implemented rather than using the default
 
1801
        SmartTransport.clone() because we must be careful about the underlying
 
1802
        http transport.
 
1803
        """
 
1804
        if relative_url:
 
1805
            abs_url = self.abspath(relative_url)
 
1806
        else:
 
1807
            abs_url = self.base
 
1808
        # By cloning the underlying http_transport, we are able to share the
 
1809
        # connection.
 
1810
        new_transport = self._http_transport.clone(relative_url)
 
1811
        return SmartHTTPTransport(abs_url, http_transport=new_transport)
1062
1812
 
1063
1813
 
1064
1814
def get_test_permutations():
1065
 
    """Return (transport, server) permutations for testing"""
 
1815
    """Return (transport, server) permutations for testing."""
 
1816
    ### We may need a little more test framework support to construct an
 
1817
    ### appropriate RemoteTransport in the future.
1066
1818
    return [(SmartTCPTransport, SmartTCPServer_for_testing)]