~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/medium.py

  • Committer: Eric Holmberg
  • Date: 2008-05-06 15:02:27 UTC
  • mto: This revision was merged to the branch mainline in revision 3449.
  • Revision ID: eholmberg@arrow.com-20080506150227-l3arq1yntdvnoxum
Fix for Bug #215426 in which bzr can cause a MemoryError in socket.recv while
downloading large packs over http.  This patch limits the request size for
socket.recv to avoid this problem.

Changes:
Added mock file object bzrlib.tests.file_utils.
Added new parameters to bzrlib.osutils.pumpfile.
Added unit tests for bzrlib.osutils.pumpfile.
Added usage of bzrlib.osutils.pumpfile to bzrlib.transport.http.response.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
"""The 'medium' layer for the smart servers and clients.
 
18
 
 
19
"Medium" here is the noun meaning "a means of transmission", not the adjective
 
20
for "the quality between big and small."
 
21
 
 
22
Media carry the bytes of the requests somehow (e.g. via TCP, wrapped in HTTP, or
 
23
over SSH), and pass them to and from the protocol logic.  See the overview in
 
24
bzrlib/transport/smart/__init__.py.
 
25
"""
 
26
 
 
27
import os
 
28
import socket
 
29
import sys
 
30
 
 
31
from bzrlib import (
 
32
    errors,
 
33
    osutils,
 
34
    symbol_versioning,
 
35
    )
 
36
from bzrlib.smart.protocol import (
 
37
    REQUEST_VERSION_TWO,
 
38
    SmartClientRequestProtocolOne,
 
39
    SmartServerRequestProtocolOne,
 
40
    SmartServerRequestProtocolTwo,
 
41
    )
 
42
from bzrlib.transport import ssh
 
43
 
 
44
 
 
45
class SmartServerStreamMedium(object):
 
46
    """Handles smart commands coming over a stream.
 
47
 
 
48
    The stream may be a pipe connected to sshd, or a tcp socket, or an
 
49
    in-process fifo for testing.
 
50
 
 
51
    One instance is created for each connected client; it can serve multiple
 
52
    requests in the lifetime of the connection.
 
53
 
 
54
    The server passes requests through to an underlying backing transport, 
 
55
    which will typically be a LocalTransport looking at the server's filesystem.
 
56
 
 
57
    :ivar _push_back_buffer: a str of bytes that have been read from the stream
 
58
        but not used yet, or None if there are no buffered bytes.  Subclasses
 
59
        should make sure to exhaust this buffer before reading more bytes from
 
60
        the stream.  See also the _push_back method.
 
61
    """
 
62
 
 
63
    def __init__(self, backing_transport, root_client_path='/'):
 
64
        """Construct new server.
 
65
 
 
66
        :param backing_transport: Transport for the directory served.
 
67
        """
 
68
        # backing_transport could be passed to serve instead of __init__
 
69
        self.backing_transport = backing_transport
 
70
        self.root_client_path = root_client_path
 
71
        self.finished = False
 
72
        self._push_back_buffer = None
 
73
 
 
74
    def _push_back(self, bytes):
 
75
        """Return unused bytes to the medium, because they belong to the next
 
76
        request(s).
 
77
 
 
78
        This sets the _push_back_buffer to the given bytes.
 
79
        """
 
80
        assert self._push_back_buffer is None, (
 
81
            "_push_back called when self._push_back_buffer is %r"
 
82
            % (self._push_back_buffer,))
 
83
        if bytes == '':
 
84
            return
 
85
        self._push_back_buffer = bytes
 
86
 
 
87
    def _get_push_back_buffer(self):
 
88
        assert self._push_back_buffer != '', (
 
89
            '%s._push_back_buffer should never be the empty string, '
 
90
            'which can be confused with EOF' % (self,))
 
91
        bytes = self._push_back_buffer
 
92
        self._push_back_buffer = None
 
93
        return bytes
 
94
 
 
95
    def serve(self):
 
96
        """Serve requests until the client disconnects."""
 
