~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/repository.py

  • Committer: Martin Pool
  • Date: 2009-08-20 04:53:23 UTC
  • mto: This revision was merged to the branch mainline in revision 4632.
  • Revision ID: mbp@sourcefrog.net-20090820045323-4hsicfa87pdq3l29
Correction to xdg_cache_dir and add a simple test

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006, 2007 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Server-side repository related request implmentations."""
 
18
 
 
19
import bz2
 
20
import os
 
21
import Queue
 
22
import sys
 
23
import tarfile
 
24
import tempfile
 
25
import threading
 
26
 
 
27
from bzrlib import (
 
28
    bencode,
 
29
    errors,
 
30
    graph,
 
31
    osutils,
 
32
    pack,
 
33
    )
 
34
from bzrlib.bzrdir import BzrDir
 
35
from bzrlib.smart.request import (
 
36
    FailedSmartServerResponse,
 
37
    SmartServerRequest,
 
38
    SuccessfulSmartServerResponse,
 
39
    )
 
40
from bzrlib.repository import _strip_NULL_ghosts, network_format_registry
 
41
from bzrlib import revision as _mod_revision
 
42
from bzrlib.versionedfile import NetworkRecordStream, record_to_fulltext_bytes
 
43
 
 
44
 
 
45
class SmartServerRepositoryRequest(SmartServerRequest):
 
46
    """Common base class for Repository requests."""
 
47
 
 
48
    def do(self, path, *args):
 
49
        """Execute a repository request.
 
50
 
 
51
        All Repository requests take a path to the repository as their first
 
52
        argument.  The repository must be at the exact path given by the
 
53
        client - no searching is done.
 
54
 
 
55
        The actual logic is delegated to self.do_repository_request.
 
56
 
 
57
        :param client_path: The path for the repository as received from the
 
58
            client.
 
59
        :return: A SmartServerResponse from self.do_repository_request().
 
60
        """
 
61
        transport = self.transport_from_client_path(path)
 
62
        bzrdir = BzrDir.open_from_transport(transport)
 
63
        # Save the repository for use with do_body.
 
64
        self._repository = bzrdir.open_repository()
 
65
        return self.do_repository_request(self._repository, *args)
 
66
 
 
67
    def do_repository_request(self, repository, *args):
 
68
        """Override to provide an implementation for a verb."""
 
69
        # No-op for verbs that take bodies (None as a result indicates a body
 
70
        # is expected)
 
71
        return None
 
72
 
 
73
    def recreate_search(self, repository, search_bytes, discard_excess=False):
 
74
        """Recreate a search from its serialised form.
 
75
 
 
76
        :param discard_excess: If True, and the search refers to data we don't
 
77
            have, just silently accept that fact - the verb calling
 
78
            recreate_search trusts that clients will look for missing things
 
79
            they expected and get it from elsewhere.
 
80
        """
 
81
        lines = search_bytes.split('\n')
 
82
        if lines[0] == 'ancestry-of':
 
83
            heads = lines[1:]
 
84
            search_result = graph.PendingAncestryResult(heads, repository)
 
85
            return search_result, None
 
86
        elif lines[0] == 'search':
 
87
            return self.recreate_search_from_recipe(repository, lines[1:],
 
88
                discard_excess=discard_excess)
 
89
        else:
 
90
            return (None, FailedSmartServerResponse(('BadSearch',)))
 
91
 
 
92
    def recreate_search_from_recipe(self, repository, lines,
 
93
        discard_excess=False):
 
94
        """Recreate a specific revision search (vs a from-tip search).
 
95
 
 
96
        :param discard_excess: If True, and the search refers to data we don't
 
97
            have, just silently accept that fact - the verb calling
 
98
            recreate_search trusts that clients will look for missing things
 
99
            they expected and get it from elsewhere.
 
