~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/smart/repository.py

  • Committer: Martin Pool
  • Date: 2005-09-15 08:37:41 UTC
  • Revision ID: mbp@sourcefrog.net-20050915083741-70d7550b97c7b580
- some updates for fetch/update function

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
"""Server-side repository related request implmentations."""
18
 
 
19
 
import bz2
20
 
import os
21
 
import Queue
22
 
import sys
23
 
import tempfile
24
 
import threading
25
 
 
26
 
from bzrlib import (
27
 
    bencode,
28
 
    commands,
29
 
    errors,
30
 
    estimate_compressed_size,
31
 
    graph,
32
 
    osutils,
33
 
    pack,
34
 
    trace,
35
 
    ui,
36
 
    )
37
 
from bzrlib.bzrdir import BzrDir
38
 
from bzrlib.smart.request import (
39
 
    FailedSmartServerResponse,
40
 
    SmartServerRequest,
41
 
    SuccessfulSmartServerResponse,
42
 
    )
43
 
from bzrlib.repository import _strip_NULL_ghosts, network_format_registry
44
 
from bzrlib import revision as _mod_revision
45
 
from bzrlib.versionedfile import (
46
 
    NetworkRecordStream,
47
 
    record_to_fulltext_bytes,
48
 
    )
49
 
 
50
 
 
51
 
class SmartServerRepositoryRequest(SmartServerRequest):
52
 
    """Common base class for Repository requests."""
53
 
 
54
 
    def do(self, path, *args):
55
 
        """Execute a repository request.
56
 
 
57
 
        All Repository requests take a path to the repository as their first
58
 
        argument.  The repository must be at the exact path given by the
59
 
        client - no searching is done.
60
 
 
61
 
        The actual logic is delegated to self.do_repository_request.
62
 
 
63
 
        :param client_path: The path for the repository as received from the
64
 
            client.
65
 
        :return: A SmartServerResponse from self.do_repository_request().
66
 
        """
67
 
        transport = self.transport_from_client_path(path)
68
 
        bzrdir = BzrDir.open_from_transport(transport)
69
 
        # Save the repository for use with do_body.
70
 
        self._repository = bzrdir.open_repository()
71
 
        return self.do_repository_request(self._repository, *args)
72
 
 
73
 
    def do_repository_request(self, repository, *args):
74
 
        """Override to provide an implementation for a verb."""
75
 
        # No-op for verbs that take bodies (None as a result indicates a body
76
 
        # is expected)
77
 
        return None
78
 
 
79
 
    def recreate_search(self, repository, search_bytes, discard_excess=False):
80
 
        """Recreate a search from its serialised form.
81
 
 
82
 
        :param discard_excess: If True, and the search refers to data we don't
83
 
            have, just silently accept that fact - the verb calling
84
 
            recreate_search trusts that clients will look for missing things
85
 
            they expected and get it from elsewhere.
86
 
        """
87
 
        if search_bytes == 'everything':
88
 
            return graph.EverythingResult(repository), None
89
 
        lines = search_bytes.split('\n')
90
 
        if lines[0] == 'ancestry-of':
91
 
            heads = lines[1:]
92
 
            search_result = graph.PendingAncestryResult(heads, repository)
93
 
            return search_result, None
94
 
        elif lines[0] == 'search':
95
 
            return self.recreate_search_from_recipe(repository, lines[1:],
96
 
                discard_excess=discard_excess)
97
 
        else:
98
 
            return (None, FailedSmartServerResponse(('BadSearch',)))
99
 
 
100
 
    def recreate_search_from_recipe(self, repository, lines,
101
 
        discard_excess=False):
102
 
        """Recreate a specific revision search (vs a from-tip search).
103
 
 
104
 
        :param discard_excess: If True, and the search refers to data we don't
105
 
            have, just silently accept that fact - the verb calling
106
 
            recreate_search trusts that clients will look for missing things
107
 
            they expected and get it from elsewhere.
108
 
        """
109
 
        start_keys = set(lines[0].split(' '))
110
 
        exclude_keys = set(lines[1].split(' '))
111
 
        revision_count = int(lines[2])
112
 
        repository.lock_read()
113
 
        try:
114
 
            search = repository.get_graph()._make_breadth_first_searcher(
115
 
                start_keys)
116
 
            while True:
117
 
                try:
118
 
                    next_revs = search.next()
119
 
                except StopIteration:
120
 
                    break
121
 
                search.stop_searching_any(exclude_keys.intersection(next_revs))
122
 
            search_result = search.get_result()
123
 
            if (not discard_excess and
124
 
                search_result.get_recipe()[3] != revision_count):
125
 
                # we got back a different amount of data than expected, this
126
 
                # gets reported as NoSuchRevision, because less revisions
127
 
                # indicates missing revisions, and more should never happen as
