~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/repository.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2010-05-21 04:57:17 UTC
  • mfrom: (5244.1.1 trunk)
  • Revision ID: pqm@pqm.ubuntu.com-20100521045717-qe9khoe3xia0fqwm
(lifeless) Merge from 2.1,
 fix for closing fd's when a specific file is supplied to status/commit etc.
 (Robert Collins)

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2006-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
"""Server-side repository related request implmentations."""
 
18
 
 
19
import bz2
 
20
import os
 
21
import Queue
 
22
import sys
 
23
import tempfile
 
24
import threading
 
25
 
 
26
from bzrlib import (
 
27
    bencode,
 
28
    errors,
 
29
    graph,
 
30
    osutils,
 
31
    pack,
 
32
    ui,
 
33
    versionedfile,
 
34
    )
 
35
from bzrlib.bzrdir import BzrDir
 
36
from bzrlib.smart.request import (
 
37
    FailedSmartServerResponse,
 
38
    SmartServerRequest,
 
39
    SuccessfulSmartServerResponse,
 
40
    )
 
41
from bzrlib.repository import _strip_NULL_ghosts, network_format_registry
 
42
from bzrlib.recordcounter import RecordCounter
 
43
from bzrlib import revision as _mod_revision
 
44
from bzrlib.versionedfile import (
 
45
    NetworkRecordStream,
 
46
    record_to_fulltext_bytes,
 
47
    )
 
48
 
 
49
 
 
50
class SmartServerRepositoryRequest(SmartServerRequest):
 
51
    """Common base class for Repository requests."""
 
52
 
 
53
    def do(self, path, *args):
 
54
        """Execute a repository request.
 
55
 
 
56
        All Repository requests take a path to the repository as their first
 
57
        argument.  The repository must be at the exact path given by the
 
58
        client - no searching is done.
 
59
 
 
60
        The actual logic is delegated to self.do_repository_request.
 
61
 
 
62
        :param client_path: The path for the repository as received from the
 
63
            client.
 
64
        :return: A SmartServerResponse from self.do_repository_request().
 
65
        """
 
66
        transport = self.transport_from_client_path(path)
 
67
        bzrdir = BzrDir.open_from_transport(transport)
 
68
        # Save the repository for use with do_body.
 
69
        self._repository = bzrdir.open_repository()
 
70
        return self.do_repository_request(self._repository, *args)
 
71
 
 
72
    def do_repository_request(self, repository, *args):
 
73
        """Override to provide an implementation for a verb."""
 
74
        # No-op for verbs that take bodies (None as a result indicates a body
 
75
        # is expected)
 
76
        return None
 
77
 
 
78
    def recreate_search(self, repository, search_bytes, discard_excess=False):
 
79
        """Recreate a search from its serialised form.
 
80
 
 
81
        :param discard_excess: If True, and the search refers to data we don't
 
82
            have, just silently accept that fact - the verb calling
 
83
            recreate_search trusts that clients will look for missing things
 
84
            they expected and get it from elsewhere.
 
85
        """
 
86
        lines = search_bytes.split('\n')
 
87
        if lines[0] == 'ancestry-of':
 
88
            heads = lines[1:]
 
89
            search_result = graph.PendingAncestryResult(heads, repository)
 
90
            return search_result, None
 
91
        elif lines[0] == 'search':
 
92
            return self.recreate_search_from_recipe(repository, lines[1:],
 
93
                discard_excess=discard_excess)
 
94
        else:
 
95
            return (None, FailedSmartServerResponse(('BadSearch',)))
 
96
 
 
97
    def recreate_search_from_recipe(self, repository, lines,
 
98
        discard_excess=False):
 
99
        """Recreate a specific revision search (vs a from-tip search).
 
100
 
 
101
        :param discard_excess: If True, and the search refers to data we don't
 
102
            have, just silently accept that fact - the verb calling
 
103
            recreate_search trusts that clients will look for missing things
 
104
            they expected and get it from elsewhere.
 
105
        """
 
106
        start_keys = set(lines[0].split(' '))
 
107
        exclude_keys = set(lines[1].split(' '))
 
108
        revision_count = int(lines[2])
 
109
        repository.lock_read()
 
110
        try:
 
111
            search = repository.get_graph()._make_breadth_first_searcher(
 
112
                start_keys)
 
113
            while True:
 
114
                try:
 
115
                    next_revs = search.next()
 
116
                except StopIteration:
 
117
                    break
 
118
                search.stop_searching_any(exclude_keys.intersection(next_revs))
 
119
            search_result = search.get_result()
 
120
            if (not discard_excess and
 
121
                search_result.get_recipe()[3] != revision_count):
 