100
        """
 
101
        start_keys = set(lines[0].split(' '))
 
102
        exclude_keys = set(lines[1].split(' '))
 
103
        revision_count = int(lines[2])
 
104
        repository.lock_read()
 
105
        try:
 
106
            search = repository.get_graph()._make_breadth_first_searcher(
 
107
                start_keys)
 
108
            while True:
 
109
                try:
 
110
                    next_revs = search.next()
 
111
                except StopIteration:
 
112
                    break
 
113
                search.stop_searching_any(exclude_keys.intersection(next_revs))
 
114
            search_result = search.get_result()
 
115
            if (not discard_excess and
 
116
                search_result.get_recipe()[3] != revision_count):
 
117
                # we got back a different amount of data than expected, this
 
118
                # gets reported as NoSuchRevision, because less revisions
 
119
                # indicates missing revisions, and more should never happen as
 
120
                # the excludes list considers ghosts and ensures that ghost
 
121
                # filling races are not a problem.
 
122
                return (None, FailedSmartServerResponse(('NoSuchRevision',)))
 
123
            return (search_result, None)
 
124
        finally:
 
125
            repository.unlock()
 
126
 
 
127
 
 
128
class SmartServerRepositoryReadLocked(SmartServerRepositoryRequest):
 
129
    """Calls self.do_readlocked_repository_request."""
 
130
 
 
131
    def do_repository_request(self, repository, *args):
 
132
        """Read lock a repository for do_readlocked_repository_request."""
 
133
        repository.lock_read()
 
134
        try:
 
135
            return self.do_readlocked_repository_request(repository, *args)
 
136
        finally:
 
137
            repository.unlock()
 
138
 
 
139
 
 
140
class SmartServerRepositoryGetParentMap(SmartServerRepositoryRequest):
 
141
    """Bzr 1.2+ - get parent data for revisions during a graph search."""
 
142
 
 
143
    no_extra_results = False
 
144
 
 
145
    def do_repository_request(self, repository, *revision_ids):
 
146
        """Get parent details for some revisions.
 
147
 
 
148
        All the parents for revision_ids are returned. Additionally up to 64KB
 
149
        of additional parent data found by performing a breadth first search
 
150
        from revision_ids is returned. The verb takes a body containing the
 
151
        current search state, see do_body for details.
 
152
 
 
153
        If 'include-missing:' is in revision_ids, ghosts encountered in the
 
154
        graph traversal for getting parent data are included in the result with
 
155
        a prefix of 'missing:'.
 
156
 
 
157
        :param repository: The repository to query in.
 
158
        :param revision_ids: The utf8 encoded revision_id to answer for.
 
159
        """
 
160
        self._revision_ids = revision_ids
 
161
        return None # Signal that we want a body.
 
162
 
 
163
    def do_body(self, body_bytes):
 
164
        """Process the current search state and perform the parent lookup.
 
165
 
 
166
        :return: A smart server response where the body contains an utf8
 
167
            encoded flattened list of the parents of the revisions (the same
 
168
            format as Repository.get_revision_graph) which has been bz2
 
169
            compressed.
 
