~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge_directive.py

(mbp) more integrated 0.15 fixes

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2007-2011 Canonical Ltd
 
1
# Copyright (C) 2007 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
 
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
from email import Message
18
19
from StringIO import StringIO
19
 
import re
20
20
 
21
 
from bzrlib import lazy_import
22
 
lazy_import.lazy_import(globals(), """
23
21
from bzrlib import (
24
22
    branch as _mod_branch,
25
23
    diff,
26
 
    email_message,
27
24
    errors,
28
25
    gpg,
29
 
    hooks,
30
 
    registry,
31
26
    revision as _mod_revision,
32
27
    rio,
33
28
    testament,
34
29
    timestamp,
35
 
    trace,
36
 
    )
37
 
from bzrlib.bundle import (
38
 
    serializer as bundle_serializer,
39
 
    )
40
 
""")
41
 
 
42
 
 
43
 
class MergeRequestBodyParams(object):
44
 
    """Parameter object for the merge_request_body hook."""
45
 
 
46
 
    def __init__(self, body, orig_body, directive, to, basename, subject,
47
 
                 branch, tree=None):
48
 
        self.body = body
49
 
        self.orig_body = orig_body
50
 
        self.directive = directive
51
 
        self.branch = branch
52
 
        self.tree = tree
53
 
        self.to = to
54
 
        self.basename = basename
55
 
        self.subject = subject
56
 
 
57
 
 
58
 
class MergeDirectiveHooks(hooks.Hooks):
59
 
    """Hooks for MergeDirective classes."""
60
 
 
61
 
    def __init__(self):
62
 
        hooks.Hooks.__init__(self, "bzrlib.merge_directive", "BaseMergeDirective.hooks")
63
 
        self.add_hook('merge_request_body',
64
 
            "Called with a MergeRequestBodyParams when a body is needed for"
65
 
            " a merge request.  Callbacks must return a body.  If more"
66
 
            " than one callback is registered, the output of one callback is"
67
 
            " provided to the next.", (1, 15, 0))
68
 
 
69
 
 
70
 
class BaseMergeDirective(object):
 
30
    )
 
31
from bzrlib.bundle import serializer as bundle_serializer
 
32
 
 
33
 
 
34
class MergeDirective(object):
 
35
 
71
36
    """A request to perform a merge into a branch.
72
37
 
73
 
    This is the base class that all merge directive implementations 
74
 
    should derive from.
75
 
 
76
 
    :cvar multiple_output_files: Whether or not this merge directive 
77
 
        stores a set of revisions in more than one file
 
38
    Designed to be serialized and mailed.  It provides all the information
 
39
    needed to perform a merge automatically, by providing at minimum a revision
 
40
    bundle or the location of a branch.
 
41
 
 
42
    The serialization format is robust against certain common forms of
 
43
    deterioration caused by mailing.
 
44
 
 
45
    The format is also designed to be patch-compatible.  If the directive
 
46
    includes a diff or revision bundle, it should be possible to apply it
 
47
    directly using the standard patch program.
78
48
    """
79
49
 
80
 
    hooks = MergeDirectiveHooks()
81
 
 
82
 
    multiple_output_files = False
 
50
    _format_string = 'Bazaar merge directive format 1'
83
51
 
84
52
    def __init__(self, revision_id, testament_sha1, time, timezone,
85
 
                 target_branch, patch=None, source_branch=None, message=None,
86
 
                 bundle=None):
 
53
                 target_branch, patch=None, patch_type=None,
 
54
                 source_branch=None, message=None):
87
55
        """Constructor.
88
56
 
89
57
        :param revision_id: The revision to merge
93
61
        :param timezone: The timezone offset
94
62
        :param target_branch: The branch to apply the merge to
95
63
        :param patch: The text of a diff or bundle
 
64
        :param patch_type: None, "diff" or "bundle", depending on the contents
 
65
            of patch
96
66
        :param source_branch: A public location to merge the revision from
97
67
        :param message: The message to use when committing this merge
98
68
        """
 
69
        assert patch_type in (None, 'diff', 'bundle')
 
70
        if patch_type != 'bundle' and source_branch is None:
 
71
            raise errors.NoMergeSource()
 
72
        if patch_type is not None and patch is None:
 
73
            raise errors.PatchMissing(patch_type)