128
 
                # the excludes list considers ghosts and ensures that ghost
129
 
                # filling races are not a problem.
130
 
                return (None, FailedSmartServerResponse(('NoSuchRevision',)))
131
 
            return (search_result, None)
132
 
        finally:
133
 
            repository.unlock()
134
 
 
135
 
 
136
 
class SmartServerRepositoryReadLocked(SmartServerRepositoryRequest):
137
 
    """Calls self.do_readlocked_repository_request."""
138
 
 
139
 
    def do_repository_request(self, repository, *args):
140
 
        """Read lock a repository for do_readlocked_repository_request."""
141
 
        repository.lock_read()
142
 
        try:
143
 
            return self.do_readlocked_repository_request(repository, *args)
144
 
        finally:
145
 
            repository.unlock()
146
 
 
147
 
_lsprof_count = 0
148
 
 
149
 
class SmartServerRepositoryGetParentMap(SmartServerRepositoryRequest):
150
 
    """Bzr 1.2+ - get parent data for revisions during a graph search."""
151
 
 
152
 
    no_extra_results = False
153
 
 
154
 
    def do_repository_request(self, repository, *revision_ids):
155
 
        """Get parent details for some revisions.
156
 
 
157
 
        All the parents for revision_ids are returned. Additionally up to 64KB
158
 
        of additional parent data found by performing a breadth first search
159
 
        from revision_ids is returned. The verb takes a body containing the
160
 
        current search state, see do_body for details.
161
 
 
162
 
        If 'include-missing:' is in revision_ids, ghosts encountered in the
163
 
        graph traversal for getting parent data are included in the result with
164
 
        a prefix of 'missing:'.
165
 
 
166
 
        :param repository: The repository to query in.
167
 
        :param revision_ids: The utf8 encoded revision_id to answer for.
168
 
        """
169
 
        self._revision_ids = revision_ids
170
 
        return None # Signal that we want a body.
171
 
 
172
 
    def do_body(self, body_bytes):
173
 
        """Process the current search state and perform the parent lookup.
174
 
 
175
 
        :return: A smart server response where the body contains an utf8
176
 
            encoded flattened list of the parents of the revisions (the same
177
 
            format as Repository.get_revision_graph) which has been bz2
178
 
            compressed.
179
 
        """
180
 
        repository = self._repository
181
 
        repository.lock_read()
182
 
        try:
183
 
            return self._do_repository_request(body_bytes)
184
 
        finally:
185
 
            repository.unlock()
186
 
 
187
 
    def _expand_requested_revs(self, repo_graph, revision_ids, client_seen_revs,
188
 
                               include_missing, max_size=65536):
189
 
        result = {}
190
 
        queried_revs = set()
191
 
        estimator = estimate_compressed_size.ZLibEstimator(max_size)
192
 
        next_revs = revision_ids
193
 
        first_loop_done = False
194
 
        while next_revs:
195
 
            queried_revs.update(next_revs)
196
 
            parent_map = repo_graph.get_parent_map(next_revs)
197
 
            current_revs = next_revs
198
 
            next_revs = set()
199
 
            for revision_id in current_revs:
200
 
                missing_rev = False
201
 
                parents = parent_map.get(revision_id)
202
 
                if parents is not None:
203
 
                    # adjust for the wire
204
 
                    if parents == (_mod_revision.NULL_REVISION,):
205
 
                        parents = ()
206
 
                    # prepare the next query
207
 
                    next_revs.update(parents)
208
 
                    encoded_id = revision_id
209
 
                else:
210
 
                    missing_rev = True
211
 
                    encoded_id = "missing:" + revision_id
212
 
                    parents = []
213
 
                if (revision_id not in client_seen_revs and
214
 
                    (not missing_rev or include_missing)):
215
 
                    # Client does not have this revision, give it to it.
216
 
                    # add parents to the result
217
 
                    result[encoded_id] = parents
218
 
                    # Approximate the serialized cost of this revision_id.
219
 
                    line = '%s %s\n' % (encoded_id, ' '.join(parents))
220
 
                    estimator.add_content(line)
221
 
            # get all the directly asked for parents, and then flesh out to
222
 
            # 64K (compressed) or so. We do one level of depth at a time to
223
 
            # stay in sync with the client. The 250000 magic number is
224
 
            # estimated compression ratio taken from bzr.dev itself.
225
 
            if self.no_extra_results or (first_loop_done and estimator.full()):
226
 
                trace.mutter('size: %d, z_size: %d'
227
 
                             % (estimator._uncompressed_size_added,
228
 
                                estimator._compressed_size_added))
229
 
                next_revs = set()
230
 
                break
231
 
            # don't query things we've already queried
232
 
            next_revs = next_revs.difference(queried_revs)
233
 
            first_loop_done = True
234
 
        return result