97
        # Keep a reference to stderr because the sys module's globals get set to
 
98
        # None during interpreter shutdown.
 
99
        from sys import stderr
 
100
        try:
 
101
            while not self.finished:
 
102
                server_protocol = self._build_protocol()
 
103
                self._serve_one_request(server_protocol)
 
104
        except Exception, e:
 
105
            stderr.write("%s terminating on exception %s\n" % (self, e))
 
106
            raise
 
107
 
 
108
    def _build_protocol(self):
 
109
        """Identifies the version of the incoming request, and returns an
 
110
        a protocol object that can interpret it.
 
111
 
 
112
        If more bytes than the version prefix of the request are read, they will
 
113
        be fed into the protocol before it is returned.
 
114
 
 
115
        :returns: a SmartServerRequestProtocol.
 
116
        """
 
117
        # Identify the protocol version.
 
118
        bytes = self._get_line()
 
119
        if bytes.startswith(REQUEST_VERSION_TWO):
 
120
            protocol_class = SmartServerRequestProtocolTwo
 
121
            bytes = bytes[len(REQUEST_VERSION_TWO):]
 
122
        else:
 
123
            protocol_class = SmartServerRequestProtocolOne
 
124
        protocol = protocol_class(
 
125
            self.backing_transport, self._write_out, self.root_client_path)
 
126
        protocol.accept_bytes(bytes)
 
127
        return protocol
 
128
 
 
129
    def _serve_one_request(self, protocol):
 
130
        """Read one request from input, process, send back a response.
 
131
        
 
132
        :param protocol: a SmartServerRequestProtocol.
 
133
        """
 
134
        try:
 
135
            self._serve_one_request_unguarded(protocol)
 
136
        except KeyboardInterrupt:
 
137
            raise
 
138
        except Exception, e:
 
139
            self.terminate_due_to_error()
 
140
 
 
141
    def terminate_due_to_error(self):
 
142
        """Called when an unhandled exception from the protocol occurs."""
 
143
        raise NotImplementedError(self.terminate_due_to_error)
 
144
 
 
145
    def _get_bytes(self, desired_count):
 
146
        """Get some bytes from the medium.
 
147
 
 
148
        :param desired_count: number of bytes we want to read.
 
149
        """
 
150
        raise NotImplementedError(self._get_bytes)
 
151
 
 
152
    def _get_line(self):
 
153
        """Read bytes from this request's response until a newline byte.
 
154
        
 
155
        This isn't particularly efficient, so should only be used when the
 
156
        expected size of the line is quite short.
 
157
 
 
158
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
159
        """
 
160
        newline_pos = -1
 
161
        bytes = ''
 
162
        while newline_pos == -1:
 
163
            new_bytes = self._get_bytes(1)
 
164
            bytes += new_bytes
 
165
            if new_bytes == '':
 
166
                # Ran out of bytes before receiving a complete line.
 
167
                return bytes
 
168
            newline_pos = bytes.find('\n')
 
169
        line = bytes[:newline_pos+1]
 
170
        self._push_back(bytes[newline_pos+1:])
 
171
        return line
 
172
 
 
173
 
 
174
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
 
175
 
 
176
    def __init__(self, sock, backing_transport, root_client_path='/'):
 
177
        """Constructor.
 
178
 
 
179
        :param sock: the socket the server will read from.  It will be put
 
180
            into blocking mode.
 
181
        """
 
182
        SmartServerStreamMedium.__init__(
 
183
            self, backing_transport, root_client_path=root_client_path)
 
184
        sock.setblocking(True)
 
185
        self.socket = sock
 
186
 
 
187
    def _serve_one_request_unguarded(self, protocol):
 
188
        while protocol.next_read_size():
 
189
            bytes = self._get_bytes(4096)
 
190
            if bytes == '':
 
191
                self.finished = True
 
192
                return
 
193
            protocol.accept_bytes(bytes)
 
194
        
 
195
        self._push_back(protocol.excess_buffer)
 
196
 
 
197
    def _get_bytes(self, desired_count):
 