99
74
        self.revision_id = revision_id
100
75
        self.testament_sha1 = testament_sha1
101
76
        self.time = time
102
77
        self.timezone = timezone
103
78
        self.target_branch = target_branch
104
79
        self.patch = patch
 
80
        self.patch_type = patch_type
105
81
        self.source_branch = source_branch
106
82
        self.message = message
107
83
 
 
84
    @classmethod
 
85
    def from_lines(klass, lines):
 
86
        """Deserialize a MergeRequest from an iterable of lines
 
87
 
 
88
        :param lines: An iterable of lines
 
89
        :return: a MergeRequest
 
90
        """
 
91
        line_iter = iter(lines)
 
92
        for line in line_iter:
 
93
            if line.startswith('# ' + klass._format_string):
 
94
                break
 
95
        else:
 
96
            if len(lines) > 0:
 
97
                raise errors.NotAMergeDirective(lines[0])
 
98
            else:
 
99
                raise errors.NotAMergeDirective('')
 
100
        stanza = rio.read_patch_stanza(line_iter)
 
101
        patch_lines = list(line_iter)
 
102
        if len(patch_lines) == 0:
 
103
            patch = None
 
104
            patch_type = None
 
105
        else:
 
106
            patch = ''.join(patch_lines)
 
107
            try:
 
108
                bundle_serializer.read_bundle(StringIO(patch))
 
109
            except errors.NotABundle:
 
110
                patch_type = 'diff'
 
111
            else:
 
112
                patch_type = 'bundle'
 
113
        time, timezone = timestamp.parse_patch_date(stanza.get('timestamp'))
 
114
        kwargs = {}
 
115
        for key in ('revision_id', 'testament_sha1', 'target_branch',
 
116
                    'source_branch', 'message'):
 
117
            try:
 
118
                kwargs[key] = stanza.get(key)
 
119
            except KeyError:
 
120
                pass
 
121
        kwargs['revision_id'] = kwargs['revision_id'].encode('utf-8')
 
122
        return MergeDirective(time=time, timezone=timezone,
 
123
                              patch_type=patch_type, patch=patch, **kwargs)
 
124
 
108
125
    def to_lines(self):
109
126
        """Serialize as a list of lines
110
127
 
111
128
        :return: a list of lines
112
129
        """
113
 
        raise NotImplementedError(self.to_lines)
114
 
 
115
 
    def to_files(self):
116
 
        """Serialize as a set of files.
117
 
 
118
 
        :return: List of tuples with filename and contents as lines
119
 
        """
120
 
        raise NotImplementedError(self.to_files)
121
 
 
122
 
    def get_raw_bundle(self):
123
 
        """Return the bundle for this merge directive.
124
 
 
125
 
        :return: bundle text or None if there is no bundle
126
 
        """
127
 
        return None
128
 
 
129
 
    def _to_lines(self, base_revision=False):
130
 
        """Serialize as a list of lines
131
 
 
132
 
        :return: a list of lines
133
 
        """
134
130
        time_str = timestamp.format_patch_date(self.time, self.timezone)