170
        """
 
171
        repository = self._repository
 
172
        repository.lock_read()
 
173
        try:
 
174
            return self._do_repository_request(body_bytes)
 
175
        finally:
 
176
            repository.unlock()
 
177
 
 
178
    def _do_repository_request(self, body_bytes):
 
179
        repository = self._repository
 
180
        revision_ids = set(self._revision_ids)
 
181
        include_missing = 'include-missing:' in revision_ids
 
182
        if include_missing:
 
183
            revision_ids.remove('include-missing:')
 
184
        body_lines = body_bytes.split('\n')
 
185
        search_result, error = self.recreate_search_from_recipe(
 
186
            repository, body_lines)
 
187
        if error is not None:
 
188
            return error
 
189
        # TODO might be nice to start up the search again; but thats not
 
190
        # written or tested yet.
 
191
        client_seen_revs = set(search_result.get_keys())
 
192
        # Always include the requested ids.
 
193
        client_seen_revs.difference_update(revision_ids)
 
194
        lines = []
 
195
        repo_graph = repository.get_graph()
 
196
        result = {}
 
197
        queried_revs = set()
 
198
        size_so_far = 0
 
199
        next_revs = revision_ids
 
200
        first_loop_done = False
 
201
        while next_revs:
 
202
            queried_revs.update(next_revs)
 
203
            parent_map = repo_graph.get_parent_map(next_revs)
 
204
            current_revs = next_revs
 
205
            next_revs = set()
 
206
            for revision_id in current_revs:
 
207
                missing_rev = False
 
208
                parents = parent_map.get(revision_id)
 
209
                if parents is not None:
 
210
                    # adjust for the wire
 
211
                    if parents == (_mod_revision.NULL_REVISION,):
 
212
                        parents = ()
 
213
                    # prepare the next query
 
214
                    next_revs.update(parents)
 
215
                    encoded_id = revision_id
 
216
                else:
 
217
                    missing_rev = True
 
218
                    encoded_id = "missing:" + revision_id
 
219
                    parents = []
 
220
                if (revision_id not in client_seen_revs and
 
221
                    (not missing_rev or include_missing)):
 
222
                    # Client does not have this revision, give it to it.
 
223
                    # add parents to the result
 
224
                    result[encoded_id] = parents
 
225
                    # Approximate the serialized cost of this revision_id.
 
226
                    size_so_far += 2 + len(encoded_id) + sum(map(len, parents))
 
227
            # get all the directly asked for parents, and then flesh out to
 
228
            # 64K (compressed) or so. We do one level of depth at a time to
 
229
            # stay in sync with the client. The 250000 magic number is
 
230
            # estimated compression ratio taken from bzr.dev itself.
 
231
            if self.no_extra_results or (
 
232
                first_loop_done and size_so_far > 250000):
 
233
                next_revs = set()
 
234
                break
 
235
            # don't query things we've already queried
 
236
            next_revs.difference_update(queried_revs)
 
237
            first_loop_done = True
 
238
 
 
239
        # sorting trivially puts lexographically similar revision ids together.
 
240
        # Compression FTW.
 
241
        for revision, parents in sorted(result.items()):
 
242
            lines.append(' '.join((revision, ) + tuple(parents)))
 
243
 
 
244
        return SuccessfulSmartServerResponse(
 
245
            ('ok', ), bz2.compress('\n'.join(lines)))
 
246
 
 
247
 
 
248
class SmartServerRepositoryGetRevisionGraph(SmartServerRepositoryReadLocked):
 
249
 
 
250
    def do_readlocked_repository_request(self, repository, revision_id):
 
251
        """Return the result of repository.get_revision_graph(revision_id).
 
252
 
 
253
        Deprecated as of bzr 1.4, but supported for older clients.
 
254
 
 
255
        :param repository: The repository to query in.
 
256
        :param revision_id: The utf8 encoded revision_id to get a graph from.
 
257
        :return: A smart server response where the body contains an utf8
 
258
            encoded flattened list of the revision graph.
 
259
        """
 
260
        if not revision_id:
 
261
            revision_id = None
 
262
 
 
263
        lines = []
 
264
        graph = repository.get_graph()
 
265
        if revision_id:
 
266
            search_ids = [revision_id]
 
267
        else:
 
268
            search_ids = repository.all_revision_ids()
 
269
        search = graph._make_breadth_first_searcher(search_ids)
 
270
        transitive_ids = set()
 
271
        map(transitive_ids.update, list(search))
 
272
        parent_map = graph.get_parent_map(transitive_ids)
 
273
        revision_graph = _strip_NULL_ghosts(parent_map)
 
274
        if revision_id and revision_id not in revision_graph:
 
275
            # Note that we return an empty body, rather than omitting the body.
 
276
            # This way the client knows that it can always expect to find a body
 
277
            # in the response for this method, even in the error case.
 
278
            return FailedSmartServerResponse(('nosuchrevision', revision_id), '')
 
279
 
 
280
        for revision, parents in revision_graph.items():
 
281
            lines.append(' '.join((revision, ) + tuple(parents)))
 
282
 
 
283
        return SuccessfulSmartServerResponse(('ok', ), '\n'.join(lines))
 
284
 
 
285
 
 
286
class SmartServerRepositoryGetRevIdForRevno(SmartServerRepositoryReadLocked):
 
287
 
 
288
    def do_readlocked_repository_request(self, repository, revno,
 
289
            known_pair):
 
290
        """Find the revid for a given revno, given a known revno/revid pair.
 
291
        
 
292
        New in 1.17.
 