198
        if self._push_back_buffer is not None:
 
199
            return self._get_push_back_buffer()
 
200
        # We ignore the desired_count because on sockets it's more efficient to
 
201
        # read 4k at a time.
 
202
        return self.socket.recv(4096)
 
203
    
 
204
    def terminate_due_to_error(self):
 
205
        """Called when an unhandled exception from the protocol occurs."""
 
206
        # TODO: This should log to a server log file, but no such thing
 
207
        # exists yet.  Andrew Bennetts 2006-09-29.
 
208
        self.socket.close()
 
209
        self.finished = True
 
210
 
 
211
    def _write_out(self, bytes):
 
212
        osutils.send_all(self.socket, bytes)
 
213
 
 
214
 
 
215
class SmartServerPipeStreamMedium(SmartServerStreamMedium):
 
216
 
 
217
    def __init__(self, in_file, out_file, backing_transport):
 
218
        """Construct new server.
 
219
 
 
220
        :param in_file: Python file from which requests can be read.
 
221
        :param out_file: Python file to write responses.
 
222
        :param backing_transport: Transport for the directory served.
 
223
        """
 
224
        SmartServerStreamMedium.__init__(self, backing_transport)
 
225
        if sys.platform == 'win32':
 
226
            # force binary mode for files
 
227
            import msvcrt
 
228
            for f in (in_file, out_file):
 
229
                fileno = getattr(f, 'fileno', None)
 
230
                if fileno:
 
231
                    msvcrt.setmode(fileno(), os.O_BINARY)
 
232
        self._in = in_file
 
233
        self._out = out_file
 
234
 
 
235
    def _serve_one_request_unguarded(self, protocol):
 
236
        while True:
 
237
            bytes_to_read = protocol.next_read_size()
 
238
            if bytes_to_read == 0:
 
239
                # Finished serving this request.
 
240
                self._out.flush()
 
241
                return
 
242
            bytes = self._get_bytes(bytes_to_read)
 
243
            if bytes == '':
 
244
                # Connection has been closed.
 
245
                self.finished = True
 
246
                self._out.flush()
 
247
                return
 
248
            protocol.accept_bytes(bytes)
 
249
 
 
250
    def _get_bytes(self, desired_count):
 
251
        if self._push_back_buffer is not None:
 
252
            return self._get_push_back_buffer()
 
253
        return self._in.read(desired_count)
 
254
 
 
255
    def terminate_due_to_error(self):
 
256
        # TODO: This should log to a server log file, but no such thing
 
257
        # exists yet.  Andrew Bennetts 2006-09-29.
 
258
        self._out.close()
 
259
        self.finished = True
 
260
 
 
261
    def _write_out(self, bytes):
 
262
        self._out.write(bytes)
 
263
 
 
264
 
 
265
class SmartClientMediumRequest(object):
 
266
    """A request on a SmartClientMedium.
 
267
 
 
268
    Each request allows bytes to be provided to it via accept_bytes, and then
 
269
    the response bytes to be read via read_bytes.
 
270
 
 
271
    For instance:
 
272
    request.accept_bytes('123')
 
273
    request.finished_writing()
 
274
    result = request.read_bytes(3)
 
275
    request.finished_reading()
 
276
 
 
277
    It is up to the individual SmartClientMedium whether multiple concurrent
 
278
    requests can exist. See SmartClientMedium.get_request to obtain instances 
 
279
    of SmartClientMediumRequest, and the concrete Medium you are using for 
 
280
    details on concurrency and pipelining.
 
281
    """
 
282
 
 
283
    def __init__(self, medium):
 
284
        """Construct a SmartClientMediumRequest for the medium medium."""
 
285
        self._medium = medium
 
286
        # we track state by constants - we may want to use the same
 
287
        # pattern as BodyReader if it gets more complex.
 
288
        # valid states are: "writing", "reading", "done"
 
289
        self._state = "writing"
 
290
 
 
291
    def accept_bytes(self, bytes):
 