235
 
 
236
 
    def _do_repository_request(self, body_bytes):
237
 
        repository = self._repository
238
 
        revision_ids = set(self._revision_ids)
239
 
        include_missing = 'include-missing:' in revision_ids
240
 
        if include_missing:
241
 
            revision_ids.remove('include-missing:')
242
 
        body_lines = body_bytes.split('\n')
243
 
        search_result, error = self.recreate_search_from_recipe(
244
 
            repository, body_lines)
245
 
        if error is not None:
246
 
            return error
247
 
        # TODO might be nice to start up the search again; but thats not
248
 
        # written or tested yet.
249
 
        client_seen_revs = set(search_result.get_keys())
250
 
        # Always include the requested ids.
251
 
        client_seen_revs.difference_update(revision_ids)
252
 
 
253
 
        repo_graph = repository.get_graph()
254
 
        result = self._expand_requested_revs(repo_graph, revision_ids,
255
 
                                             client_seen_revs, include_missing)
256
 
 
257
 
        # sorting trivially puts lexographically similar revision ids together.
258
 
        # Compression FTW.
259
 
        lines = []
260
 
        for revision, parents in sorted(result.items()):
261
 
            lines.append(' '.join((revision, ) + tuple(parents)))
262
 
 
263
 
        return SuccessfulSmartServerResponse(
264
 
            ('ok', ), bz2.compress('\n'.join(lines)))
265
 
 
266
 
 
267
 
class SmartServerRepositoryGetRevisionGraph(SmartServerRepositoryReadLocked):
268
 
 
269
 
    def do_readlocked_repository_request(self, repository, revision_id):
270
 
        """Return the result of repository.get_revision_graph(revision_id).
271
 
 
272
 
        Deprecated as of bzr 1.4, but supported for older clients.
273
 
 
274
 
        :param repository: The repository to query in.
275
 
        :param revision_id: The utf8 encoded revision_id to get a graph from.
276
 
        :return: A smart server response where the body contains an utf8
277
 
            encoded flattened list of the revision graph.
278
 
        """
279
 
        if not revision_id:
280
 
            revision_id = None
281
 
 
282
 
        lines = []
283
 
        graph = repository.get_graph()
284
 
        if revision_id:
285
 
            search_ids = [revision_id]
286
 
        else:
287
 
            search_ids = repository.all_revision_ids()
288
 
        search = graph._make_breadth_first_searcher(search_ids)
289
 
        transitive_ids = set()
290
 
        map(transitive_ids.update, list(search))
291
 
        parent_map = graph.get_parent_map(transitive_ids)
292
 
        revision_graph = _strip_NULL_ghosts(parent_map)
293
 
        if revision_id and revision_id not in revision_graph:
294
 
            # Note that we return an empty body, rather than omitting the body.
295
 
            # This way the client knows that it can always expect to find a body
296
 
            # in the response for this method, even in the error case.
297
 
            return FailedSmartServerResponse(('nosuchrevision', revision_id), '')
298
 
 
299
 
        for revision, parents in revision_graph.items():
300
 
            lines.append(' '.join((revision, ) + tuple(parents)))
301
 
 
302
 
        return SuccessfulSmartServerResponse(('ok', ), '\n'.join(lines))
303
 
 
304
 
 
305
 
class SmartServerRepositoryGetRevIdForRevno(SmartServerRepositoryReadLocked):
306
 
 
307
 
    def do_readlocked_repository_request(self, repository, revno,
308
 
            known_pair):
309
 
        """Find the revid for a given revno, given a known revno/revid pair.
310
 
        
311
 
        New in 1.17.
312
 
        """
313
 
        try:
314
 
            found_flag, result = repository.get_rev_id_for_revno(revno, known_pair)
315
 
        except errors.RevisionNotPresent, err:
316
 
            if err.revision_id != known_pair[1]:
317
 
                raise AssertionError(
318
 
                    'get_rev_id_for_revno raised RevisionNotPresent for '
319
 
                    'non-initial revision: ' + err.revision_id)
320
 
            return FailedSmartServerResponse(
321
 
                ('nosuchrevision', err.revision_id))
322
 
        if found_flag:
323
 
            return SuccessfulSmartServerResponse(('ok', result))
324
 
        else:
325
 
            earliest_revno, earliest_revid = result
326
 
            return SuccessfulSmartServerResponse(
327
 
                ('history-incomplete', earliest_revno, earliest_revid))
328
 
 
329
 
 
330
 
class SmartServerRequestHasRevision(SmartServerRepositoryRequest):
331
 
 
332
 
    def do_repository_request(self, repository, revision_id):
333
 
        """Return ok if a specific revision is in the repository at path.
334
 
 
335
 
        :param repository: The repository to query in.
336
 
        :param revision_id: The utf8 encoded revision_id to lookup.
337
 
        :return: A smart server response of ('ok', ) if the revision is
338
 
            present.
339
 
        """
