~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

Merge trunk.

Show diffs side-by-side

added added

removed removed

Lines of Context:
30
30
 
31
31
from bzrlib import (
32
32
    errors,
 
33
    osutils,
33
34
    symbol_versioning,
34
35
    )
35
36
from bzrlib.smart.protocol import (
36
37
    REQUEST_VERSION_TWO,
 
38
    SmartClientRequestProtocolOne,
37
39
    SmartServerRequestProtocolOne,
38
40
    SmartServerRequestProtocolTwo,
39
41
    )
40
 
 
41
 
try:
42
 
    from bzrlib.transport import ssh
43
 
except errors.ParamikoNotPresent:
44
 
    # no paramiko.  SmartSSHClientMedium will break.
45
 
    pass
 
42
from bzrlib.transport import ssh
46
43
 
47
44
 
48
45
class SmartServerStreamMedium(object):
56
53
 
57
54
    The server passes requests through to an underlying backing transport, 
58
55
    which will typically be a LocalTransport looking at the server's filesystem.
 
56
 
 
57
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
58
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
59
        should make sure to exhaust this buffer before reading more bytes from
 
60
        the stream.  See also the _push_back method.
59
61
    """
60
62
 
61
 
    def __init__(self, backing_transport):
 
63
    def __init__(self, backing_transport, root_client_path='/'):
62
64
        """Construct new server.
63
65
 
64
66
        :param backing_transport: Transport for the directory served.
65
67
        """
66
68
        # backing_transport could be passed to serve instead of __init__
67
69
        self.backing_transport = backing_transport
 
70
        self.root_client_path = root_client_path
68
71
        self.finished = False
 
72
        self._push_back_buffer = None
 
73
 
 
74
    def _push_back(self, bytes):
 
75
        """Return unused bytes to the medium, because they belong to the next
 
76
        request(s).
 
77
 
 
78
        This sets the _push_back_buffer to the given bytes.
 
79
        """
 
80
        assert self._push_back_buffer is None, (
 
81
            "_push_back called when self._push_back_buffer is %r"
 
82
            % (self._push_back_buffer,))
 
83
        if bytes == '':
 
84
            return
 
85
        self._push_back_buffer = bytes
 
86
 
 
87
    def _get_push_back_buffer(self):
 
88
        assert self._push_back_buffer != '', (
 
89
            '%s._push_back_buffer should never be the empty string, '
 
90
            'which can be confused with EOF' % (self,))
 
91
        bytes = self._push_back_buffer
 
92
        self._push_back_buffer = None
 
93
        return bytes
69
94
 
70
95
    def serve(self):
71
96
        """Serve requests until the client disconnects."""
96
121
            bytes = bytes[len(REQUEST_VERSION_TWO):]
97
122
        else:
98
123
            protocol_class = SmartServerRequestProtocolOne
99
 
        protocol = protocol_class(self.backing_transport, self._write_out)
 
124
        protocol = protocol_class(
 
125
            self.backing_transport, self._write_out, self.root_client_path)
100
126
        protocol.accept_bytes(bytes)
101
127
        return protocol
102
128
 
131
157
 
132
158
        :returns: a string of bytes ending in a newline (byte 0x0A).
133
159
        """
134
 
        # XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
135
 
        line = ''
136
 
        while not line or line[-1] != '\n':
137
 
            new_char = self._get_bytes(1)
138
 
            line += new_char
139
 
            if new_char == '':
 
160
        newline_pos = -1
 
161
        bytes = ''
 
162
        while newline_pos == -1:
 
163
            new_bytes = self._get_bytes(1)
 
164
            bytes += new_bytes
 
165
            if new_bytes == '':
140
166
                # Ran out of bytes before receiving a complete line.
141
 
                break
 
167
                return bytes
 
168
            newline_pos = bytes.find('\n')
 
169
        line = bytes[:newline_pos+1]
 
170
        self._push_back(bytes[newline_pos+1:])
142
171
        return line
143
 
 
 
172
 
144
173
 
145
174
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
146
175
 
147
 
    def __init__(self, sock, backing_transport):
 
176
    def __init__(self, sock, backing_transport, root_client_path='/'):
148
177
        """Constructor.
149
178
 
150
179
        :param sock: the socket the server will read from.  It will be put
151
180
            into blocking mode.
152
181
        """
153
 
        SmartServerStreamMedium.__init__(self, backing_transport)
154
 
        self.push_back = ''
 
182
        SmartServerStreamMedium.__init__(
 
183
            self, backing_transport, root_client_path=root_client_path)
155
184
        sock.setblocking(True)
156
185
        self.socket = sock
157
186
 
158
187
    def _serve_one_request_unguarded(self, protocol):
159
188
        while protocol.next_read_size():
160
 
            if self.push_back:
161
 
                protocol.accept_bytes(self.push_back)
162
 
                self.push_back = ''
163
 
            else:
164
 
                bytes = self._get_bytes(4096)
165
 
                if bytes == '':
166
 
                    self.finished = True
167
 
                    return
168
 
                protocol.accept_bytes(bytes)
 