122
                # we got back a different amount of data than expected, this
 
123
                # gets reported as NoSuchRevision, because less revisions
 
124
                # indicates missing revisions, and more should never happen as
 
125
                # the excludes list considers ghosts and ensures that ghost
 
126
                # filling races are not a problem.
 
127
                return (None, FailedSmartServerResponse(('NoSuchRevision',)))
 
128
            return (search_result, None)
 
129
        finally:
 
130
            repository.unlock()
 
131
 
 
132
 
 
133
class SmartServerRepositoryReadLocked(SmartServerRepositoryRequest):
 
134
    """Calls self.do_readlocked_repository_request."""
 
135
 
 
136
    def do_repository_request(self, repository, *args):
 
137
        """Read lock a repository for do_readlocked_repository_request."""
 
138
        repository.lock_read()
 
139
        try:
 
140
            return self.do_readlocked_repository_request(repository, *args)
 
141
        finally:
 
142
            repository.unlock()
 
143
 
 
144
 
 
145
class SmartServerRepositoryGetParentMap(SmartServerRepositoryRequest):
 
146
    """Bzr 1.2+ - get parent data for revisions during a graph search."""
 
147
 
 
148
    no_extra_results = False
 
149
 
 
150
    def do_repository_request(self, repository, *revision_ids):
 
151
        """Get parent details for some revisions.
 
152
 
 
153
        All the parents for revision_ids are returned. Additionally up to 64KB
 
154
        of additional parent data found by performing a breadth first search
 
155
        from revision_ids is returned. The verb takes a body containing the
 
156
        current search state, see do_body for details.
 
157
 
 
158
        If 'include-missing:' is in revision_ids, ghosts encountered in the
 
159
        graph traversal for getting parent data are included in the result with
 
160
        a prefix of 'missing:'.
 
161
 
 
162
        :param repository: The repository to query in.
 
163
        :param revision_ids: The utf8 encoded revision_id to answer for.
 
164
        """
 
165
        self._revision_ids = revision_ids
 
166
        return None # Signal that we want a body.
 
167
 
 
168
    def do_body(self, body_bytes):
 
169
        """Process the current search state and perform the parent lookup.
 
170
 
 
171
        :return: A smart server response where the body contains an utf8
 
172
            encoded flattened list of the parents of the revisions (the same
 
173
            format as Repository.get_revision_graph) which has been bz2
 
174
            compressed.
 
175
        """
 
176
        repository = self._repository
 
177
        repository.lock_read()
 
178
        try:
 
179
            return self._do_repository_request(body_bytes)
 
180
        finally:
 
181
            repository.unlock()
 
182
 
 
183
    def _do_repository_request(self, body_bytes):
 
184
        repository = self._repository
 
185
        revision_ids = set(self._revision_ids)
 
186
        include_missing = 'include-missing:' in revision_ids
 
187
        if include_missing:
 
188
            revision_ids.remove('include-missing:')
 
189
        body_lines = body_bytes.split('\n')
 
190
        search_result, error = self.recreate_search_from_recipe(
 
191
            repository, body_lines)
 
192
        if error is not None:
 
193
            return error
 
194
        # TODO might be nice to start up the search again; but thats not
 
195
        # written or tested yet.
 
196
        client_seen_revs = set(search_result.get_keys())
 
197
        # Always include the requested ids.
 
198
        client_seen_revs.difference_update(revision_ids)
 
199
        lines = []
 
200
        repo_graph = repository.get_graph()
 
201
        result = {}
 
202
        queried_revs = set()
 
203
        size_so_far = 0
 
204
        next_revs = revision_ids
 
205
        first_loop_done = False
 
206
        while next_revs:
 
207
            queried_revs.update(next_revs)
 
208
            parent_map = repo_graph.get_parent_map(next_revs)
 
209
            current_revs = next_revs
 
210
            next_revs = set()
 
211
            for revision_id in current_revs:
 
212
                missing_rev = False
 
213
                parents = parent_map.get(revision_id)
 
214
                if parents is not None:
 
215
                    # adjust for the wire
 
216
                    if parents == (_mod_revision.NULL_REVISION,):
 
217
                        parents = ()
 
218
                    # prepare the next query
 
219
                    next_revs.update(parents)
 
220
                    encoded_id = revision_id
 
221
                else:
 
222
                    missing_rev = True
 
223
                    encoded_id = "missing:" + revision_id
 
224
                    parents = []
 
225
                if (revision_id not in client_seen_revs and
 
226
                    (not missing_rev or include_missing)):
 
227
                    # Client does not have this revision, give it to it.
 
228
                    # add parents to the result
 
229
                    result[encoded_id] = parents
 