340
 
        if repository.has_revision(revision_id):
341
 
            return SuccessfulSmartServerResponse(('yes', ))
342
 
        else:
343
 
            return SuccessfulSmartServerResponse(('no', ))
344
 
 
345
 
 
346
 
class SmartServerRepositoryGatherStats(SmartServerRepositoryRequest):
347
 
 
348
 
    def do_repository_request(self, repository, revid, committers):
349
 
        """Return the result of repository.gather_stats().
350
 
 
351
 
        :param repository: The repository to query in.
352
 
        :param revid: utf8 encoded rev id or an empty string to indicate None
353
 
        :param committers: 'yes' or 'no'.
354
 
 
355
 
        :return: A SmartServerResponse ('ok',), a encoded body looking like
356
 
              committers: 1
357
 
              firstrev: 1234.230 0
358
 
              latestrev: 345.700 3600
359
 
              revisions: 2
360
 
 
361
 
              But containing only fields returned by the gather_stats() call
362
 
        """
363
 
        if revid == '':
364
 
            decoded_revision_id = None
365
 
        else:
366
 
            decoded_revision_id = revid
367
 
        if committers == 'yes':
368
 
            decoded_committers = True
369
 
        else:
370
 
            decoded_committers = None
371
 
        stats = repository.gather_stats(decoded_revision_id, decoded_committers)
372
 
 
373
 
        body = ''
374
 
        if stats.has_key('committers'):
375
 
            body += 'committers: %d\n' % stats['committers']
376
 
        if stats.has_key('firstrev'):
377
 
            body += 'firstrev: %.3f %d\n' % stats['firstrev']
378
 
        if stats.has_key('latestrev'):
379
 
             body += 'latestrev: %.3f %d\n' % stats['latestrev']
380
 
        if stats.has_key('revisions'):
381
 
            body += 'revisions: %d\n' % stats['revisions']
382
 
        if stats.has_key('size'):
383
 
            body += 'size: %d\n' % stats['size']
384
 
 
385
 
        return SuccessfulSmartServerResponse(('ok', ), body)
386
 
 
387
 
 
388
 
class SmartServerRepositoryIsShared(SmartServerRepositoryRequest):
389
 
 
390
 
    def do_repository_request(self, repository):
391
 
        """Return the result of repository.is_shared().
392
 
 
393
 
        :param repository: The repository to query in.
394
 
        :return: A smart server response of ('yes', ) if the repository is
395
 
            shared, and ('no', ) if it is not.
396
 
        """
397
 
        if repository.is_shared():
398
 
            return SuccessfulSmartServerResponse(('yes', ))
399
 
        else:
400
 
            return SuccessfulSmartServerResponse(('no', ))
401
 
 
402
 
 
403
 
class SmartServerRepositoryLockWrite(SmartServerRepositoryRequest):
404
 
 
405
 
    def do_repository_request(self, repository, token=''):
406
 
        # XXX: this probably should not have a token.
407
 
        if token == '':
408
 
            token = None
409
 
        try:
410
 
            token = repository.lock_write(token=token).repository_token
411
 
        except errors.LockContention, e:
412
 
            return FailedSmartServerResponse(('LockContention',))
413
 
        except errors.UnlockableTransport:
414
 
            return FailedSmartServerResponse(('UnlockableTransport',))
415
 
        except errors.LockFailed, e:
416
 
            return FailedSmartServerResponse(('LockFailed',
417
 
                str(e.lock), str(e.why)))
418
 
        if token is not None:
419
 
            repository.leave_lock_in_place()
420
 
        repository.unlock()
421
 
        if token is None:
422
 
            token = ''
423
 
        return SuccessfulSmartServerResponse(('ok', token))
424
 
 
425
 
 
426
 
class SmartServerRepositoryGetStream(SmartServerRepositoryRequest):
427
 
 
428
 
    def do_repository_request(self, repository, to_network_name):
429
 
        """Get a stream for inserting into a to_format repository.
430
 
 
431
 
        The request body is 'search_bytes', a description of the revisions
432
 
        being requested.
433
 
 
434
 
        In 2.3 this verb added support for search_bytes == 'everything'.  Older
435
 
        implementations will respond with a BadSearch error, and clients should
436
 
        catch this and fallback appropriately.
437
 
 
438
 
        :param repository: The repository to stream from.
439
 
        :param to_network_name: The network name of the format of the target
440
 
            repository.
441
 
        """
442
 
        self._to_format = network_format_registry.get(to_network_name)
443
 
        if self._should_fake_unknown():
444
 
            return FailedSmartServerResponse(
445
 
                ('UnknownMethod', 'Repository.get_stream'))
446
 
        return None # Signal that we want a body.