135
131
        stanza = rio.Stanza(revision_id=self.revision_id, timestamp=time_str,
136
132
                            target_branch=self.target_branch,
138
134
        for key in ('source_branch', 'message'):
139
135
            if self.__dict__[key] is not None:
140
136
                stanza.add(key, self.__dict__[key])
141
 
        if base_revision:
142
 
            stanza.add('base_revision_id', self.base_revision_id)
143
137
        lines = ['# ' + self._format_string + '\n']
144
138
        lines.extend(rio.to_patch_lines(stanza))
145
139
        lines.append('# \n')
 
140
        if self.patch is not None:
 
141
            lines.extend(self.patch.splitlines(True))
146
142
        return lines
147
143
 
148
 
    def write_to_directory(self, path):
149
 
        """Write this merge directive to a series of files in a directory.
150
 
 
151
 
        :param path: Filesystem path to write to
152
 
        """
153
 
        raise NotImplementedError(self.write_to_directory)
154
 
 
155
 
    @classmethod
156
 
    def from_objects(klass, repository, revision_id, time, timezone,
157
 
                 target_branch, patch_type='bundle',
158
 
                 local_target_branch=None, public_branch=None, message=None):
159
 
        """Generate a merge directive from various objects
160
 
 
161
 
        :param repository: The repository containing the revision
162
 
        :param revision_id: The revision to merge
163
 
        :param time: The POSIX timestamp of the date the request was issued.
164
 
        :param timezone: The timezone of the request
165
 
        :param target_branch: The url of the branch to merge into
166
 
        :param patch_type: 'bundle', 'diff' or None, depending on the type of
167
 
            patch desired.
168
 
        :param local_target_branch: a local copy of the target branch
169
 
        :param public_branch: location of a public branch containing the target
170
 
            revision.
171
 
        :param message: Message to use when committing the merge
172
 
        :return: The merge directive
173
 
 
174
 
        The public branch is always used if supplied.  If the patch_type is
175
 
        not 'bundle', the public branch must be supplied, and will be verified.
176
 
 
177
 
        If the message is not supplied, the message from revision_id will be
178
 
        used for the commit.
179
 
        """
180
 
        t_revision_id = revision_id
181
 
        if revision_id == _mod_revision.NULL_REVISION:
182
 
            t_revision_id = None
183
 
        t = testament.StrictTestament3.from_revision(repository, t_revision_id)
184
 
        submit_branch = _mod_branch.Branch.open(target_branch)
185
 
        if submit_branch.get_public_branch() is not None:
186
 
            target_branch = submit_branch.get_public_branch()
187
 
        if patch_type is None:
188
 
            patch = None
189
 
        else:
190
 
            submit_revision_id = submit_branch.last_revision()
191
 
            submit_revision_id = _mod_revision.ensure_null(submit_revision_id)
192
 
            repository.fetch(submit_branch.repository, submit_revision_id)
193
 
            graph = repository.get_graph()
194
 
            ancestor_id = graph.find_unique_lca(revision_id,
195
 
                                                submit_revision_id)
196
 
            type_handler = {'bundle': klass._generate_bundle,
197
 
                            'diff': klass._generate_diff,
198
 
                            None: lambda x, y, z: None }
199
 
            patch = type_handler[patch_type](repository, revision_id,
200
 
                                             ancestor_id)
201
 
 
202
 
        if public_branch is not None and patch_type != 'bundle':
203
 
            public_branch_obj = _mod_branch.Branch.open(public_branch)
204
 
            if not public_branch_obj.repository.has_revision(revision_id):
205
 
                raise errors.PublicBranchOutOfDate(public_branch,
206
 
                                                   revision_id)
207
 
 
208
 
        return klass(revision_id, t.as_sha1(), time, timezone, target_branch,
209
 
            patch, patch_type, public_branch, message)
210
 
 
211
 
    def get_disk_name(self, branch):
212
 
        """Generate a suitable basename for storing this directive on disk
213
 
 
214
 
        :param branch: The Branch this merge directive was generated fro
215
 
        :return: A string
216
 
        """
217
 
        revno, revision_id = branch.last_revision_info()
218
 
        if self.revision_id == revision_id:
219
 
            revno = [revno]
220
 
        else:
221
 
            revno = branch.get_revision_id_to_revno_map().get(self.revision_id,
222
 
                ['merge'])
223
 
        nick = re.sub('(\W+)', '-', branch.nick).strip('-')
224
 
        return '%s-%s' % (nick, '.'.join(str(n) for n in revno))
225
 
 
226
 
    @staticmethod
227
 
    def _generate_diff(repository, revision_id, ancestor_id):
228
 
        tree_1 = repository.revision_tree(ancestor_id)
229
 
        tree_2 = repository.revision_tree(revision_id)
230
 
        s = StringIO()
231
 
        diff.show_diff_trees(tree_1, tree_2, s, old_label='', new_label='')
232
 
        return s.getvalue()
233
 
 
234
 
    @staticmethod
235
 
    def _generate_bundle(repository, revision_id, ancestor_id):
236
 
        s = StringIO()
237
 
        bundle_serializer.write_bundle(repository, revision_id,
238
 
                                       ancestor_id, s)
239
 
        return s.getvalue()
240
 
 
241
144
    def to_signed(self, branch):
242
145
        """Serialize as a signed string.
243
146
 
257
160
        :return: an email message
258
161
        """
259
162
        mail_from = branch.get_config().username()
 
163
        message = Message.Message()
 
164
        message['To'] = mail_to
 
165
        message['From'] = mail_from
260
166
        if self.message is not None:
261
 
            subject = self.message
 
167
            message['Subject'] = self.message
262
168
        else:
263
169
            revision = branch.repository.get_revision(self.revision_id)
264
 
            subject = revision.message
 
170
            message['Subject'] = revision.message
265
171
        if sign:
266
172
            body = self.to_signed(branch)
267
173
        else:
268
174
            body = ''.join(self.to_lines())
269
 
        message = email_message.EmailMessage(mail_from, mail_to, subject,
270
 
            body)
 
175
        message.set_payload(body)
271
176
        return message
272
177
 
273
 
    def install_revisions(self, target_repo):
274
 
        """Install revisions and return the target revision"""
275
 
        if not target_repo.has_revision(self.revision_id):
276
 
            if self.patch_type == 'bundle':
277
 
                info = bundle_serializer.read_bundle(
278
 
                    StringIO(self.get_raw_bundle()))
279
 
                # We don't use the bundle's target revision, because
280
 
                # MergeDirective.revision_id is authoritative.
281
 
                try:
282
 
                    info.install_revisions(target_repo, stream_input=False)
283
 
                except errors.RevisionNotPresent:
284
 
                    # At least one dependency isn't present.  Try installing
285
 
                    # missing revisions from the submit branch
286
 
                    try:
287
 
                        submit_branch = \
288
 
                            _mod_branch.Branch.open(self.target_branch)
289
 
                    except errors.NotBranchError:
290
 
                        raise errors.TargetNotBranch(self.target_branch)
291
 
                    missing_revisions = []
292
 
                    bundle_revisions = set(r.revision_id for r in
293
 
                                           info.real_revisions)
294
 
                    for revision in info.real_revisions:
295
 
                        for parent_id in revision.parent_ids:
296
 
                            if (parent_id not in bundle_revisions and
297
 
                                not target_repo.has_revision(parent_id)):
298
 
                                missing_revisions.append(parent_id)
299
 
                    # reverse missing revisions to try to get heads first
300
 
                    unique_missing = []
301
 
                    unique_missing_set = set()
302
 
                    for revision in reversed(missing_revisions):
303
 
                        if revision in unique_missing_set:
304
 
                            continue
305
 
                        unique_missing.append(revision)
306
 
                        unique_missing_set.add(revision)
307
 
                    for missing_revision in unique_missing:
308
 
                        target_repo.fetch(submit_branch.repository,
309
 
                                          missing_revision)
310
 
                    info.install_revisions(target_repo, stream_input=False)
311
 
            else:
312
 
                source_branch = _mod_branch.Branch.open(self.source_branch)
313
 
                target_repo.fetch(source_branch.repository, self.revision_id)
314
 
        return self.revision_id
315
 
 
316
 
    def compose_merge_request(self, mail_client, to, body, branch, tree=None):
317
 
        """Compose a request to merge this directive.
318
 
 
319
 
        :param mail_client: The mail client to use for composing this request.
320
 
        :param to: The address to compose the request to.
321
 
        :param branch: The Branch that was used to produce this directive.
322
 
        :param tree: The Tree (if any) for the Branch used to produce this
323
 
            directive.
324
 
        """
325
 
        basename = self.get_disk_name(branch)
326
 
        subject = '[MERGE] '
327
 
        if self.message is not None:
328
 
            subject += self.message
329
 
        else:
330
 
            revision = branch.repository.get_revision(self.revision_id)
331
 
            subject += revision.get_summary()
332
 
        if getattr(mail_client, 'supports_body', False):
333
 
            orig_body = body
334
 
            for hook in self.hooks['merge_request_body']:
335
 
                params = MergeRequestBodyParams(body, orig_body, self,
336
 
                                                to, basename, subject, branch,
337
 
                                                tree)
338
 
                body = hook(params)
339
 
        elif len(self.hooks['merge_request_body']) > 0:
340
 
            trace.warning('Cannot run merge_request_body hooks because mail'
341
 
                          ' client %s does not support message bodies.',
342
 
                        mail_client.__class__.__name__)
343
 
        mail_client.compose_merge_request(to, subject,
344
 
                                          ''.join(self.to_lines()),
345
 
                                          basename, body)
346
 
 
347
 
 
348
 
class MergeDirective(BaseMergeDirective):
349
 
 
350
 
    """A request to perform a merge into a branch.
351
 
 
352
 
    Designed to be serialized and mailed.  It provides all the information
353
 
    needed to perform a merge automatically, by providing at minimum a revision
354
 
    bundle or the location of a branch.
355
 
 
356
 
    The serialization format is robust against certain common forms of
357
 
    deterioration caused by mailing.
358
 
 
359
 
    The format is also designed to be patch-compatible.  If the directive
360
 
    includes a diff or revision bundle, it should be possible to apply it
361
 
    directly using the standard patch program.
362
 
    """
363
 
 
364
 
    _format_string = 'Bazaar merge directive format 1'
365
 
 
366
 
    def __init__(self, revision_id, testament_sha1, time, timezone,
367
 
                 target_branch, patch=None, patch_type=None,
368
 
                 source_branch=None, message=None, bundle=None):
369
 
        """Constructor.
370
 
 
371
 
        :param revision_id: The revision to merge
372
 
        :param testament_sha1: The sha1 of the testament of the revision to
373
 
            merge.
374
 
        :param time: The current POSIX timestamp time
375
 
        :param timezone: The timezone offset
376
 
        :param target_branch: The branch to apply the merge to
377
 
        :param patch: The text of a diff or bundle
378
 
        :param patch_type: None, "diff" or "bundle", depending on the contents
379
 
            of patch
380
 
        :param source_branch: A public location to merge the revision from
381
 
        :param message: The message to use when committing this merge
382
 
        """
383
 
        BaseMergeDirective.__init__(self, revision_id, testament_sha1, time,
384
 
            timezone, target_branch, patch, source_branch, message)
385
 
        if patch_type not in (None, 'diff', 'bundle'):
386
 
            raise ValueError(patch_type)
387
 
        if patch_type != 'bundle' and source_branch is None:
388
 
            raise errors.NoMergeSource()
389
 
        if patch_type is not None and patch is None:
390
 
            raise errors.PatchMissing(patch_type)
391
 
        self.patch_type = patch_type
392
 
 
393
 
    def clear_payload(self):
394
 
        self.patch = None
395
 
        self.patch_type = None
396
 
 
397
 
    def get_raw_bundle(self):
398
 
        return self.bundle
399
 
 
400
 
    def _bundle(self):
401
 
        if self.patch_type == 'bundle':
402
 
            return self.patch
403
 
        else:
404
 
            return None
405
 
 
406
 
    bundle = property(_bundle)
407
 
 
408
 
    @classmethod
409
 
    def from_lines(klass, lines):
410
 
        """Deserialize a MergeRequest from an iterable of lines
411
 
 
412
 
        :param lines: An iterable of lines
413
 
        :return: a MergeRequest
414
 
        """
415
 
        line_iter = iter(lines)
416
 
        firstline = ""
417
 
        for line in line_iter:
418
 
            if line.startswith('# Bazaar merge directive format '):
419
 
                return _format_registry.get(line[2:].rstrip())._from_lines(
420
 
                    line_iter)
421
 
            firstline = firstline or line.strip()
422
 
        raise errors.NotAMergeDirective(firstline)
423
 
 
424
 
    @classmethod
425
 
    def _from_lines(klass, line_iter):
426
 
        stanza = rio.read_patch_stanza(line_iter)
427
 
        patch_lines = list(line_iter)
428
 
        if len(patch_lines) == 0:
429
 
            patch = None
430
 
            patch_type = None
431
 
        else:
432
 
            patch = ''.join(patch_lines)
433
 
            try:
434
 
                bundle_serializer.read_bundle(StringIO(patch))
435
 
            except (errors.NotABundle, errors.BundleNotSupported,
436
 
                    errors.BadBundle):
437
 
                patch_type = 'diff'
438
 
            else:
439
 
                patch_type = 'bundle'
440
 
        time, timezone = timestamp.parse_patch_date(stanza.get('timestamp'))
441
 
        kwargs = {}
442
 
        for key in ('revision_id', 'testament_sha1', 'target_branch',
443
 
                    'source_branch', 'message'):
444
 
            try:
445
 
                kwargs[key] = stanza.get(key)
446
 
            except KeyError:
447
 
                pass
448
 
        kwargs['revision_id'] = kwargs['revision_id'].encode('utf-8')
449
 
        return MergeDirective(time=time, timezone=timezone,
450
 
                              patch_type=patch_type, patch=patch, **kwargs)
451
 
 
452
 
    def to_lines(self):
453
 
        lines = self._to_lines()
454
 
        if self.patch is not None:
455
 
            lines.extend(self.patch.splitlines(True))
456
 
        return lines
457
 
 
458
 
    @staticmethod
459
 
    def _generate_bundle(repository, revision_id, ancestor_id):
460
 
        s = StringIO()
461
 
        bundle_serializer.write_bundle(repository, revision_id,
462
 
                                       ancestor_id, s, '0.9')
463
 
        return s.getvalue()
464
 
 
465
 
    def get_merge_request(self, repository):
466
 
        """Provide data for performing a merge
467
 
 
468
 
        Returns suggested base, suggested target, and patch verification status
469
 
        """
470
 
        return None, self.revision_id, 'inapplicable'
471
 
 
472
 
 
473
 
class MergeDirective2(BaseMergeDirective):
474
 
 
475
 
    _format_string = 'Bazaar merge directive format 2 (Bazaar 0.90)'
476
 
 
477
 
    def __init__(self, revision_id, testament_sha1, time, timezone,
478
 
                 target_branch, patch=None, source_branch=None, message=None,
479
 
                 bundle=None, base_revision_id=None):
480
 
        if source_branch is None and bundle is None:
481
 
            raise errors.NoMergeSource()
482
 
        BaseMergeDirective.__init__(self, revision_id, testament_sha1, time,
483
 
            timezone, target_branch, patch, source_branch, message)
484
 
        self.bundle = bundle
485
 
        self.base_revision_id = base_revision_id
486
 
 
487
 
    def _patch_type(self):
488
 
        if self.bundle is not None:
489
 
            return 'bundle'
490
 
        elif self.patch is not None:
491
 
            return 'diff'
492
 
        else:
493
 
            return None
494
 
 
495
 
    patch_type = property(_patch_type)
496
 
 
497
 
    def clear_payload(self):
498
 
        self.patch = None
499
 
        self.bundle = None
500
 
 
501
 
    def get_raw_bundle(self):
502
 
        if self.bundle is None:
503
 
            return None
504
 
        else:
505
 
            return self.bundle.decode('base-64')
506
 
 
507
 
    @classmethod
508
 
    def _from_lines(klass, line_iter):
509
 
        stanza = rio.read_patch_stanza(line_iter)
510
 
        patch = None
511
 
        bundle = None
512
 
        try:
513
 
            start = line_iter.next()
514
 
        except StopIteration:
515
 
            pass
516
 
        else:
517
 
            if start.startswith('# Begin patch'):
518
 
                patch_lines = []
519
 
                for line in line_iter:
520
 
                    if line.startswith('# Begin bundle'):
521
 
                        start = line
522
 
                        break
523
 
                    patch_lines.append(line)
524
 
                else:
525
 
                    start = None
526
 
                patch = ''.join(patch_lines)
527
 
            if start is not None:
528
 
                if start.startswith('# Begin bundle'):
529
 
                    bundle = ''.join(line_iter)
530
 
                else:
531
 
                    raise errors.IllegalMergeDirectivePayload(start)
532
 
        time, timezone = timestamp.parse_patch_date(stanza.get('timestamp'))
533
 
        kwargs = {}
534
 
        for key in ('revision_id', 'testament_sha1', 'target_branch',
535
 
                    'source_branch', 'message', 'base_revision_id'):
536
 
            try:
537
 
                kwargs[key] = stanza.get(key)
538
 
            except KeyError:
539
 
                pass
540
 
        kwargs['revision_id'] = kwargs['revision_id'].encode('utf-8')
541
 
        kwargs['base_revision_id'] =\
542
 
            kwargs['base_revision_id'].encode('utf-8')
543
 
        return klass(time=time, timezone=timezone, patch=patch, bundle=bundle,
544
 
                     **kwargs)
545
 
 
546
 
    def to_lines(self):
547
 
        lines = self._to_lines(base_revision=True)
548
 
        if self.patch is not None:
549
 
            lines.append('# Begin patch\n')
550
 
            lines.extend(self.patch.splitlines(True))
551
 
        if self.bundle is not None:
552
 
            lines.append('# Begin bundle\n')
553
 
            lines.extend(self.bundle.splitlines(True))
554
 
        return lines
555
 
 
556
178
    @classmethod
557
179
    def from_objects(klass, repository, revision_id, time, timezone,
558
 
                 target_branch, include_patch=True, include_bundle=True,
559
 
                 local_target_branch=None, public_branch=None, message=None,
560
 
                 base_revision_id=None):
 
180
                 target_branch, patch_type='bundle',
 
181
                 local_target_branch=None, public_branch=None, message=None):
