~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Jelmer Vernooij
  • Date: 2012-03-30 18:16:07 UTC
  • mto: This revision was merged to the branch mainline in revision 6535.
  • Revision ID: jelmer@samba.org-20120330181607-xr5s4v7xyr1y0ob6
Add bzrlib.branchfmt.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
from __future__ import absolute_import
 
18
 
 
19
# mbp: "you know that thing where cvs gives you conflict markers?"
 
20
# s: "i hate that."
 
21
 
 
22
from bzrlib import (
 
23
    errors,
 
24
    patiencediff,
 
25
    textfile,
 
26
    )
 
27
 
 
28
 
 
29
def intersect(ra, rb):
 
30
    """Given two ranges return the range where they intersect or None.
 
31
 
 
32
    >>> intersect((0, 10), (0, 6))
 
33
    (0, 6)
 
34
    >>> intersect((0, 10), (5, 15))
 
35
    (5, 10)
 
36
    >>> intersect((0, 10), (10, 15))
 
37
    >>> intersect((0, 9), (10, 15))
 
38
    >>> intersect((0, 9), (7, 15))
 
39
    (7, 9)
 
40
    """
 
41
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
 
42
 
 
43
    sa = max(ra[0], rb[0])
 
44
    sb = min(ra[1], rb[1])
 
45
    if sa < sb:
 
46
        return sa, sb
 
47
    else:
 
48
        return None
 
49
 
 
50
 
 
51
def compare_range(a, astart, aend, b, bstart, bend):
 
52
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
53
    """
 
54
    if (aend-astart) != (bend-bstart):
 
55
        return False
 
56
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
57
        if a[ia] != b[ib]:
 
58
            return False
 
59
    else:
 
60
        return True
 
61
 
 
62
 
 
63
 
 
64
 
 
65
class Merge3(object):
 
66
    """3-way merge of texts.
 
67
 
 
68
    Given BASE, OTHER, THIS, tries to produce a combined text
 
69
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
70
    All three will typically be sequences of lines."""
 
71
 
 
72
    def __init__(self, base, a, b, is_cherrypick=False, allow_objects=False):
 
73
        """Constructor.
 
74
 
 
75
        :param base: lines in BASE
 
76
        :param a: lines in A
 
77
        :param b: lines in B
 
78
        :param is_cherrypick: flag indicating if this merge is a cherrypick.
 
79
            When cherrypicking b => a, matches with b and base do not conflict.
 
80
        :param allow_objects: if True, do not require that base, a and b are
 
81
            plain Python strs.  Also prevents BinaryFile from being raised.
 
82
            Lines can be any sequence of comparable and hashable Python
 
83
            objects.
 
84
        """
 
85
        if not allow_objects:
 
86
            textfile.check_text_lines(base)
 
87
            textfile.check_text_lines(a)
 
88
            textfile.check_text_lines(b)
 
89
        self.base = base
 
90
        self.a = a
 
91
        self.b = b
 
92
        self.is_cherrypick = is_cherrypick
 
93
 
 
94
    def merge_lines(self,
 
95
                    name_a=None,
 
96
                    name_b=None,
 
97
                    name_base=None,
 
98
                    start_marker='<<<<<<<',
 
99
                    mid_marker='=======',
 
100
                    end_marker='>>>>>>>',
 
101
                    base_marker=None,
 
102
                    reprocess=False):
 
103
        """Return merge in cvs-like form.
 
104
        """
 
105
        newline = '\n'
 
106
        if len(self.a) > 0:
 
107
            if self.a[0].endswith('\r\n'):
 
108
                newline = '\r\n'
 
109
            elif self.a[0].endswith('\r'):
 
110
                newline = '\r'
 
111
        if base_marker and reprocess:
 
112
            raise errors.CantReprocessAndShowBase()
 
113
        if name_a:
 
114
            start_marker = start_marker + ' ' + name_a
 
115
        if name_b:
 
116
            end_marker = end_marker + ' ' + name_b
 
117
        if name_base and base_marker:
 
118
            base_marker = base_marker + ' ' + name_base
 
119
        merge_regions = self.merge_regions()
 
120
        if reprocess is True:
 
121
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
122
        for t in merge_regions:
 
123
            what = t[0]
 
124
            if what == 'unchanged':
 
125
                for i in range(t[1], t[2]):
 
126
                    yield self.base[i]
 
127
            elif what == 'a' or what == 'same':
 
128
                for i in range(t[1], t[2]):
 
129
                    yield self.a[i]
 
130
            elif what == 'b':
 
131
                for i in range(t[1], t[2]):
 
132
                    yield self.b[i]
 
133
            elif what == 'conflict':
 
134
                yield start_marker + newline
 
135
                for i in range(t[3], t[4]):
 
136
                    yield self.a[i]
 
137
                if base_marker is not None:
 
138
                    yield base_marker + newline
 
139
                    for i in range(t[1], t[2]):
 
140
                        yield self.base[i]
 
141
                yield mid_marker + newline
 
142
                for i in range(t[5], t[6]):
 
143
                    yield self.b[i]
 
144
                yield end_marker + newline
 
145
            else:
 
146
                raise ValueError(what)
 
147
 
 
148
    def merge_annotated(self):
 
149
        """Return merge with conflicts, showing origin of lines.
 