230
                    # Approximate the serialized cost of this revision_id.
 
231
                    size_so_far += 2 + len(encoded_id) + sum(map(len, parents))
 
232
            # get all the directly asked for parents, and then flesh out to
 
233
            # 64K (compressed) or so. We do one level of depth at a time to
 
234
            # stay in sync with the client. The 250000 magic number is
 
235
            # estimated compression ratio taken from bzr.dev itself.
 
236
            if self.no_extra_results or (
 
237
                first_loop_done and size_so_far > 250000):
 
238
                next_revs = set()
 
239
                break
 
240
            # don't query things we've already queried
 
241
            next_revs.difference_update(queried_revs)
 
242
            first_loop_done = True
 
243
 
 
244
        # sorting trivially puts lexographically similar revision ids together.
 
245
        # Compression FTW.
 
246
        for revision, parents in sorted(result.items()):
 
247
            lines.append(' '.join((revision, ) + tuple(parents)))
 
248
 
 
249
        return SuccessfulSmartServerResponse(
 
250
            ('ok', ), bz2.compress('\n'.join(lines)))
 
251
 
 
252
 
 
253
class SmartServerRepositoryGetRevisionGraph(SmartServerRepositoryReadLocked):
 
254
 
 
255
    def do_readlocked_repository_request(self, repository, revision_id):
 
256
        """Return the result of repository.get_revision_graph(revision_id).
 
257
 
 
258
        Deprecated as of bzr 1.4, but supported for older clients.
 
259
 
 
260
        :param repository: The repository to query in.
 
261
        :param revision_id: The utf8 encoded revision_id to get a graph from.
 
262
        :return: A smart server response where the body contains an utf8
 
263
            encoded flattened list of the revision graph.
 
264
        """
 
265
        if not revision_id:
 
266
            revision_id = None
 
267
 
 
268
        lines = []
 
269
        graph = repository.get_graph()
 
270
        if revision_id:
 
271
            search_ids = [revision_id]
 
272
        else:
 
273
            search_ids = repository.all_revision_ids()
 
274
        search = graph._make_breadth_first_searcher(search_ids)
 
275
        transitive_ids = set()
 
276
        map(transitive_ids.update, list(search))
 
277
        parent_map = graph.get_parent_map(transitive_ids)
 
278
        revision_graph = _strip_NULL_ghosts(parent_map)
 
279
        if revision_id and revision_id not in revision_graph:
 
280
            # Note that we return an empty body, rather than omitting the body.
 
281
            # This way the client knows that it can always expect to find a body
 
282
            # in the response for this method, even in the error case.
 
283
            return FailedSmartServerResponse(('nosuchrevision', revision_id), '')
 
284
 
 
285
        for revision, parents in revision_graph.items():
 
286
            lines.append(' '.join((revision, ) + tuple(parents)))
 
287
 
 
288
        return SuccessfulSmartServerResponse(('ok', ), '\n'.join(lines))
 
289
 
 
290
 
 
291
class SmartServerRepositoryGetRevIdForRevno(SmartServerRepositoryReadLocked):
 
292
 
 
293
    def do_readlocked_repository_request(self, repository, revno,
 
294
            known_pair):
 
295
        """Find the revid for a given revno, given a known revno/revid pair.
 
296
        
 
297
        New in 1.17.
 
298
        """
 
299
        try:
 
300
            found_flag, result = repository.get_rev_id_for_revno(revno, known_pair)
 
301
        except errors.RevisionNotPresent, err:
 
302
            if err.revision_id != known_pair[1]:
 
303
                raise AssertionError(
 
304
                    'get_rev_id_for_revno raised RevisionNotPresent for '
 
305
                    'non-initial revision: ' + err.revision_id)
 
306
            return FailedSmartServerResponse(
 
307
                ('nosuchrevision', err.revision_id))
 
308
        if found_flag:
 
309
            return SuccessfulSmartServerResponse(('ok', result))
 
310
        else:
 
311
            earliest_revno, earliest_revid = result
 
312
            return SuccessfulSmartServerResponse(
 
313
                ('history-incomplete', earliest_revno, earliest_revid))
 
314
 
 
315
 
 
316
class SmartServerRequestHasRevision(SmartServerRepositoryRequest):
 
317
 
 
318
    def do_repository_request(self, repository, revision_id):
 
319
        """Return ok if a specific revision is in the repository at path.
 
320
 
 
321
        :param repository: The repository to query in.
 
322
        :param revision_id: The utf8 encoded revision_id to lookup.
 
323
        :return: A smart server response of ('ok', ) if the revision is
 
324
            present.
 
325
        """
 
326
        if repository.has_revision(revision_id):
 