561
182
        """Generate a merge directive from various objects
562
183
 
563
184
        :param repository: The repository containing the revision
565
186
        :param time: The POSIX timestamp of the date the request was issued.
566
187
        :param timezone: The timezone of the request
567
188
        :param target_branch: The url of the branch to merge into
568
 
        :param include_patch: If true, include a preview patch
569
 
        :param include_bundle: If true, include a bundle
 
189
        :param patch_type: 'bundle', 'diff' or None, depending on the type of
 
190
            patch desired.
570
191
        :param local_target_branch: a local copy of the target branch
571
192
        :param public_branch: location of a public branch containing the target
572
193
            revision.
573
194
        :param message: Message to use when committing the merge
574
195
        :return: The merge directive
575
196
 
576
 
        The public branch is always used if supplied.  If no bundle is
577
 
        included, the public branch must be supplied, and will be verified.
 
197
        The public branch is always used if supplied.  If the patch_type is
 
198
        not 'bundle', the public branch must be supplied, and will be verified.
578
199
 
579
200
        If the message is not supplied, the message from revision_id will be
580
201
        used for the commit.
581
202
        """
582
 
        locked = []
583
 
        try:
584
 
            repository.lock_write()
585
 
            locked.append(repository)
586
 
            t_revision_id = revision_id