293
        """
 
294
        try:
 
295
            found_flag, result = repository.get_rev_id_for_revno(revno, known_pair)
 
296
        except errors.RevisionNotPresent, err:
 
297
            if err.revision_id != known_pair[1]:
 
298
                raise AssertionError(
 
299
                    'get_rev_id_for_revno raised RevisionNotPresent for '
 
300
                    'non-initial revision: ' + err.revision_id)
 
301
            return FailedSmartServerResponse(
 
302
                ('nosuchrevision', err.revision_id))
 
303
        if found_flag:
 
304
            return SuccessfulSmartServerResponse(('ok', result))
 
305
        else:
 
306
            earliest_revno, earliest_revid = result
 
307
            return SuccessfulSmartServerResponse(
 
308
                ('history-incomplete', earliest_revno, earliest_revid))
 
309
 
 
310
 
 
311
class SmartServerRequestHasRevision(SmartServerRepositoryRequest):
 
312
 
 
313
    def do_repository_request(self, repository, revision_id):
 
314
        """Return ok if a specific revision is in the repository at path.
 
315
 
 
316
        :param repository: The repository to query in.
 
317
        :param revision_id: The utf8 encoded revision_id to lookup.
 
318
        :return: A smart server response of ('ok', ) if the revision is
 
319
            present.
 
320
        """
 
321
        if repository.has_revision(revision_id):
 
322
            return SuccessfulSmartServerResponse(('yes', ))
 
323
        else:
 
324
            return SuccessfulSmartServerResponse(('no', ))
 
325
 
 
326
 
 
327
class SmartServerRepositoryGatherStats(SmartServerRepositoryRequest):
 
328
 
 
329
    def do_repository_request(self, repository, revid, committers):
 
330
        """Return the result of repository.gather_stats().
 
331
 
 
332
        :param repository: The repository to query in.
 
333
        :param revid: utf8 encoded rev id or an empty string to indicate None
 
334
        :param committers: 'yes' or 'no'.
 
335
 
 
336
        :return: A SmartServerResponse ('ok',), a encoded body looking like
 
337
              committers: 1
 
338
              firstrev: 1234.230 0
 
339
              latestrev: 345.700 3600
 
340
              revisions: 2
 
341
 
 
342
              But containing only fields returned by the gather_stats() call
 
343
        """
 
344
        if revid == '':
 
345
            decoded_revision_id = None
 
346
        else:
 
347
            decoded_revision_id = revid
 
348
        if committers == 'yes':
 
349
            decoded_committers = True
 
350
        else:
 
351
            decoded_committers = None
 
352
        stats = repository.gather_stats(decoded_revision_id, decoded_committers)
 
353
 
 
354
        body = ''
 
355
        if stats.has_key('committers'):
 
356
            body += 'committers: %d\n' % stats['committers']
 
357
        if stats.has_key('firstrev'):
 
358
            body += 'firstrev: %.3f %d\n' % stats['firstrev']
 
359
        if stats.has_key('latestrev'):
 
360
             body += 'latestrev: %.3f %d\n' % stats['latestrev']
 
361
        if stats.has_key('revisions'):
 
362
            body += 'revisions: %d\n' % stats['revisions']
 
363
        if stats.has_key('size'):
 
364
            body += 'size: %d\n' % stats['size']
 
365
 
 
366
        return SuccessfulSmartServerResponse(('ok', ), body)
 
367
 
 
368
 
 
369
class SmartServerRepositoryIsShared(SmartServerRepositoryRequest):
 
370
 
 
371
    def do_repository_request(self, repository):
 
372
        """Return the result of repository.is_shared().
 
373
 
 
374
        :param repository: The repository to query in.
 
375
        :return: A smart server response of ('yes', ) if the repository is
 
376
            shared, and ('no', ) if it is not.
 