327
            return SuccessfulSmartServerResponse(('yes', ))
 
328
        else:
 
329
            return SuccessfulSmartServerResponse(('no', ))
 
330
 
 
331
 
 
332
class SmartServerRepositoryGatherStats(SmartServerRepositoryRequest):
 
333
 
 
334
    def do_repository_request(self, repository, revid, committers):
 
335
        """Return the result of repository.gather_stats().
 
336
 
 
337
        :param repository: The repository to query in.
 
338
        :param revid: utf8 encoded rev id or an empty string to indicate None
 
339
        :param committers: 'yes' or 'no'.
 
340
 
 
341
        :return: A SmartServerResponse ('ok',), a encoded body looking like
 
342
              committers: 1
 
343
              firstrev: 1234.230 0
 
344
              latestrev: 345.700 3600
 
345
              revisions: 2
 
346
 
 
347
              But containing only fields returned by the gather_stats() call
 
348
        """
 
349
        if revid == '':
 
350
            decoded_revision_id = None
 
351
        else:
 
352
            decoded_revision_id = revid
 
353
        if committers == 'yes':
 
354
            decoded_committers = True
 
355
        else:
 
356
            decoded_committers = None
 
357
        stats = repository.gather_stats(decoded_revision_id, decoded_committers)
 
358
 
 
359
        body = ''
 
360
        if stats.has_key('committers'):
 
361
            body += 'committers: %d\n' % stats['committers']
 
362
        if stats.has_key('firstrev'):
 
363
            body += 'firstrev: %.3f %d\n' % stats['firstrev']
 
364
        if stats.has_key('latestrev'):
 
365
             body += 'latestrev: %.3f %d\n' % stats['latestrev']
 
366
        if stats.has_key('revisions'):
 
367
            body += 'revisions: %d\n' % stats['revisions']
 
368
        if stats.has_key('size'):
 
369
            body += 'size: %d\n' % stats['size']
 
370
 
 
371
        return SuccessfulSmartServerResponse(('ok', ), body)
 
372
 
 
373
 
 
374
class SmartServerRepositoryIsShared(SmartServerRepositoryRequest):
 
375
 
 
376
    def do_repository_request(self, repository):
 
377
        """Return the result of repository.is_shared().
 
378
 
 
379
        :param repository: The repository to query in.
 
380
        :return: A smart server response of ('yes', ) if the repository is
 
381
            shared, and ('no', ) if it is not.
 
382
        """
 
383
        if repository.is_shared():
 
384
            return SuccessfulSmartServerResponse(('yes', ))
 
385
        else:
 
386
            return SuccessfulSmartServerResponse(('no', ))
 
387
 
 
388
 
 
389
class SmartServerRepositoryLockWrite(SmartServerRepositoryRequest):
 
390
 
 
391
    def do_repository_request(self, repository, token=''):
 
392
        # XXX: this probably should not have a token.
 
393
        if token == '':
 
394
            token = None
 
395
        try:
 
396
            token = repository.lock_write(token=token).repository_token
 
397
        except errors.LockContention, e:
 
398
            return FailedSmartServerResponse(('LockContention',))
 
399
        except errors.UnlockableTransport:
 
400
            return FailedSmartServerResponse(('UnlockableTransport',))
 
401
        except errors.LockFailed, e:
 
402
            return FailedSmartServerResponse(('LockFailed',
 
403
                str(e.lock), str(e.why)))
 
404
        if token is not None:
 
405
            repository.leave_lock_in_place()
 
406
        repository.unlock()
 
407
        if token is None:
 
408
            token = ''
 
409
        return SuccessfulSmartServerResponse(('ok', token))
 
410
 
 
411
 
 
412
class SmartServerRepositoryGetStream(SmartServerRepositoryRequest):
 
413
 
 
414
    def do_repository_request(self, repository, to_network_name):
 
415
        """Get a stream for inserting into a to_format repository.
 
416
 
 
417
        :param repository: The repository to stream from.
 
418
        :param to_network_name: The network name of the format of the target
 
419
            repository.
 
420
        """
 
421
        self._to_format = network_format_registry.get(to_network_name)
 
422
        if self._should_fake_unknown():
 
423
            return FailedSmartServerResponse(
 
424
                ('UnknownMethod', 'Repository.get_stream'))
 
425
        return None # Signal that we want a body.
 
426
 
 
427
    def _should_fake_unknown(self):
 