447
 
 
448
 
    def _should_fake_unknown(self):
449
 
        """Return True if we should return UnknownMethod to the client.
450
 
        
451
 
        This is a workaround for bugs in pre-1.19 clients that claim to
452
 
        support receiving streams of CHK repositories.  The pre-1.19 client
453
 
        expects inventory records to be serialized in the format defined by
454
 
        to_network_name, but in pre-1.19 (at least) that format definition
455
 
        tries to use the xml5 serializer, which does not correctly handle
456
 
        rich-roots.  After 1.19 the client can also accept inventory-deltas
457
 
        (which avoids this issue), and those clients will use the
458
 
        Repository.get_stream_1.19 verb instead of this one.
459
 
        So: if this repository is CHK, and the to_format doesn't match,
460
 
        we should just fake an UnknownSmartMethod error so that the client
461
 
        will fallback to VFS, rather than sending it a stream we know it
462
 
        cannot handle.
463
 
        """
464
 
        from_format = self._repository._format
465
 
        to_format = self._to_format
466
 
        if not from_format.supports_chks:
467
 
            # Source not CHK: that's ok
468
 
            return False
469
 
        if (to_format.supports_chks and
470
 
            from_format.repository_class is to_format.repository_class and
471
 
            from_format._serializer == to_format._serializer):
472
 
            # Source is CHK, but target matches: that's ok
473
 
            # (e.g. 2a->2a, or CHK2->2a)
474
 
            return False
475
 
        # Source is CHK, and target is not CHK or incompatible CHK.  We can't
476
 
        # generate a compatible stream.
477
 
        return True
478
 
 
479
 
    def do_body(self, body_bytes):
480
 
        repository = self._repository
481
 
        repository.lock_read()
482
 
        try:
483
 
            search_result, error = self.recreate_search(repository, body_bytes,
484
 
                discard_excess=True)
485
 
            if error is not None:
486
 
                repository.unlock()
487
 
                return error
488
 
            source = repository._get_source(self._to_format)
489
 
            stream = source.get_stream(search_result)
490
 
        except Exception:
491
 
            exc_info = sys.exc_info()
492
 
            try:
493
 
                # On non-error, unlocking is done by the body stream handler.
494
 
                repository.unlock()
495
 
            finally:
496
 
                raise exc_info[0], exc_info[1], exc_info[2]
497
 
        return SuccessfulSmartServerResponse(('ok',),
498
 
            body_stream=self.body_stream(stream, repository))
499
 
 
500
 
    def body_stream(self, stream, repository):
501
 
        byte_stream = _stream_to_byte_stream(stream, repository._format)
502
 
        try:
503
 
            for bytes in byte_stream:
504
 
                yield bytes
505
 
        except errors.RevisionNotPresent, e:
506
 
            # This shouldn't be able to happen, but as we don't buffer
507
 
            # everything it can in theory happen.
508
 
            repository.unlock()
509
 
            yield FailedSmartServerResponse(('NoSuchRevision', e.revision_id))
510
 
        else:
511
 
            repository.unlock()
512
 
 
513
 
 
514
 
class SmartServerRepositoryGetStream_1_19(SmartServerRepositoryGetStream):
515
 
    """The same as Repository.get_stream, but will return stream CHK formats to
516
 
    clients.
517
 
 
518
 
    See SmartServerRepositoryGetStream._should_fake_unknown.
519
 
    
520
 
    New in 1.19.
521
 
    """
522
 
 
523
 
    def _should_fake_unknown(self):
524
 
        """Returns False; we don't need to workaround bugs in 1.19+ clients."""
525
 
        return False
526
 
 
527
 
 
528
 
def _stream_to_byte_stream(stream, src_format):
529
 
    """Convert a record stream to a self delimited byte stream."""
530
 
    pack_writer = pack.ContainerSerialiser()
531
 
    yield pack_writer.begin()
532
 
    yield pack_writer.bytes_record(src_format.network_name(), '')
533
 
    for substream_type, substream in stream:
534
 
        for record in substream:
535
 
            if record.storage_kind in ('chunked', 'fulltext'):
536
 
                serialised = record_to_fulltext_bytes(record)
537
 
            elif record.storage_kind == 'absent':
538
 
                raise ValueError("Absent factory for %s" % (record.key,))
539
 
            else:
540
 
                serialised = record.get_bytes_as(record.storage_kind)
541
 
            if serialised:
542
 
                # Some streams embed the whole stream into the wire
543
 
                # representation of the first record, which means that
544
 
                # later records have no wire representation: we skip them.
545
 
                yield pack_writer.bytes_record(serialised, [(substream_type,)])
546
 
    yield pack_writer.end()
547
 
 
548
 
 
549
 