587
 
            if revision_id == 'null:':
588
 
                t_revision_id = None
589
 
            t = testament.StrictTestament3.from_revision(repository,
590
 
                t_revision_id)
591
 
            submit_branch = _mod_branch.Branch.open(target_branch)
592
 
            submit_branch.lock_read()
593
 
            locked.append(submit_branch)
594
 
            if submit_branch.get_public_branch() is not None:
595
 
                target_branch = submit_branch.get_public_branch()
 
203
        t = testament.StrictTestament3.from_revision(repository, revision_id)
 
204
        submit_branch = _mod_branch.Branch.open(target_branch)
 
205
        if submit_branch.get_public_branch() is not None:
 
206
            target_branch = submit_branch.get_public_branch()
 
207
        if patch_type is None:
 
208
            patch = None
 
209
        else:
596
210
            submit_revision_id = submit_branch.last_revision()
597
 
            submit_revision_id = _mod_revision.ensure_null(submit_revision_id)
598
 
            graph = repository.get_graph(submit_branch.repository)
599
 
            ancestor_id = graph.find_unique_lca(revision_id,
600
 
                                                submit_revision_id)
601
 
            if base_revision_id is None:
602
 
                base_revision_id = ancestor_id
603
 
            if (include_patch, include_bundle) != (False, False):