377
        """
 
378
        if repository.is_shared():
 
379
            return SuccessfulSmartServerResponse(('yes', ))
 
380
        else:
 
381
            return SuccessfulSmartServerResponse(('no', ))
 
382
 
 
383
 
 
384
class SmartServerRepositoryLockWrite(SmartServerRepositoryRequest):
 
385
 
 
386
    def do_repository_request(self, repository, token=''):
 
387
        # XXX: this probably should not have a token.
 
388
        if token == '':
 
389
            token = None
 
390
        try:
 
391
            token = repository.lock_write(token=token)
 
392
        except errors.LockContention, e:
 
393
            return FailedSmartServerResponse(('LockContention',))
 
394
        except errors.UnlockableTransport:
 
395
            return FailedSmartServerResponse(('UnlockableTransport',))
 
396
        except errors.LockFailed, e:
 
397
            return FailedSmartServerResponse(('LockFailed',
 
398
                str(e.lock), str(e.why)))
 
399
        if token is not None:
 
400
            repository.leave_lock_in_place()
 
401
        repository.unlock()
 
402
        if token is None:
 
403
            token = ''
 
404
        return SuccessfulSmartServerResponse(('ok', token))
 
405
 
 
406
 
 
407
class SmartServerRepositoryGetStream(SmartServerRepositoryRequest):
 
408
 
 
409
    def do_repository_request(self, repository, to_network_name):
 
410
        """Get a stream for inserting into a to_format repository.
 
411
 
 
412
        :param repository: The repository to stream from.
 
413
        :param to_network_name: The network name of the format of the target
 
414
            repository.
 
415
        """
 
416
        self._to_format = network_format_registry.get(to_network_name)
 
417
        return None # Signal that we want a body.
 
418
 
 
419
    def do_body(self, body_bytes):
 
420
        repository = self._repository
 
421
        repository.lock_read()
 
422
        try:
 
423
            search_result, error = self.recreate_search(repository, body_bytes,
 
424
                discard_excess=True)
 
425
            if error is not None:
 
426
                repository.unlock()
 
427
                return error
 
428
            source = repository._get_source(self._to_format)
 
429
            stream = source.get_stream(search_result)
 
430
        except Exception:
 
431
            exc_info = sys.exc_info()
 
432
            try:
 
433
                # On non-error, unlocking is done by the body stream handler.
 
434
                repository.unlock()
 
435
            finally:
 
436
                raise exc_info[0], exc_info[1], exc_info[2]
 
437
        return SuccessfulSmartServerResponse(('ok',),
 
438
            body_stream=self.body_stream(stream, repository))
 
439
 
 
440
    def body_stream(self, stream, repository):
 
441
        byte_stream = _stream_to_byte_stream(stream, repository._format)
 
442
        try:
 
443
            for bytes in byte_stream:
 
444
                yield bytes
 
445
        except errors.RevisionNotPresent, e:
 
446
            # This shouldn't be able to happen, but as we don't buffer
 
447
            # everything it can in theory happen.
 
448
            repository.unlock()
 
449
            yield FailedSmartServerResponse(('NoSuchRevision', e.revision_id))
 
450
        else:
 
451
            repository.unlock()
 
452
 
 
453
 
 
454
def _stream_to_byte_stream(stream, src_format):
 
455
    """Convert a record stream to a self delimited byte stream."""
 
456
    pack_writer = pack.ContainerSerialiser()
 
457
    yield pack_writer.begin()
 
458
    yield pack_writer.bytes_record(src_format.network_name(), '')
 
459
    for substream_type, substream in stream:
 
460
        for record in substream:
 
461
            if record.storage_kind in ('chunked', 'fulltext'):
 
462
                serialised = record_to_fulltext_bytes(record)
 
463
            elif record.storage_kind == 'absent':
 
464
                raise ValueError("Absent factory for %s" % (record.key,))
 
465
            else:
 
466
                serialised = record.get_bytes_as(record.storage_kind)
 
467
            if serialised:
 
468
                # Some streams embed the whole stream into the wire
 
469
                # representation of the first record, which means that
 
470
                # later records have no wire representation: we skip them.
 
471
                yield pack_writer.bytes_record(serialised, [(substream_type,)])
 
472
    yield pack_writer.end()
 
473
 
 
474
 
 
475
def _byte_stream_to_stream(byte_stream):
 
476
    """Convert a byte stream into a format and a stream.
 
477
 
 
478
    :param byte_stream: A bytes iterator, as output by _stream_to_byte_stream.
 
479
    :return: (RepositoryFormat, stream_generator)
 