150
 
 
151
        Most useful for debugging merge.
 
152
        """
 
153
        for t in self.merge_regions():
 
154
            what = t[0]
 
155
            if what == 'unchanged':
 
156
                for i in range(t[1], t[2]):
 
157
                    yield 'u | ' + self.base[i]
 
158
            elif what == 'a' or what == 'same':
 
159
                for i in range(t[1], t[2]):
 
160
                    yield what[0] + ' | ' + self.a[i]
 
161
            elif what == 'b':
 
162
                for i in range(t[1], t[2]):
 
163
                    yield 'b | ' + self.b[i]
 
164
            elif what == 'conflict':
 
165
                yield '<<<<\n'
 
166
                for i in range(t[3], t[4]):
 
167
                    yield 'A | ' + self.a[i]
 
168
                yield '----\n'
 
169
                for i in range(t[5], t[6]):
 
170
                    yield 'B | ' + self.b[i]
 
171
                yield '>>>>\n'
 
172
            else:
 
173
                raise ValueError(what)
 
174
 
 
175
    def merge_groups(self):
 
176
        """Yield sequence of line groups.  Each one is a tuple:
 
177
 
 
178
        'unchanged', lines
 
179
             Lines unchanged from base
 
180
 
 
181
        'a', lines
 
182
             Lines taken from a
 
183
 
 
184
        'same', lines
 
185
             Lines taken from a (and equal to b)
 
186
 
 
187
        'b', lines
 
188
             Lines taken from b
 
189
 
 
190
        'conflict', base_lines, a_lines, b_lines
 
191
             Lines from base were changed to either a or b and conflict.
 
192
        """
 
193
        for t in self.merge_regions():
 
194
            what = t[0]
 
195
            if what == 'unchanged':
 
196
                yield what, self.base[t[1]:t[2]]
 
197
            elif what == 'a' or what == 'same':
 
198
                yield what, self.a[t[1]:t[2]]
 
199
            elif what == 'b':
 
200
                yield what, self.b[t[1]:t[2]]
 
201
            elif what == 'conflict':
 
202
                yield (what,
 
203
                       self.base[t[1]:t[2]],
 
204
                       self.a[t[3]:t[4]],
 
205
                       self.b[t[5]:t[6]])
 
206
            else:
 
207
                raise ValueError(what)
 
208
 
 
209
    def merge_regions(self):
 
210
        """Return sequences of matching and conflicting regions.
 
211
 
 
212
        This returns tuples, where the first value says what kind we
 
213
        have:
 
214
 
 
215
        'unchanged', start, end
 
216
             Take a region of base[start:end]
 
217
 
 
218
        'same', astart, aend
 
219
             b and a are different from base but give the same result
 
220
 
 
221
        'a', start, end
 
222
             Non-clashing insertion from a[start:end]
 
223
 
 
224
        Method is as follows:
 
225
 
 
226
        The two sequences align only on regions which match the base
 
227
        and both descendents.  These are found by doing a two-way diff
 
228
        of each one against the base, and then finding the
 
229
        intersections between those regions.  These "sync regions"
 
230
        are by definition unchanged in both and easily dealt with.
 
231
 
 
232
        The regions in between can be in any of three cases:
 
233
        conflicted, or changed on only one side.
 