604
 
                repository.fetch(submit_branch.repository, submit_revision_id)
605
 
            if include_patch:
 
211
            repository.fetch(submit_branch.repository, submit_revision_id)
 
212
            ancestor_id = _mod_revision.common_ancestor(revision_id,
 
213
                                                        submit_revision_id,
 
214
                                                        repository)
 
215
            type_handler = {'bundle': klass._generate_bundle,
 
216
                            'diff': klass._generate_diff,
 
217
                            None: lambda x, y, z: None }
 
218
            patch = type_handler[patch_type](repository, revision_id,
 
219
                                             ancestor_id)
 
220
            if patch_type == 'bundle':
 
221
                s = StringIO()
 
222
                bundle_serializer.write_bundle(repository, revision_id,
 
223
                                               ancestor_id, s)
 
224
                patch = s.getvalue()
 
225
            elif patch_type == 'diff':
606
226
                patch = klass._generate_diff(repository, revision_id,
607
 
                                             base_revision_id)
608
 
            else:
609
 
                patch = None
610
 
 
611
 
            if include_bundle:
612
 
                bundle = klass._generate_bundle(repository, revision_id,
613
 
                    ancestor_id).encode('base-64')
614
 
            else:
615
 
                bundle = None
616
 
 
617
 
            if public_branch is not None and not include_bundle:
 