292
        """Accept bytes for inclusion in this request.
 
293
 
 
294
        This method may not be be called after finished_writing() has been
 
295
        called.  It depends upon the Medium whether or not the bytes will be
 
296
        immediately transmitted. Message based Mediums will tend to buffer the
 
297
        bytes until finished_writing() is called.
 
298
 
 
299
        :param bytes: A bytestring.
 
300
        """
 
301
        if self._state != "writing":
 
302
            raise errors.WritingCompleted(self)
 
303
        self._accept_bytes(bytes)
 
304
 
 
305
    def _accept_bytes(self, bytes):
 
306
        """Helper for accept_bytes.
 
307
 
 
308
        Accept_bytes checks the state of the request to determing if bytes
 
309
        should be accepted. After that it hands off to _accept_bytes to do the
 
310
        actual acceptance.
 
311
        """
 
312
        raise NotImplementedError(self._accept_bytes)
 
313
 
 
314
    def finished_reading(self):
 
315
        """Inform the request that all desired data has been read.
 
316
 
 
317
        This will remove the request from the pipeline for its medium (if the
 
318
        medium supports pipelining) and any further calls to methods on the
 
319
        request will raise ReadingCompleted.
 
320
        """
 
321
        if self._state == "writing":
 
322
            raise errors.WritingNotComplete(self)
 
323
        if self._state != "reading":
 
324
            raise errors.ReadingCompleted(self)
 
325
        self._state = "done"
 
326
        self._finished_reading()
 
327
 
 
328
    def _finished_reading(self):
 
329
        """Helper for finished_reading.
 
330
 
 
331
        finished_reading checks the state of the request to determine if 
 
332
        finished_reading is allowed, and if it is hands off to _finished_reading
 
333
        to perform the action.
 
334
        """
 
335
        raise NotImplementedError(self._finished_reading)
 
336
 
 
337
    def finished_writing(self):
 
338
        """Finish the writing phase of this request.
 
339
 
 
340
        This will flush all pending data for this request along the medium.
 
341
        After calling finished_writing, you may not call accept_bytes anymore.
 
342
        """
 
343
        if self._state != "writing":
 
344
            raise errors.WritingCompleted(self)
 
345
        self._state = "reading"
 
346
        self._finished_writing()
 
347
 
 
348
    def _finished_writing(self):
 
349
        """Helper for finished_writing.
 
350
 
 
351
        finished_writing checks the state of the request to determine if 
 
352
        finished_writing is allowed, and if it is hands off to _finished_writing
 
353
        to perform the action.
 
354
        """
 
355
        raise NotImplementedError(self._finished_writing)
 
356
 
 
357
    def read_bytes(self, count):
 
358
        """Read bytes from this requests response.
 
359
 
 
360
        This method will block and wait for count bytes to be read. It may not
 
361
        be invoked until finished_writing() has been called - this is to ensure
 
362
        a message-based approach to requests, for compatibility with message
 
363
        based mediums like HTTP.
 
364
        """
 
365
        if self._state == "writing":
 
366
            raise errors.WritingNotComplete(self)
 
367
        if self._state != "reading":
 
368
            raise errors.ReadingCompleted(self)
 
369
        return self._read_bytes(count)
 
370
 
 
371
    def _read_bytes(self, count):
 
372
        """Helper for read_bytes.
 
373
 
 
374
        read_bytes checks the state of the request to determing if bytes
 
375
        should be read. After that it hands off to _read_bytes to do the
 
376
        actual read.
 
377
        """
 
378
        raise NotImplementedError(self._read_bytes)
 
379
 
 
380
    def read_line(self):
 
381
        """Read bytes from this request's response until a newline byte.
 
382
        
 
383
        This isn't particularly efficient, so should only be used when the
 
384
        expected size of the line is quite short.
 
385
 
 
386
        :returns: a string of bytes ending in a newline (byte 0x0A).
 
387
        """
 
388
        # XXX: this duplicates SmartClientRequestProtocolOne._recv_tuple
 
389
        line = ''
 
390
        while not line or line[-1] != '\n':
 
391
            new_char = self.read_bytes(1)
 