428
        """Return True if we should return UnknownMethod to the client.
 
429
        
 
430
        This is a workaround for bugs in pre-1.19 clients that claim to
 
431
        support receiving streams of CHK repositories.  The pre-1.19 client
 
432
        expects inventory records to be serialized in the format defined by
 
433
        to_network_name, but in pre-1.19 (at least) that format definition
 
434
        tries to use the xml5 serializer, which does not correctly handle
 
435
        rich-roots.  After 1.19 the client can also accept inventory-deltas
 
436
        (which avoids this issue), and those clients will use the
 
437
        Repository.get_stream_1.19 verb instead of this one.
 
438
        So: if this repository is CHK, and the to_format doesn't match,
 
439
        we should just fake an UnknownSmartMethod error so that the client
 
440
        will fallback to VFS, rather than sending it a stream we know it
 
441
        cannot handle.
 
442
        """
 
443
        from_format = self._repository._format
 
444
        to_format = self._to_format
 
445
        if not from_format.supports_chks:
 
446
            # Source not CHK: that's ok
 
447
            return False
 
448
        if (to_format.supports_chks and
 
449
            from_format.repository_class is to_format.repository_class and
 
450
            from_format._serializer == to_format._serializer):
 
451
            # Source is CHK, but target matches: that's ok
 
452
            # (e.g. 2a->2a, or CHK2->2a)
 
453
            return False
 
454
        # Source is CHK, and target is not CHK or incompatible CHK.  We can't
 
455
        # generate a compatible stream.
 
456
        return True
 
457
 
 
458
    def do_body(self, body_bytes):
 
459
        repository = self._repository
 
460
        repository.lock_read()
 
461
        try:
 
462
            search_result, error = self.recreate_search(repository, body_bytes,
 
463
                discard_excess=True)
 
464
            if error is not None:
 
465
                repository.unlock()
 
466
                return error
 
467
            source = repository._get_source(self._to_format)
 
468
            stream = source.get_stream(search_result)
 
469
        except Exception:
 
470
            exc_info = sys.exc_info()
 
471
            try:
 
472
                # On non-error, unlocking is done by the body stream handler.
 
473
                repository.unlock()
 
474
            finally:
 
475
                raise exc_info[0], exc_info[1], exc_info[2]
 
476
        return SuccessfulSmartServerResponse(('ok',),
 
477
            body_stream=self.body_stream(stream, repository))
 
478
 
 
479
    def body_stream(self, stream, repository):
 
480
        byte_stream = _stream_to_byte_stream(stream, repository._format)
 
481
        try:
 
482
            for bytes in byte_stream:
 
483
                yield bytes
 
484
        except errors.RevisionNotPresent, e:
 
485
            # This shouldn't be able to happen, but as we don't buffer
 
486
            # everything it can in theory happen.
 
487
            repository.unlock()
 
488
            yield FailedSmartServerResponse(('NoSuchRevision', e.revision_id))
 
489
        else:
 
490
            repository.unlock()
 
491
 
 
492
 
 
493
class SmartServerRepositoryGetStream_1_19(SmartServerRepositoryGetStream):
 
494
 
 
495
    def _should_fake_unknown(self):
 
496
        """Returns False; we don't need to workaround bugs in 1.19+ clients."""
 
497
        return False
 
498
 
 
499
 
 
500
def _stream_to_byte_stream(stream, src_format):
 
501
    """Convert a record stream to a self delimited byte stream."""
 
502
    pack_writer = pack.ContainerSerialiser()
 
503
    yield pack_writer.begin()
 
504
    yield pack_writer.bytes_record(src_format.network_name(), '')
 
505
    for substream_type, substream in stream:
 
506
        for record in substream:
 
507
            if record.storage_kind in ('chunked', 'fulltext'):
 
508
                serialised = record_to_fulltext_bytes(record)
 
509
            elif record.storage_kind == 'inventory-delta':
 
510
                serialised = record_to_inventory_delta_bytes(record)
 
511
            elif record.storage_kind == 'absent':
 
512
                raise ValueError("Absent factory for %s" % (record.key,))
 
513
            else:
 
514
                serialised = record.get_bytes_as(record.storage_kind)
 
515
            if serialised:
 
516
                # Some streams embed the whole stream into the wire
 
517
                # representation of the first record, which means that
 
518
                # later records have no wire representation: we skip them.
 
519
                yield pack_writer.bytes_record(serialised, [(substream_type,)])
 
520
    yield pack_writer.end()
 
521
 
 
522
 
 
523
class _ByteStreamDecoder(object):
 