class _ByteStreamDecoder(object):
550
 
    """Helper for _byte_stream_to_stream.
551
 
 
552
 
    The expected usage of this class is via the function _byte_stream_to_stream
553
 
    which creates a _ByteStreamDecoder, pops off the stream format and then
554
 
    yields the output of record_stream(), the main entry point to
555
 
    _ByteStreamDecoder.
556
 
 
557
 
    Broadly this class has to unwrap two layers of iterators:
558
 
    (type, substream)
559
 
    (substream details)
560
 
 
561
 
    This is complicated by wishing to return type, iterator_for_type, but
562
 
    getting the data for iterator_for_type when we find out type: we can't
563
 
    simply pass a generator down to the NetworkRecordStream parser, instead
564
 
    we have a little local state to seed each NetworkRecordStream instance,
565
 
    and gather the type that we'll be yielding.
566
 
 
567
 
    :ivar byte_stream: The byte stream being decoded.
568
 
    :ivar stream_decoder: A pack parser used to decode the bytestream
569
 
    :ivar current_type: The current type, used to join adjacent records of the
570
 
        same type into a single stream.
571
 
    :ivar first_bytes: The first bytes to give the next NetworkRecordStream.
572
 
    """
573
 
 
574
 
    def __init__(self, byte_stream, record_counter):
575
 
        """Create a _ByteStreamDecoder."""
576
 
        self.stream_decoder = pack.ContainerPushParser()
577
 
        self.current_type = None
578
 
        self.first_bytes = None
579
 
        self.byte_stream = byte_stream
580
 
        self._record_counter = record_counter
581
 
        self.key_count = 0
582
 
 
583
 
    def iter_stream_decoder(self):
584
 
        """Iterate the contents of the pack from stream_decoder."""
585
 
        # dequeue pending items
586
 
        for record in self.stream_decoder.read_pending_records():
587
 
            yield record
588
 
        # Pull bytes of the wire, decode them to records, yield those records.
589
 
        for bytes in self.byte_stream:
590
 
            self.stream_decoder.accept_bytes(bytes)
591
 
            for record in self.stream_decoder.read_pending_records():
592
 
                yield record
593
 
 
594
 
    def iter_substream_bytes(self):
595
 
        if self.first_bytes is not None:
596
 
            yield self.first_bytes
597
 
            # If we run out of pack records, single the outer layer to stop.
598
 
            self.first_bytes = None
599
 
        for record in self.iter_pack_records:
600
 
            record_names, record_bytes = record
601
 
            record_name, = record_names
602
 
            substream_type = record_name[0]
603
 
            if substream_type != self.current_type:
604
 
                # end of a substream, seed the next substream.
605
 
                self.current_type = substream_type
606
 
                self.first_bytes = record_bytes
607
 
                return
608
 
            yield record_bytes
609
 
 
610
 
    def record_stream(self):
611
 
        """Yield substream_type, substream from the byte stream."""
612
 
        def wrap_and_count(pb, rc, substream):
613
 
            """Yield records from stream while showing progress."""
614
 
            counter = 0
615
 
            if rc:
616
 
                if self.current_type != 'revisions' and self.key_count != 0:
617
 
                    # As we know the number of revisions now (in self.key_count)
618
 
                    # we can setup and use record_counter (rc).
619
 
                    if not rc.is_initialized():
620
 
                        rc.setup(self.key_count, self.key_count)
621
 
            for record in substream.read():
622
 
                if rc:
623
 
                    if rc.is_initialized() and counter == rc.STEP:
624
 
                        rc.increment(counter)
625
 
                        pb.update('Estimate', rc.current, rc.max)
626
 
                        counter = 0
627
 
                    if self.current_type == 'revisions':
628
 
                        # Total records is proportional to number of revs
629
 
                        # to fetch. With remote, we used self.key_count to
630
 
                        # track the number of revs. Once we have the revs
631
 
                        # counts in self.key_count, the progress bar changes
632
 
                        # from 'Estimating..' to 'Estimate' above.
633
 
                        self.key_count += 1
634
 
                        if counter == rc.STEP:
635
 
                            pb.update('Estimating..', self.key_count)
636
 
                            counter = 0
637
 
                counter += 1
638
 
                yield record
639
 
 
640
 
        self.seed_state()
641
 
        pb = ui.ui_factory.nested_progress_bar()
642
 
        rc = self._record_counter
643
 
        # Make and consume sub generators, one per substream type:
644
 
        while self.first_bytes is not None:
645
 
            substream = NetworkRecordStream(self.iter_substream_bytes())
646
 
            # after substream is fully consumed, self.current_type is set to
647
 
            # the next type, and self.first_bytes is set to the matching bytes.
648
 
            yield self.current_type, wrap_and_count(pb, rc, substream)
649
 
        if rc:
650
 
            pb.update('Done', rc.max, rc.max)
651
 
        pb.finished()