392
            line += new_char
 
393
            if new_char == '':
 
394
                # end of file encountered reading from server
 
395
                raise errors.ConnectionReset(
 
396
                    "please check connectivity and permissions",
 
397
                    "(and try -Dhpss if further diagnosis is required)")
 
398
        return line
 
399
 
 
400
 
 
401
class SmartClientMedium(object):
 
402
    """Smart client is a medium for sending smart protocol requests over."""
 
403
 
 
404
    def __init__(self):
 
405
        super(SmartClientMedium, self).__init__()
 
406
        self._protocol_version_error = None
 
407
        self._protocol_version = None
 
408
 
 
409
    def protocol_version(self):
 
410
        """Find out the best protocol version to use."""
 
411
        if self._protocol_version_error is not None:
 
412
            raise self._protocol_version_error
 
413
        if self._protocol_version is None:
 
414
            try:
 
415
                medium_request = self.get_request()
 
416
                # Send a 'hello' request in protocol version one, for maximum
 
417
                # backwards compatibility.
 
418
                client_protocol = SmartClientRequestProtocolOne(medium_request)
 
419
                self._protocol_version = client_protocol.query_version()
 
420
            except errors.SmartProtocolError, e:
 
421
                # Cache the error, just like we would cache a successful
 
422
                # result.
 
423
                self._protocol_version_error = e
 
424
                raise
 
425
        return self._protocol_version
 
426
 
 
427
    def disconnect(self):
 
428
        """If this medium maintains a persistent connection, close it.
 
429
        
 
430
        The default implementation does nothing.
 
431
        """
 
432
        
 
433
 
 
434
class SmartClientStreamMedium(SmartClientMedium):
 
435
    """Stream based medium common class.
 
436
 
 
437
    SmartClientStreamMediums operate on a stream. All subclasses use a common
 
438
    SmartClientStreamMediumRequest for their requests, and should implement
 
439
    _accept_bytes and _read_bytes to allow the request objects to send and
 
440
    receive bytes.
 
441
    """
 
442
 
 
443
    def __init__(self):
 
444
        SmartClientMedium.__init__(self)
 
445
        self._current_request = None
 
446
        # Be optimistic: we assume the remote end can accept new remote
 
447
        # requests until we get an error saying otherwise.  (1.2 adds some
 
448
        # requests that send bodies, which confuses older servers.)
 
449
        self._remote_is_at_least_1_2 = True
 
450
 
 
451
    def accept_bytes(self, bytes):
 
452
        self._accept_bytes(bytes)
 
453
 
 
454
    def __del__(self):
 
455
        """The SmartClientStreamMedium knows how to close the stream when it is
 
456
        finished with it.
 
457
        """
 
458
        self.disconnect()
 
459
 
 
460
    def _flush(self):
 
461
        """Flush the output stream.
 
462
        
 
463
        This method is used by the SmartClientStreamMediumRequest to ensure that
 
464
        all data for a request is sent, to avoid long timeouts or deadlocks.
 
465
        """
 
466
        raise NotImplementedError(self._flush)
 
467
 
 
468
    def get_request(self):
 
469
        """See SmartClientMedium.get_request().
 
470
 
 
471
        SmartClientStreamMedium always returns a SmartClientStreamMediumRequest
 
472
        for get_request.
 
473
        """
 
474
        return SmartClientStreamMediumRequest(self)
 
475
 
 
476
    def read_bytes(self, count):
 
477
        return self._read_bytes(count)
 
478
 
 
479
 
 
480
class SmartSimplePipesClientMedium(SmartClientStreamMedium):
 
481
    """A client medium using simple pipes.
 
482
    
 
483
    This client does not manage the pipes: it assumes they will always be open.
 
484
    """
 
485
 
 
486
    def __init__(self, readable_pipe, writeable_pipe):
 
487
        SmartClientStreamMedium.__init__(self)
 
488
        self._readable_pipe = readable_pipe
 
489
        self._writeable_pipe = writeable_pipe
 