524
    """Helper for _byte_stream_to_stream.
 
525
 
 
526
    The expected usage of this class is via the function _byte_stream_to_stream
 
527
    which creates a _ByteStreamDecoder, pops off the stream format and then
 
528
    yields the output of record_stream(), the main entry point to
 
529
    _ByteStreamDecoder.
 
530
 
 
531
    Broadly this class has to unwrap two layers of iterators:
 
532
    (type, substream)
 
533
    (substream details)
 
534
 
 
535
    This is complicated by wishing to return type, iterator_for_type, but
 
536
    getting the data for iterator_for_type when we find out type: we can't
 
537
    simply pass a generator down to the NetworkRecordStream parser, instead
 
538
    we have a little local state to seed each NetworkRecordStream instance,
 
539
    and gather the type that we'll be yielding.
 
540
 
 
541
    :ivar byte_stream: The byte stream being decoded.
 
542
    :ivar stream_decoder: A pack parser used to decode the bytestream
 
543
    :ivar current_type: The current type, used to join adjacent records of the
 
544
        same type into a single stream.
 
545
    :ivar first_bytes: The first bytes to give the next NetworkRecordStream.
 
546
    """
 
547
 
 
548
    def __init__(self, byte_stream, record_counter):
 
549
        """Create a _ByteStreamDecoder."""
 
550
        self.stream_decoder = pack.ContainerPushParser()
 
551
        self.current_type = None
 
552
        self.first_bytes = None
 
553
        self.byte_stream = byte_stream
 
554
        self._record_counter = record_counter
 
555
        self.key_count = 0
 
556
 
 
557
    def iter_stream_decoder(self):
 
558
        """Iterate the contents of the pack from stream_decoder."""
 
559
        # dequeue pending items
 
560
        for record in self.stream_decoder.read_pending_records():
 
561
            yield record
 
562
        # Pull bytes of the wire, decode them to records, yield those records.
 
563
        for bytes in self.byte_stream:
 
564
            self.stream_decoder.accept_bytes(bytes)
 
565
            for record in self.stream_decoder.read_pending_records():
 
566
                yield record
 
567
 
 
568
    def iter_substream_bytes(self):
 
569
        if self.first_bytes is not None:
 
570
            yield self.first_bytes
 
571
            # If we run out of pack records, single the outer layer to stop.
 
572
            self.first_bytes = None
 
573
        for record in self.iter_pack_records:
 
574
            record_names, record_bytes = record
 
575
            record_name, = record_names
 
576
            substream_type = record_name[0]
 
577
            if substream_type != self.current_type:
 
578
                # end of a substream, seed the next substream.
 
579
                self.current_type = substream_type
 
580
                self.first_bytes = record_bytes
 
581
                return
 
582
            yield record_bytes
 
583
 
 
584
    def record_stream(self):
 
585
        """Yield substream_type, substream from the byte stream."""
 
586
        def wrap_and_count(pb, rc, substream):
 
587
            """Yield records from stream while showing progress."""
 
588
            counter = 0
 
589
            if rc:
 
590
                if self.current_type != 'revisions' and self.key_count != 0:
 
591
                    # As we know the number of revisions now (in self.key_count)
 
592
                    # we can setup and use record_counter (rc).
 
593
                    if not rc.is_initialized():
 
594
                        rc.setup(self.key_count, self.key_count)
 
595
            for record in substream.read():
 
596
                if rc:
 
597
                    if rc.is_initialized() and counter == rc.STEP:
 
598
                        rc.increment(counter)
 
599
                        pb.update('Estimate', rc.current, rc.max)
 
600
                        counter = 0
 
601
                    if self.current_type == 'revisions':
 
602
                        # Total records is proportional to number of revs
 
603
                        # to fetch. With remote, we used self.key_count to
 
604
                        # track the number of revs. Once we have the revs
 
605
                        # counts in self.key_count, the progress bar changes
 
606
                        # from 'Estimating..' to 'Estimate' above.
 
607
                        self.key_count += 1
 
608
                        if counter == rc.STEP:
 
609
                            pb.update('Estimating..', self.key_count)
 
610
                            counter = 0
 
611
                counter += 1
 
612
                yield record
 
613
 
 
614
        self.seed_state()
 
615
        pb = ui.ui_factory.nested_progress_bar()
 
616
        rc = self._record_counter
 
617
        # Make and consume sub generators, one per substream type:
 
618
        while self.first_bytes is not None:
 
619
            substream = NetworkRecordStream(self.iter_substream_bytes())
 
620
            # after substream is fully consumed, self.current_type is set to
 
621
            # the next type, and self.first_bytes is set to the matching bytes.
 
622
            yield self.current_type, wrap_and_count(pb, rc, substream)
 
623
        if rc:
 
624
            pb.update('Done', rc.max, rc.max)
 
625
        pb.finished()
 
626
 
 
627
    def seed_state(self):
 
628
        """Prepare the _ByteStreamDecoder to decode from the pack stream."""
 
629
        # Set a single generator we can use to get data from the pack stream.
 
630
        self.iter_pack_records = self.iter_stream_decoder()
 