234
        """
 
235
 
 
236
        # section a[0:ia] has been disposed of, etc
 
237
        iz = ia = ib = 0
 
238
 
 
239
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
240
            matchlen = zend - zmatch
 
241
            # invariants:
 
242
            #   matchlen >= 0
 
243
            #   matchlen == (aend - amatch)
 
244
            #   matchlen == (bend - bmatch)
 
245
            len_a = amatch - ia
 
246
            len_b = bmatch - ib
 
247
            len_base = zmatch - iz
 
248
            # invariants:
 
249
            # assert len_a >= 0
 
250
            # assert len_b >= 0
 
251
            # assert len_base >= 0
 
252
 
 
253
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
254
 
 
255
            if len_a or len_b:
 
256
                # try to avoid actually slicing the lists
 
257
                same = compare_range(self.a, ia, amatch,
 
258
                                     self.b, ib, bmatch)
 
259
 
 
260
                if same:
 
261
                    yield 'same', ia, amatch
 
262
                else:
 
263
                    equal_a = compare_range(self.a, ia, amatch,
 
264
                                            self.base, iz, zmatch)
 
265
                    equal_b = compare_range(self.b, ib, bmatch,
 
266
                                            self.base, iz, zmatch)
 
267
                    if equal_a and not equal_b:
 
268
                        yield 'b', ib, bmatch
 
269
                    elif equal_b and not equal_a:
 
270
                        yield 'a', ia, amatch
 
271
                    elif not equal_a and not equal_b:
 
272
                        if self.is_cherrypick:
 
273
                            for node in self._refine_cherrypick_conflict(
 
274
                                                    iz, zmatch, ia, amatch,
 
275
                                                    ib, bmatch):
 
276
                                yield node
 
277
                        else:
 
278
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
279
                    else:
 
280
                        raise AssertionError("can't handle a=b=base but unmatched")
 
281
 
 
282
                ia = amatch
 
283
                ib = bmatch
 
284
            iz = zmatch
 
285
 
 
286
            # if the same part of the base was deleted on both sides
 
287
            # that's OK, we can just skip it.
 
288
 
 
289
            if matchlen > 0:
 
290
                # invariants:
 
291
                # assert ia == amatch
 
292
                # assert ib == bmatch
 
293
                # assert iz == zmatch
 
294
 
 
295
                yield 'unchanged', zmatch, zend
 
296
                iz = zend
 
297
                ia = aend
 
298
                ib = bend
 
299
 
 
300
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
 
301
        """When cherrypicking b => a, ignore matches with b and base."""
 
302
        # Do not emit regions which match, only regions which do not match
 
303
        matches = patiencediff.PatienceSequenceMatcher(None,
 
304
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
 
305
        last_base_idx = 0
 
306
        last_b_idx = 0
 
307
        last_b_idx = 0
 
308
        yielded_a = False
 
309
        for base_idx, b_idx, match_len in matches:
 
310
            conflict_z_len = base_idx - last_base_idx
 
311
            conflict_b_len = b_idx - last_b_idx
 
312
            if conflict_b_len == 0: # There are no lines in b which conflict,
 
313
                                    # so skip it
 
314
                pass
 
315
            else:
 
316
                if yielded_a:
 
317
                    yield ('conflict',
 
318
                           zstart + last_base_idx, zstart + base_idx,
 
319
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
 
320
                else:
 
321
                    # The first conflict gets the a-range
 
322
                    yielded_a = True
 
323
                    yield ('conflict', zstart + last_base_idx, zstart +
 
324
                    base_idx,
 
325
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
 
326
            last_base_idx = base_idx + match_len
 
327
            last_b_idx = b_idx + match_len
 
328
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
 
329
            if yielded_a:
 
330
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
331
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
 
332
            else:
 
333
                # The first conflict gets the a-range
 
334
                yielded_a = True
 
335
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
336
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
 
337
        if not yielded_a:
 
338
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
 
339
 
 
340
    def reprocess_merge_regions(self, merge_regions):
 
341
        """Where there are conflict regions, remove the agreed lines.
 
342
 
 
343
        Lines where both A and B have made the same changes are
 
344
        eliminated.
 
345
        """
 
346
        for region in merge_regions:
 
347
            if region[0] != "conflict":
 
348
                yield region
 
349
                continue
 
350
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
351
            a_region = self.a[ia:amatch]
 
352
            b_region = self.b[ib:bmatch]
 
353
            matches = patiencediff.PatienceSequenceMatcher(
 
354
                    None, a_region, b_region).get_matching_blocks()
 
355
            next_a = ia
 
356
            next_b = ib
 
357
            for region_ia, region_ib, region_len in matches[:-1]:
 
358
                region_ia += ia
 
359
                region_ib += ib
 
360
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
361
                                           region_ib)
 
362
                if reg is not None:
 
363
                    yield reg
 
364
                yield 'same', region_ia, region_len+region_ia
 
365
                next_a = region_ia + region_len
 
366
                next_b = region_ib + region_len
 
367
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
368
            if reg is not None:
 
369
                yield reg
 
370
 
 
371
    @staticmethod
 
372
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
373
        if next_a < region_ia or next_b < region_ib:
 
374
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
375
 
 
376
    def find_sync_regions(self):
 
377
        """Return a list of sync regions, where both descendents match the base.
 