480
    """
 
481
    stream_decoder = pack.ContainerPushParser()
 
482
    def record_stream():
 
483
        """Closure to return the substreams."""
 
484
        # May have fully parsed records already.
 
485
        for record in stream_decoder.read_pending_records():
 
486
            record_names, record_bytes = record
 
487
            record_name, = record_names
 
488
            substream_type = record_name[0]
 
489
            substream = NetworkRecordStream([record_bytes])
 
490
            yield substream_type, substream.read()
 
491
        for bytes in byte_stream:
 
492
            stream_decoder.accept_bytes(bytes)
 
493
            for record in stream_decoder.read_pending_records():
 
494
                record_names, record_bytes = record
 
495
                record_name, = record_names
 
496
                substream_type = record_name[0]
 
497
                substream = NetworkRecordStream([record_bytes])
 
498
                yield substream_type, substream.read()
 
499
    for bytes in byte_stream:
 
500
        stream_decoder.accept_bytes(bytes)
 
501
        for record in stream_decoder.read_pending_records(max=1):
 
502
            record_names, src_format_name = record
 
503
            src_format = network_format_registry.get(src_format_name)
 
504
            return src_format, record_stream()
 
505
 
 
506
 
 
507
class SmartServerRepositoryUnlock(SmartServerRepositoryRequest):
 
508
 
 
509
    def do_repository_request(self, repository, token):
 
510
        try:
 
511
            repository.lock_write(token=token)
 
512
        except errors.TokenMismatch, e:
 
513
            return FailedSmartServerResponse(('TokenMismatch',))
 
514
        repository.dont_leave_lock_in_place()
 
515
        repository.unlock()
 
516
        return SuccessfulSmartServerResponse(('ok',))
 
517
 
 
518
 
 
519
class SmartServerRepositorySetMakeWorkingTrees(SmartServerRepositoryRequest):
 
520
 
 
521
    def do_repository_request(self, repository, str_bool_new_value):
 
522
        if str_bool_new_value == 'True':
 
523
            new_value = True
 
524
        else:
 
525
            new_value = False
 
526
        repository.set_make_working_trees(new_value)
 
527
        return SuccessfulSmartServerResponse(('ok',))
 
528
 
 
529
 
 
530
class SmartServerRepositoryTarball(SmartServerRepositoryRequest):
 
531
    """Get the raw repository files as a tarball.
 
532
 
 
533
    The returned tarball contains a .bzr control directory which in turn
 
534
    contains a repository.
 
535
 
 
536
    This takes one parameter, compression, which currently must be
 
537
    "", "gz", or "bz2".
 
538
 
 
539
    This is used to implement the Repository.copy_content_into operation.
 
540
    """
 
541
 
 
542
    def do_repository_request(self, repository, compression):
 
543
        tmp_dirname, tmp_repo = self._copy_to_tempdir(repository)
 
544
        try:
 
545
            controldir_name = tmp_dirname + '/.bzr'
 
546
            return self._tarfile_response(controldir_name, compression)
 
547
        finally:
 
548
            osutils.rmtree(tmp_dirname)
 
549
 
 
550
    def _copy_to_tempdir(self, from_repo):
 
551
        tmp_dirname = osutils.mkdtemp(prefix='tmpbzrclone')
 
552
        tmp_bzrdir = from_repo.bzrdir._format.initialize(tmp_dirname)
 
553
        tmp_repo = from_repo._format.initialize(tmp_bzrdir)
 
554
        from_repo.copy_content_into(tmp_repo)
 
555
        return tmp_dirname, tmp_repo
 
556
 
 
557
    def _tarfile_response(self, tmp_dirname, compression):
 
558
        temp = tempfile.NamedTemporaryFile()
 
559
        try:
 
560
            self._tarball_of_dir(tmp_dirname, compression, temp.file)
 
561
            # all finished; write the tempfile out to the network
 
562
            temp.seek(0)
 
563
            return SuccessfulSmartServerResponse(('ok',), temp.read())
 
564
            # FIXME: Don't read the whole thing into memory here; rather stream
 
565
            # it out from the file onto the network. mbp 20070411
 
566
        finally:
 
567
            temp.close()
 
568
 
 
569
    def _tarball_of_dir(self, dirname, compression, ofile):
 
570
        filename = os.path.basename(ofile.name)
 
571
        tarball = tarfile.open(fileobj=ofile, name=filename,
 
572
            mode='w|' + compression)
 
573
        try:
 
574
            # The tarball module only accepts ascii names, and (i guess)
 
575
            # packs them with their 8bit names.  We know all the files
 
576
            # within the repository have ASCII names so the should be safe
 
577
            # to pack in.
 
578
            dirname = dirname.encode(sys.getfilesystemencoding())
 
579
            # python's tarball module includes the whole path by default so
 
580
            # override it
 
581
            if not dirname.endswith('.bzr'):
 
582
                raise ValueError(dirname)
 
583
            tarball.add(dirname, '.bzr') # recursive by default
 
584
        finally:
 
585
            tarball.close()
 
586
 
 
587
 
 
588
class SmartServerRepositoryInsertStreamLocked(SmartServerRepositoryRequest):
 
589
    """Insert a record stream from a RemoteSink into a repository.
 