490
 
 
491
    def _accept_bytes(self, bytes):
 
492
        """See SmartClientStreamMedium.accept_bytes."""
 
493
        self._writeable_pipe.write(bytes)
 
494
 
 
495
    def _flush(self):
 
496
        """See SmartClientStreamMedium._flush()."""
 
497
        self._writeable_pipe.flush()
 
498
 
 
499
    def _read_bytes(self, count):
 
500
        """See SmartClientStreamMedium._read_bytes."""
 
501
        return self._readable_pipe.read(count)
 
502
 
 
503
 
 
504
class SmartSSHClientMedium(SmartClientStreamMedium):
 
505
    """A client medium using SSH."""
 
506
    
 
507
    def __init__(self, host, port=None, username=None, password=None,
 
508
            vendor=None, bzr_remote_path=None):
 
509
        """Creates a client that will connect on the first use.
 
510
        
 
511
        :param vendor: An optional override for the ssh vendor to use. See
 
512
            bzrlib.transport.ssh for details on ssh vendors.
 
513
        """
 
514
        SmartClientStreamMedium.__init__(self)
 
515
        self._connected = False
 
516
        self._host = host
 
517
        self._password = password
 
518
        self._port = port
 
519
        self._username = username
 
520
        self._read_from = None
 
521
        self._ssh_connection = None
 
522
        self._vendor = vendor
 
523
        self._write_to = None
 
524
        self._bzr_remote_path = bzr_remote_path
 
525
        if self._bzr_remote_path is None:
 
526
            symbol_versioning.warn(
 
527
                'bzr_remote_path is required as of bzr 0.92',
 
528
                DeprecationWarning, stacklevel=2)
 
529
            self._bzr_remote_path = os.environ.get('BZR_REMOTE_PATH', 'bzr')
 
530
 
 
531
    def _accept_bytes(self, bytes):
 
532
        """See SmartClientStreamMedium.accept_bytes."""
 
533
        self._ensure_connection()
 
534
        self._write_to.write(bytes)
 
535
 
 
536
    def disconnect(self):
 
537
        """See SmartClientMedium.disconnect()."""
 
538
        if not self._connected:
 
539
            return
 
540
        self._read_from.close()
 
541
        self._write_to.close()
 
542
        self._ssh_connection.close()
 
543
        self._connected = False
 
544
 
 
545
    def _ensure_connection(self):
 
546
        """Connect this medium if not already connected."""
 
547
        if self._connected:
 
548
            return
 
549
        if self._vendor is None:
 
550
            vendor = ssh._get_ssh_vendor()
 
551
        else:
 
552
            vendor = self._vendor
 
553
        self._ssh_connection = vendor.connect_ssh(self._username,
 
554
                self._password, self._host, self._port,
 
555
                command=[self._bzr_remote_path, 'serve', '--inet',
 
556
                         '--directory=/', '--allow-writes'])
 
557
        self._read_from, self._write_to = \
 
558
            self._ssh_connection.get_filelike_channels()
 
559
        self._connected = True
 
560
 
 
561
    def _flush(self):
 
562
        """See SmartClientStreamMedium._flush()."""
 
563
        self._write_to.flush()
 
564
 
 
565
    def _read_bytes(self, count):
 
566
        """See SmartClientStreamMedium.read_bytes."""
 
567
        if not self._connected:
 
568
            raise errors.MediumNotConnected(self)
 
569
        return self._read_from.read(count)
 
570
 
 
571
 
 
572
# Port 4155 is the default port for bzr://, registered with IANA.
 
573
BZR_DEFAULT_INTERFACE = '0.0.0.0'
 
574
BZR_DEFAULT_PORT = 4155
 
575
 
 
576
 
 
577
class SmartTCPClientMedium(SmartClientStreamMedium):
 
578
    """A client medium using TCP."""
 
579
    
 
580
    def __init__(self, host, port):
 
581
        """Creates a client that will connect on the first use."""
 
582
        SmartClientStreamMedium.__init__(self)
 
583
        self._connected = False
 
584
        self._host = host
 