631
        # Seed the very first subiterator with content; after this each one
 
632
        # seeds the next.
 
633
        list(self.iter_substream_bytes())
 
634
 
 
635
 
 
636
def _byte_stream_to_stream(byte_stream, record_counter=None):
 
637
    """Convert a byte stream into a format and a stream.
 
638
 
 
639
    :param byte_stream: A bytes iterator, as output by _stream_to_byte_stream.
 
640
    :return: (RepositoryFormat, stream_generator)
 
641
    """
 
642
    decoder = _ByteStreamDecoder(byte_stream, record_counter)
 
643
    for bytes in byte_stream:
 
644
        decoder.stream_decoder.accept_bytes(bytes)
 
645
        for record in decoder.stream_decoder.read_pending_records(max=1):
 
646
            record_names, src_format_name = record
 
647
            src_format = network_format_registry.get(src_format_name)
 
648
            return src_format, decoder.record_stream()
 
649
 
 
650
 
 
651
class SmartServerRepositoryUnlock(SmartServerRepositoryRequest):
 
652
 
 
653
    def do_repository_request(self, repository, token):
 
654
        try:
 
655
            repository.lock_write(token=token)
 
656
        except errors.TokenMismatch, e:
 
657
            return FailedSmartServerResponse(('TokenMismatch',))
 
658
        repository.dont_leave_lock_in_place()
 
659
        repository.unlock()
 
660
        return SuccessfulSmartServerResponse(('ok',))
 
661
 
 
662
 
 
663
class SmartServerRepositorySetMakeWorkingTrees(SmartServerRepositoryRequest):
 
664
 
 
665
    def do_repository_request(self, repository, str_bool_new_value):
 
666
        if str_bool_new_value == 'True':
 
667
            new_value = True
 
668
        else:
 
669
            new_value = False
 
670
        repository.set_make_working_trees(new_value)
 
671
        return SuccessfulSmartServerResponse(('ok',))
 
672
 
 
673
 
 
674
class SmartServerRepositoryTarball(SmartServerRepositoryRequest):
 
675
    """Get the raw repository files as a tarball.
 
676
 
 
677
    The returned tarball contains a .bzr control directory which in turn
 
678
    contains a repository.
 
679
 
 
680
    This takes one parameter, compression, which currently must be
 
681
    "", "gz", or "bz2".
 
682
 
 
683
    This is used to implement the Repository.copy_content_into operation.
 
684
    """
 
685
 
 
686
    def do_repository_request(self, repository, compression):
 
687
        tmp_dirname, tmp_repo = self._copy_to_tempdir(repository)
 
688
        try:
 
689
            controldir_name = tmp_dirname + '/.bzr'
 
690
            return self._tarfile_response(controldir_name, compression)
 
691
        finally:
 
692
            osutils.rmtree(tmp_dirname)
 
693
 
 
694
    def _copy_to_tempdir(self, from_repo):
 
695
        tmp_dirname = osutils.mkdtemp(prefix='tmpbzrclone')
 
696
        tmp_bzrdir = from_repo.bzrdir._format.initialize(tmp_dirname)
 
697
        tmp_repo = from_repo._format.initialize(tmp_bzrdir)
 
698
        from_repo.copy_content_into(tmp_repo)
 
699
        return tmp_dirname, tmp_repo
 
700
 
 
701
    def _tarfile_response(self, tmp_dirname, compression):
 
702
        temp = tempfile.NamedTemporaryFile()
 
703
        try:
 
704
            self._tarball_of_dir(tmp_dirname, compression, temp.file)
 
705
            # all finished; write the tempfile out to the network
 
706
            temp.seek(0)
 
707
            return SuccessfulSmartServerResponse(('ok',), temp.read())
 
708
            # FIXME: Don't read the whole thing into memory here; rather stream
 
709
            # it out from the file onto the network. mbp 20070411
 
710
        finally:
 
711
            temp.close()
 
712
 
 
713
    def _tarball_of_dir(self, dirname, compression, ofile):
 
714
        import tarfile
 
715
        filename = os.path.basename(ofile.name)
 
716
        tarball = tarfile.open(fileobj=ofile, name=filename,
 
717
            mode='w|' + compression)
 
718
        try:
 
719
            # The tarball module only accepts ascii names, and (i guess)
 
720
            # packs them with their 8bit names.  We know all the files
 
721
            # within the repository have ASCII names so the should be safe
 
722
            # to pack in.
 
723
            dirname = dirname.encode(sys.getfilesystemencoding())
 
724
            # python's tarball module includes the whole path by default so
 
725
            # override it
 
726
            if not dirname.endswith('.bzr'):
 