227
                                             ancestor_id)
 
228
 
 
229
            if public_branch is not None and patch_type != 'bundle':
618
230
                public_branch_obj = _mod_branch.Branch.open(public_branch)
619
 
                public_branch_obj.lock_read()
620
 
                locked.append(public_branch_obj)
621
 
                if not public_branch_obj.repository.has_revision(
622
 
                    revision_id):
 
231
                if not public_branch_obj.repository.has_revision(revision_id):
623
232
                    raise errors.PublicBranchOutOfDate(public_branch,
624
233
                                                       revision_id)
625
 
            testament_sha1 = t.as_sha1()
626
 
        finally:
627
 
            for entry in reversed(locked):
628
 
                entry.unlock()
629
 
        return klass(revision_id, testament_sha1, time, timezone,
630
 
            target_branch, patch, public_branch, message, bundle,
631
 
            base_revision_id)
632
 
 
633
 
    def _verify_patch(self, repository):
634
 
        calculated_patch = self._generate_diff(repository, self.revision_id,
635
 
                                               self.base_revision_id)
636
 
        # Convert line-endings to UNIX
637
 
        stored_patch = re.sub('\r\n?', '\n', self.patch)
638
 
        calculated_patch = re.sub('\r\n?', '\n', calculated_patch)
639
 
        # Strip trailing whitespace