585
        self._port = port
 
586
        self._socket = None
 
587
 
 
588
    def _accept_bytes(self, bytes):
 
589
        """See SmartClientMedium.accept_bytes."""
 
590
        self._ensure_connection()
 
591
        osutils.send_all(self._socket, bytes)
 
592
 
 
593
    def disconnect(self):
 
594
        """See SmartClientMedium.disconnect()."""
 
595
        if not self._connected:
 
596
            return
 
597
        self._socket.close()
 
598
        self._socket = None
 
599
        self._connected = False
 
600
 
 
601
    def _ensure_connection(self):
 
602
        """Connect this medium if not already connected."""
 
603
        if self._connected:
 
604
            return
 
605
        self._socket = socket.socket()
 
606
        self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
 
607
        if self._port is None:
 
608
            port = BZR_DEFAULT_PORT
 
609
        else:
 
610
            port = int(self._port)
 
611
        try:
 
612
            self._socket.connect((self._host, port))
 
613
        except socket.error, err:
 
614
            # socket errors either have a (string) or (errno, string) as their
 
615
            # args.
 
616
            if type(err.args) is str:
 
617
                err_msg = err.args
 
618
            else:
 
619
                err_msg = err.args[1]
 
620
            raise errors.ConnectionError("failed to connect to %s:%d: %s" %
 
621
                    (self._host, port, err_msg))
 
622
        self._connected = True
 
623
 
 
624
    def _flush(self):
 
625
        """See SmartClientStreamMedium._flush().
 
626
        
 
627
        For TCP we do no flushing. We may want to turn off TCP_NODELAY and 
 
628
        add a means to do a flush, but that can be done in the future.
 
629
        """
 
630
 
 
631
    def _read_bytes(self, count):
 
632
        """See SmartClientMedium.read_bytes."""
 
633
        if not self._connected:
 
634
            raise errors.MediumNotConnected(self)
 
635
        return self._socket.recv(count)
 
636
 
 
637
 
 
638
class SmartClientStreamMediumRequest(SmartClientMediumRequest):
 
639
    """A SmartClientMediumRequest that works with an SmartClientStreamMedium."""
 
640
 
 
641
    def __init__(self, medium):
 
642
        SmartClientMediumRequest.__init__(self, medium)
 
643
        # check that we are safe concurrency wise. If some streams start
 
644
        # allowing concurrent requests - i.e. via multiplexing - then this
 
645
        # assert should be moved to SmartClientStreamMedium.get_request,
 
646
        # and the setting/unsetting of _current_request likewise moved into
 
647
        # that class : but its unneeded overhead for now. RBC 20060922
 
648
        if self._medium._current_request is not None:
 
649
            raise errors.TooManyConcurrentRequests(self._medium)
 
650
        self._medium._current_request = self
 
651
 
 
652
    def _accept_bytes(self, bytes):
 
653
        """See SmartClientMediumRequest._accept_bytes.
 
654
        
 
655
        This forwards to self._medium._accept_bytes because we are operating
 
656
        on the mediums stream.
 
657
        """
 
658
        self._medium._accept_bytes(bytes)
 
659
 
 
660
    def _finished_reading(self):
 
661
        """See SmartClientMediumRequest._finished_reading.
 
662
 
 
663
        This clears the _current_request on self._medium to allow a new 
 
664
        request to be created.
 
665
        """
 
666
        assert self._medium._current_request is self
 
667
        self._medium._current_request = None
 
668
        
 
669
    def _finished_writing(self):
 
670
        """See SmartClientMediumRequest._finished_writing.
 
671
 
 
672
        This invokes self._medium._flush to ensure all bytes are transmitted.
 
673
        """
 
674
        self._medium._flush()
 
675
 
 
676
    def _read_bytes(self, count):
 
677
        """See SmartClientMediumRequest._read_bytes.
 
678
        
 
679
        This forwards to self._medium._read_bytes because we are operating
 
680
        on the mediums stream.
 
681
        """
 
682
        return self._medium._read_bytes(count)
 
683