652
 
 
653
 
    def seed_state(self):
654
 
        """Prepare the _ByteStreamDecoder to decode from the pack stream."""
655
 
        # Set a single generator we can use to get data from the pack stream.
656
 
        self.iter_pack_records = self.iter_stream_decoder()
657
 
        # Seed the very first subiterator with content; after this each one
658
 
        # seeds the next.
659
 
        list(self.iter_substream_bytes())
660
 
 
661
 
 
662
 
def _byte_stream_to_stream(byte_stream, record_counter=None):
663
 
    """Convert a byte stream into a format and a stream.
664
 
 
665
 
    :param byte_stream: A bytes iterator, as output by _stream_to_byte_stream.
666
 
    :return: (RepositoryFormat, stream_generator)
667
 
    """
668
 
    decoder = _ByteStreamDecoder(byte_stream, record_counter)
669
 
    for bytes in byte_stream:
670
 
        decoder.stream_decoder.accept_bytes(bytes)
671
 
        for record in decoder.stream_decoder.read_pending_records(max=1):
672
 
            record_names, src_format_name = record
673
 
            src_format = network_format_registry.get(src_format_name)
674
 
            return src_format, decoder.record_stream()
675
 
 
676
 
 
677
 
class SmartServerRepositoryUnlock(SmartServerRepositoryRequest):
678
 
 
679
 
    def do_repository_request(self, repository, token):
680
 
        try:
681
 
            repository.lock_write(token=token)
682
 
        except errors.TokenMismatch, e:
683
 
            return FailedSmartServerResponse(('TokenMismatch',))
684
 
        repository.dont_leave_lock_in_place()
685
 
        repository.unlock()
686
 
        return SuccessfulSmartServerResponse(('ok',))
687
 
 
688
 
 
689
 
class SmartServerRepositorySetMakeWorkingTrees(SmartServerRepositoryRequest):
690
 
 
691
 
    def do_repository_request(self, repository, str_bool_new_value):
692
 
        if str_bool_new_value == 'True':
693
 
            new_value = True
694
 
        else:
695
 
            new_value = False
696
 
        repository.set_make_working_trees(new_value)
697
 
        return SuccessfulSmartServerResponse(('ok',))
698
 
 
699
 
 
700
 
class SmartServerRepositoryTarball(SmartServerRepositoryRequest):
701
 
    """Get the raw repository files as a tarball.
702
 
 
703
 
    The returned tarball contains a .bzr control directory which in turn
704
 
    contains a repository.
705
 
 
706
 
    This takes one parameter, compression, which currently must be
707
 
    "", "gz", or "bz2".
708
 
 
709
 
    This is used to implement the Repository.copy_content_into operation.
710
 
    """
711
 
 
712
 
    def do_repository_request(self, repository, compression):
713
 
        tmp_dirname, tmp_repo = self._copy_to_tempdir(repository)
714
 
        try:
715
 
            controldir_name = tmp_dirname + '/.bzr'
716
 
            return self._tarfile_response(controldir_name, compression)
717
 
        finally:
718
 
            osutils.rmtree(tmp_dirname)
719
 
 
720
 
    def _copy_to_tempdir(self, from_repo):
721
 
        tmp_dirname = osutils.mkdtemp(prefix='tmpbzrclone')
722
 
        tmp_bzrdir = from_repo.bzrdir._format.initialize(tmp_dirname)
723
 
        tmp_repo = from_repo._format.initialize(tmp_bzrdir)
724
 
        from_repo.copy_content_into(tmp_repo)
725
 
        return tmp_dirname, tmp_repo
726
 
 
727
 
    def _tarfile_response(self, tmp_dirname, compression):
728
 
        temp = tempfile.NamedTemporaryFile()
729
 
        try:
730
 
            self._tarball_of_dir(tmp_dirname, compression, temp.file)
731
 
            # all finished; write the tempfile out to the network
732
 
            temp.seek(0)
733
 
            return SuccessfulSmartServerResponse(('ok',), temp.read())
734
 
            # FIXME: Don't read the whole thing into memory here; rather stream
735
 
            # it out from the file onto the network. mbp 20070411
736
 
        finally:
737
 
            temp.close()
738
 
 
739
 
    def _tarball_of_dir(self, dirname, compression, ofile):
740
 
        import tarfile
741
 
        filename = os.path.basename(ofile.name)
742
 
        tarball = tarfile.open(fileobj=ofile, name=filename,
743
 
            mode='w|' + compression)
744
 
        try:
745
 
            # The tarball module only accepts ascii names, and (i guess)
746
 
            # packs them with their 8bit names.  We know all the files
747
 
            # within the repository have ASCII names so the should be safe
748
 
            # to pack in.
749
 
            dirname = dirname.encode(sys.getfilesystemencoding())