727
                raise ValueError(dirname)
 
728
            tarball.add(dirname, '.bzr') # recursive by default
 
729
        finally:
 
730
            tarball.close()
 
731
 
 
732
 
 
733
class SmartServerRepositoryInsertStreamLocked(SmartServerRepositoryRequest):
 
734
    """Insert a record stream from a RemoteSink into a repository.
 
735
 
 
736
    This gets bytes pushed to it by the network infrastructure and turns that
 
737
    into a bytes iterator using a thread. That is then processed by
 
738
    _byte_stream_to_stream.
 
739
 
 
740
    New in 1.14.
 
741
    """
 
742
 
 
743
    def do_repository_request(self, repository, resume_tokens, lock_token):
 
744
        """StreamSink.insert_stream for a remote repository."""
 
745
        repository.lock_write(token=lock_token)
 
746
        self.do_insert_stream_request(repository, resume_tokens)
 
747
 
 
748
    def do_insert_stream_request(self, repository, resume_tokens):
 
749
        tokens = [token for token in resume_tokens.split(' ') if token]
 
750
        self.tokens = tokens
 
751
        self.repository = repository
 
752
        self.queue = Queue.Queue()
 
753
        self.insert_thread = threading.Thread(target=self._inserter_thread)
 
754
        self.insert_thread.start()
 
755
 
 
756
    def do_chunk(self, body_stream_chunk):
 
757
        self.queue.put(body_stream_chunk)
 
758
 
 
759
    def _inserter_thread(self):
 
760
        try:
 
761
            src_format, stream = _byte_stream_to_stream(
 
762
                self.blocking_byte_stream())
 
763
            self.insert_result = self.repository._get_sink().insert_stream(
 
764
                stream, src_format, self.tokens)
 
765
            self.insert_ok = True
 
766
        except:
 
767
            self.insert_exception = sys.exc_info()
 
768
            self.insert_ok = False
 
769
 
 
770
    def blocking_byte_stream(self):
 
771
        while True:
 
772
            bytes = self.queue.get()
 
773
            if bytes is StopIteration:
 
774
                return
 
775
            else:
 
776
                yield bytes
 
777
 
 
778
    def do_end(self):
 
779
        self.queue.put(StopIteration)
 
780
        if self.insert_thread is not None:
 
781
            self.insert_thread.join()
 
782
        if not self.insert_ok:
 
783
            exc_info = self.insert_exception
 
784
            raise exc_info[0], exc_info[1], exc_info[2]
 
785
        write_group_tokens, missing_keys = self.insert_result
 
786
        if write_group_tokens or missing_keys:
 
787
            # bzip needed? missing keys should typically be a small set.
 
788
            # Should this be a streaming body response ?
 
789
            missing_keys = sorted(missing_keys)
 
790
            bytes = bencode.bencode((write_group_tokens, missing_keys))
 
791
            self.repository.unlock()
 
792
            return SuccessfulSmartServerResponse(('missing-basis', bytes))
 
793
        else:
 
794
            self.repository.unlock()
 
795
            return SuccessfulSmartServerResponse(('ok', ))
 
796
 
 
797
 
 
798
class SmartServerRepositoryInsertStream_1_19(SmartServerRepositoryInsertStreamLocked):
 
799
    """Insert a record stream from a RemoteSink into a repository.
 
800
 
 
801
    Same as SmartServerRepositoryInsertStreamLocked, except:
 
802
     - the lock token argument is optional
 
803
     - servers that implement this verb accept 'inventory-delta' records in the
 
804
       stream.
 
805
 
 
806
    New in 1.19.
 
807
    """
 
808
 
 
809
    def do_repository_request(self, repository, resume_tokens, lock_token=None):
 
810
        """StreamSink.insert_stream for a remote repository."""
 
811
        SmartServerRepositoryInsertStreamLocked.do_repository_request(
 
812
            self, repository, resume_tokens, lock_token)
 
813
 
 
814
 
 
815
class SmartServerRepositoryInsertStream(SmartServerRepositoryInsertStreamLocked):
 
816
    """Insert a record stream from a RemoteSink into an unlocked repository.
 
817
 
 
818
    This is the same as SmartServerRepositoryInsertStreamLocked, except it
 
819
    takes no lock_tokens; i.e. it works with an unlocked (or lock-free, e.g.
 
820
    like pack format) repository.
 
821
 
 
822
    New in 1.13.
 
823
    """
 
824
 
 
825
    def do_repository_request(self, repository, resume_tokens):
 
826
        """StreamSink.insert_stream for a remote repository."""
 
827
        repository.lock_write()
 
828
        self.do_insert_stream_request(repository, resume_tokens)
 
829
 
 
830