590
 
 
591
    This gets bytes pushed to it by the network infrastructure and turns that
 
592
    into a bytes iterator using a thread. That is then processed by
 
593
    _byte_stream_to_stream.
 
594
 
 
595
    New in 1.14.
 
596
    """
 
597
 
 
598
    def do_repository_request(self, repository, resume_tokens, lock_token):
 
599
        """StreamSink.insert_stream for a remote repository."""
 
600
        repository.lock_write(token=lock_token)
 
601
        self.do_insert_stream_request(repository, resume_tokens)
 
602
 
 
603
    def do_insert_stream_request(self, repository, resume_tokens):
 
604
        tokens = [token for token in resume_tokens.split(' ') if token]
 
605
        self.tokens = tokens
 
606
        self.repository = repository
 
607
        self.queue = Queue.Queue()
 
608
        self.insert_thread = threading.Thread(target=self._inserter_thread)
 
609
        self.insert_thread.start()
 
610
 
 
611
    def do_chunk(self, body_stream_chunk):
 
612
        self.queue.put(body_stream_chunk)
 
613
 
 
614
    def _inserter_thread(self):
 
615
        try:
 
616
            src_format, stream = _byte_stream_to_stream(
 
617
                self.blocking_byte_stream())
 
618
            self.insert_result = self.repository._get_sink().insert_stream(
 
619
                stream, src_format, self.tokens)
 
620
            self.insert_ok = True
 
621
        except:
 
622
            self.insert_exception = sys.exc_info()
 
623
            self.insert_ok = False
 
624
 
 
625
    def blocking_byte_stream(self):
 
626
        while True:
 
627
            bytes = self.queue.get()
 
628
            if bytes is StopIteration:
 
629
                return
 
630
            else:
 
631
                yield bytes
 
632
 
 
633
    def do_end(self):
 
634
        self.queue.put(StopIteration)
 
635
        if self.insert_thread is not None:
 
636
            self.insert_thread.join()
 
637
        if not self.insert_ok:
 
638
            exc_info = self.insert_exception
 
639
            raise exc_info[0], exc_info[1], exc_info[2]
 
640
        write_group_tokens, missing_keys = self.insert_result
 
641
        if write_group_tokens or missing_keys:
 
642
            # bzip needed? missing keys should typically be a small set.
 
643
            # Should this be a streaming body response ?
 
644
            missing_keys = sorted(missing_keys)
 
645
            bytes = bencode.bencode((write_group_tokens, missing_keys))
 
646
            self.repository.unlock()
 
647
            return SuccessfulSmartServerResponse(('missing-basis', bytes))
 
648
        else:
 
649
            self.repository.unlock()
 
650
            return SuccessfulSmartServerResponse(('ok', ))
 
651
 
 
652
 
 
653
class SmartServerRepositoryInsertStream(SmartServerRepositoryInsertStreamLocked):
 
654
    """Insert a record stream from a RemoteSink into an unlocked repository.
 
655
 
 
656
    This is the same as SmartServerRepositoryInsertStreamLocked, except it
 
657
    takes no lock_tokens; i.e. it works with an unlocked (or lock-free, e.g.
 
658
    like pack format) repository.
 
659
 
 
660
    New in 1.13.
 
661
    """
 
662
 
 
663
    def do_repository_request(self, repository, resume_tokens):
 
664
        """StreamSink.insert_stream for a remote repository."""
 
665
        repository.lock_write()
 
666
        self.do_insert_stream_request(repository, resume_tokens)
 
667
 
 
668