750
 
            # python's tarball module includes the whole path by default so
751
 
            # override it
752
 
            if not dirname.endswith('.bzr'):
753
 
                raise ValueError(dirname)
754
 
            tarball.add(dirname, '.bzr') # recursive by default
755
 
        finally:
756
 
            tarball.close()
757
 
 
758
 
 
759
 
class SmartServerRepositoryInsertStreamLocked(SmartServerRepositoryRequest):
760
 
    """Insert a record stream from a RemoteSink into a repository.
761
 
 
762
 
    This gets bytes pushed to it by the network infrastructure and turns that
763
 
    into a bytes iterator using a thread. That is then processed by
764
 
    _byte_stream_to_stream.
765
 
 
766
 
    New in 1.14.
767
 
    """
768
 
 
769
 
    def do_repository_request(self, repository, resume_tokens, lock_token):
770
 
        """StreamSink.insert_stream for a remote repository."""
771
 
        repository.lock_write(token=lock_token)
772
 
        self.do_insert_stream_request(repository, resume_tokens)
773
 
 
774
 
    def do_insert_stream_request(self, repository, resume_tokens):
775
 
        tokens = [token for token in resume_tokens.split(' ') if token]
776
 
        self.tokens = tokens
777
 
        self.repository = repository
778
 
        self.queue = Queue.Queue()
779
 
        self.insert_thread = threading.Thread(target=self._inserter_thread)
780
 
        self.insert_thread.start()
781
 
 
782
 
    def do_chunk(self, body_stream_chunk):
783
 
        self.queue.put(body_stream_chunk)
784
 
 
785
 
    def _inserter_thread(self):
786
 
        try:
787
 
            src_format, stream = _byte_stream_to_stream(
788
 
                self.blocking_byte_stream())
789
 
            self.insert_result = self.repository._get_sink().insert_stream(
790
 
                stream, src_format, self.tokens)
791
 
            self.insert_ok = True
792
 
        except:
793
 
            self.insert_exception = sys.exc_info()
794
 
            self.insert_ok = False
795
 
 
796
 
    def blocking_byte_stream(self):
797
 
        while True:
798
 
            bytes = self.queue.get()
799
 
            if bytes is StopIteration:
800
 
                return
801
 
            else:
802
 
                yield bytes
803
 
 
804
 
    def do_end(self):
805
 
        self.queue.put(StopIteration)
806
 
        if self.insert_thread is not None:
807
 
            self.insert_thread.join()
808
 
        if not self.insert_ok:
809
 
            exc_info = self.insert_exception
810
 
            raise exc_info[0], exc_info[1], exc_info[2]
811
 
        write_group_tokens, missing_keys = self.insert_result
812
 
        if write_group_tokens or missing_keys:
813
 
            # bzip needed? missing keys should typically be a small set.
814
 
            # Should this be a streaming body response ?
815
 
            missing_keys = sorted(missing_keys)
816
 
            bytes = bencode.bencode((write_group_tokens, missing_keys))
817
 
            self.repository.unlock()
818
 
            return SuccessfulSmartServerResponse(('missing-basis', bytes))
819
 
        else:
820
 
            self.repository.unlock()
821
 
            return SuccessfulSmartServerResponse(('ok', ))
822
 
 
823
 
 
824
 
class SmartServerRepositoryInsertStream_1_19(SmartServerRepositoryInsertStreamLocked):
825
 
    """Insert a record stream from a RemoteSink into a repository.
826
 
 
827
 
    Same as SmartServerRepositoryInsertStreamLocked, except:
828
 
     - the lock token argument is optional
829
 
     - servers that implement this verb accept 'inventory-delta' records in the
830
 
       stream.
831
 
 
832
 
    New in 1.19.
833
 
    """
834
 
 
835
 
    def do_repository_request(self, repository, resume_tokens, lock_token=None):
836
 
        """StreamSink.insert_stream for a remote repository."""
837
 
        SmartServerRepositoryInsertStreamLocked.do_repository_request(
838
 
            self, repository, resume_tokens, lock_token)
839
 
 
840
 
 
841
 
class SmartServerRepositoryInsertStream(SmartServerRepositoryInsertStreamLocked):
842
 
    """Insert a record stream from a RemoteSink into an unlocked repository.
843
 
 
844
 
    This is the same as SmartServerRepositoryInsertStreamLocked, except it
845
 
    takes no lock_tokens; i.e. it works with an unlocked (or lock-free, e.g.
846
 
    like pack format) repository.
847
 
 
848
 
    New in 1.13.
849
 
    """
850
 
 
851
 
    def do_repository_request(self, repository, resume_tokens):
852
 
        """StreamSink.insert_stream for a remote repository."""
853
 
        repository.lock_write()
854
 
        self.do_insert_stream_request(repository, resume_tokens)
855
 
 
856