378
 
 
379
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
380
        always a zero-length sync region at the end of all the files.
 
381
        """
 
382
 
 
383
        ia = ib = 0
 
384
        amatches = patiencediff.PatienceSequenceMatcher(
 
385
                None, self.base, self.a).get_matching_blocks()
 
386
        bmatches = patiencediff.PatienceSequenceMatcher(
 
387
                None, self.base, self.b).get_matching_blocks()
 
388
        len_a = len(amatches)
 
389
        len_b = len(bmatches)
 
390
 
 
391
        sl = []
 
392
 
 
393
        while ia < len_a and ib < len_b:
 
394
            abase, amatch, alen = amatches[ia]
 
395
            bbase, bmatch, blen = bmatches[ib]
 
396
 
 
397
            # there is an unconflicted block at i; how long does it
 
398
            # extend?  until whichever one ends earlier.
 
399
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
400
            if i:
 
401
                intbase = i[0]
 
402
                intend = i[1]
 
403
                intlen = intend - intbase
 
404
 
 
405
                # found a match of base[i[0], i[1]]; this may be less than
 
406
                # the region that matches in either one
 
407
                # assert intlen <= alen
 
408
                # assert intlen <= blen
 
409
                # assert abase <= intbase
 
410
                # assert bbase <= intbase
 
411
 
 
412
                asub = amatch + (intbase - abase)
 
413
                bsub = bmatch + (intbase - bbase)
 
414
                aend = asub + intlen
 
415
                bend = bsub + intlen
 
416
 
 
417
                # assert self.base[intbase:intend] == self.a[asub:aend], \
 
418
                #       (self.base[intbase:intend], self.a[asub:aend])
 
419
                # assert self.base[intbase:intend] == self.b[bsub:bend]
 
420
 
 
421
                sl.append((intbase, intend,
 
422
                           asub, aend,
 
423
                           bsub, bend))
 
424
            # advance whichever one ends first in the base text
 
425
            if (abase + alen) < (bbase + blen):
 
426
                ia += 1
 
427
            else:
 
428
                ib += 1
 
429
 
 
430
        intbase = len(self.base)
 
431
        abase = len(self.a)
 
432
        bbase = len(self.b)
 
433
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
434
 
 
435
        return sl
 
436
 
 
437
    def find_unconflicted(self):
 
438
        """Return a list of ranges in base that are not conflicted."""
 
439
        am = patiencediff.PatienceSequenceMatcher(
 
440
                None, self.base, self.a).get_matching_blocks()
 
441
        bm = patiencediff.PatienceSequenceMatcher(
 
442
                None, self.base, self.b).get_matching_blocks()
 
443
 
 
444
        unc = []
 
445
 
 
446
        while am and bm:
 
447
            # there is an unconflicted block at i; how long does it
 
448
            # extend?  until whichever one ends earlier.
 
449
            a1 = am[0][0]
 
450
            a2 = a1 + am[0][2]
 
451
            b1 = bm[0][0]
 
452
            b2 = b1 + bm[0][2]
 
453
            i = intersect((a1, a2), (b1, b2))
 
454
            if i:
 
455
                unc.append(i)
 
456
 
 
457
            if a2 < b2:
 
458
                del am[0]
 
459
            else:
 
460
                del bm[0]
 
461
 
 
462
        return unc
 
463
 
 
464
 
 
465
def main(argv):
 
466
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
467
    a = file(argv[1], 'rt').readlines()
 
468
    base = file(argv[2], 'rt').readlines()
 
469
    b = file(argv[3], 'rt').readlines()
 
470
 
 
471
    m3 = Merge3(base, a, b)
 
472
 
 
473
    #for sr in m3.find_sync_regions():
 
474
    #    print sr
 
475
 
 
476
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
477
    sys.stdout.writelines(m3.merge_annotated())
 
478
 
 
479
 
 
480
if __name__ == '__main__':
 
481
    import sys
 
482
    sys.exit(main(sys.argv))