189
            bytes = self._get_bytes(4096)
 
190
            if bytes == '':
 
191
                self.finished = True
 
192
                return
 
193
            protocol.accept_bytes(bytes)
169
194
        
170
 
        self.push_back = protocol.excess_buffer
 
195
        self._push_back(protocol.excess_buffer)
171
196
 
172
197
    def _get_bytes(self, desired_count):
 
198
        if self._push_back_buffer is not None:
 
199
            return self._get_push_back_buffer()
173
200
        # We ignore the desired_count because on sockets it's more efficient to
174
201
        # read 4k at a time.
175
202
        return self.socket.recv(4096)
182
209
        self.finished = True
183
210
 
184
211
    def _write_out(self, bytes):
185
 
        self.socket.sendall(bytes)
 
212
        osutils.send_all(self.socket, bytes)
186
213
 
187
214
 
188
215
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
221
248
            protocol.accept_bytes(bytes)
222
249
 
223
250
    def _get_bytes(self, desired_count):
 
251
        if self._push_back_buffer is not None:
 
252
            return self._get_push_back_buffer()
224
253
        return self._in.read(desired_count)
225
254
 
226
255
    def terminate_due_to_error(self):
362
391
            new_char = self.read_bytes(1)
363
392
            line += new_char
364
393
            if new_char == '':
365
 
                raise errors.SmartProtocolError(
366
 
                    'unexpected end of file reading from server')
 
394
                # end of file encountered reading from server
 
395
                raise errors.ConnectionReset(
 
396
                    "please check connectivity and permissions",
 
397
                    "(and try -Dhpss if further diagnosis is required)")
367
398
        return line
368
399
 
369
400
 
370
401
class SmartClientMedium(object):
371
402
    """Smart client is a medium for sending smart protocol requests over."""
372
403
 
 
404
    def __init__(self):
 
405
        super(SmartClientMedium, self).__init__()
 
406
        self._protocol_version_error = None
 
407
        self._protocol_version = None
 
408
 
 
409
    def protocol_version(self):
 
410
        """Find out the best protocol version to use."""
 
411
        if self._protocol_version_error is not None:
 
412
            raise self._protocol_version_error
 
413
        if self._protocol_version is None:
 
414
            try:
 
415
                medium_request = self.get_request()
 
416
                # Send a 'hello' request in protocol version one, for maximum
 
417
                # backwards compatibility.
 
418
                client_protocol = SmartClientRequestProtocolOne(medium_request)
 
419
                self._protocol_version = client_protocol.query_version()
 
420
            except errors.SmartProtocolError, e:
 
421
                # Cache the error, just like we would cache a successful
 
422
                # result.
 
423
                self._protocol_version_error = e
 
424
                raise
 
425
        return self._protocol_version
 
426
 
373
427
    def disconnect(self):
374
428
        """If this medium maintains a persistent connection, close it.
375
429
        
387
441
    """
388
442
 
389
443
    def __init__(self):
 
444
        SmartClientMedium.__init__(self)
390
445
        self._current_request = None
 
446
        # Be optimistic: we assume the remote end can accept new remote
 
447
        # requests until we get an error saying otherwise.  (1.2 adds some
 
448
        # requests that send bodies, which confuses older servers.)
 
449
        self._remote_is_at_least_1_2 = True
391
450
 
392
451
    def accept_bytes(self, bytes):
393
452
        self._accept_bytes(bytes)
510
569
        return self._read_from.read(count)
511
570
 
512
571
 
 
572
# Port 4155 is the default port for bzr://, registered with IANA.
 
573
BZR_DEFAULT_INTERFACE = '0.0.0.0'
 
574
BZR_DEFAULT_PORT = 4155
 
575
 
 
576
 
513
577
class SmartTCPClientMedium(SmartClientStreamMedium):
514
578
    """A client medium using TCP."""
515
579
    
524
588
    def _accept_bytes(self, bytes):
525
589
        """See SmartClientMedium.accept_bytes."""
526
590
        self._ensure_connection()
527
 
        self._socket.sendall(bytes)
 
591
        osutils.send_all(self._socket, bytes)
528
592
 
529
593
    def disconnect(self):
530
594
        """See SmartClientMedium.disconnect()."""
540
604
            return
541
605
        self._socket = socket.socket()
542
606
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
543
 
        result = self._socket.connect_ex((self._host, int(self._port)))
544
 
        if result:
 
607
        if self._port is None:
 
608
            port = BZR_DEFAULT_PORT
 
609
        else:
 
610
            port = int(self._port)
 
611
        try:
 
612
            self._socket.connect((self._host, port))
 
613
        except socket.error, err:
 
614
            # socket errors either have a (string) or (errno, string) as their
 
615
            # args.
 
616
            if type(err.args) is str:
 
617
                err_msg = err.args
 
618
            else:
 
619
                err_msg = err.args[1]
545
620
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
546
 
                    (self._host, self._port, os.strerror(result)))
 
621
                    (self._host, port, err_msg))
547
622
        self._connected = True
548
623
 
549
624
    def _flush(self):