640
 
        calculated_patch = re.sub(' *\n', '\n', calculated_patch)
641
 
        stored_patch = re.sub(' *\n', '\n', stored_patch)
642
 
        return (calculated_patch == stored_patch)
643
 
 
644
 
    def get_merge_request(self, repository):
645
 
        """Provide data for performing a merge
646
 
 
647
 
        Returns suggested base, suggested target, and patch verification status
648
 
        """
649
 
        verified = self._maybe_verify(repository)
650
 
        return self.base_revision_id, self.revision_id, verified
651
 
 
652
 
    def _maybe_verify(self, repository):
653
 
        if self.patch is not None:
654
 
            if self._verify_patch(repository):
655
 
                return 'verified'
656
 
            else:
657
 
                return 'failed'
658
 
        else:
659
 
            return 'inapplicable'
660
 
 
661
 
 
662
 
class MergeDirectiveFormatRegistry(registry.Registry):
663
 
 
664
 
    def register(self, directive, format_string=None):
665
 
        if format_string is None:
666
 
            format_string = directive._format_string
667
 
        registry.Registry.register(self, format_string, directive)
668
 
 
669
 
 
670
 
_format_registry = MergeDirectiveFormatRegistry()
671
 
_format_registry.register(MergeDirective)
672
 
_format_registry.register(MergeDirective2)
673
 
# 0.19 never existed.  It got renamed to 0.90.  But by that point, there were
674
 
# already merge directives in the wild that used 0.19. Registering with the old
675
 
# format string to retain compatibility with those merge directives.
676
 
_format_registry.register(MergeDirective2,
677
 
                          'Bazaar merge directive format 2 (Bazaar 0.19)')
 
234
 
 
235
        return MergeDirective(revision_id, t.as_sha1(), time, timezone,
 
236
                              target_branch, patch, patch_type, public_branch,
 
237
                              message)
 
238
 
 
239
    @staticmethod
 
240
    def _generate_diff(repository, revision_id, ancestor_id):
 
241
        tree_1 = repository.revision_tree(ancestor_id)
 
242
        tree_2 = repository.revision_tree(revision_id)
 
243
        s = StringIO()
 
244
        diff.show_diff_trees(tree_1, tree_2, s, old_label='', new_label='')
 
245
        return s.getvalue()
 
246
 
 
247
    @staticmethod
 
248
    def _generate_bundle(repository, revision_id, ancestor_id):
 
249
        s = StringIO()
 
250
        bundle_serializer.write_bundle(repository, revision_id,
 
251
                                       ancestor_id, s)
 
252
        return s.getvalue()