~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Packman
  • Date: 2011-12-08 19:00:14 UTC
  • mto: This revision was merged to the branch mainline in revision 6359.
  • Revision ID: martin.packman@canonical.com-20111208190014-mi8jm6v7jygmhb0r
Use --include-duplicates for make update-pot which already combines multiple msgid strings prettily

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005-2010 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
from bzrlib import (
 
22
    errors,
 
23
    patiencediff,
 
24
    textfile,
 
25
    )
 
26
 
 
27
 
 
28
def intersect(ra, rb):
 
29
    """Given two ranges return the range where they intersect or None.
 
30
 
 
31
    >>> intersect((0, 10), (0, 6))
 
32
    (0, 6)
 
33
    >>> intersect((0, 10), (5, 15))
 
34
    (5, 10)
 
35
    >>> intersect((0, 10), (10, 15))
 
36
    >>> intersect((0, 9), (10, 15))
 
37
    >>> intersect((0, 9), (7, 15))
 
38
    (7, 9)
 
39
    """
 
40
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
 
41
 
 
42
    sa = max(ra[0], rb[0])
 
43
    sb = min(ra[1], rb[1])
 
44
    if sa < sb:
 
45
        return sa, sb
 
46
    else:
 
47
        return None
 
48
 
 
49
 
 
50
def compare_range(a, astart, aend, b, bstart, bend):
 
51
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
52
    """
 
53
    if (aend-astart) != (bend-bstart):
 
54
        return False
 
55
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
56
        if a[ia] != b[ib]:
 
57
            return False
 
58
    else:
 
59
        return True
 
60
 
 
61
 
 
62
 
 
63
 
 
64
class Merge3(object):
 
65
    """3-way merge of texts.
 
66
 
 
67
    Given BASE, OTHER, THIS, tries to produce a combined text
 
68
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
69
    All three will typically be sequences of lines."""
 
70
 
 
71
    def __init__(self, base, a, b, is_cherrypick=False, allow_objects=False):
 
72
        """Constructor.
 
73
 
 
74
        :param base: lines in BASE
 
75
        :param a: lines in A
 
76
        :param b: lines in B
 
77
        :param is_cherrypick: flag indicating if this merge is a cherrypick.
 
78
            When cherrypicking b => a, matches with b and base do not conflict.
 
79
        :param allow_objects: if True, do not require that base, a and b are
 
80
            plain Python strs.  Also prevents BinaryFile from being raised.
 
81
            Lines can be any sequence of comparable and hashable Python
 
82
            objects.
 
83
        """
 
84
        if not allow_objects:
 
85
            textfile.check_text_lines(base)
 
86
            textfile.check_text_lines(a)
 
87
            textfile.check_text_lines(b)
 
88
        self.base = base
 
89
        self.a = a
 
90
        self.b = b
 
91
        self.is_cherrypick = is_cherrypick
 
92
 
 
93
    def merge_lines(self,
 
94
                    name_a=None,
 
95
                    name_b=None,
 
96
                    name_base=None,
 
97
                    start_marker='<<<<<<<',
 
98
                    mid_marker='=======',
 
99
                    end_marker='>>>>>>>',
 
100
                    base_marker=None,
 
101
                    reprocess=False):
 
102
        """Return merge in cvs-like form.
 
103
        """
 
104
        newline = '\n'
 
105
        if len(self.a) > 0:
 
106
            if self.a[0].endswith('\r\n'):
 
107
                newline = '\r\n'
 
108
            elif self.a[0].endswith('\r'):
 
109
                newline = '\r'
 
110
        if base_marker and reprocess:
 
111
            raise errors.CantReprocessAndShowBase()
 
112
        if name_a:
 
113
            start_marker = start_marker + ' ' + name_a
 
114
        if name_b:
 
115
            end_marker = end_marker + ' ' + name_b
 
116
        if name_base and base_marker:
 
117
            base_marker = base_marker + ' ' + name_base
 
118
        merge_regions = self.merge_regions()
 
119
        if reprocess is True:
 
120
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
121
        for t in merge_regions:
 
122
            what = t[0]
 
123
            if what == 'unchanged':
 
124
                for i in range(t[1], t[2]):
 
125
                    yield self.base[i]
 
126
            elif what == 'a' or what == 'same':
 
127
                for i in range(t[1], t[2]):
 
128
                    yield self.a[i]
 
129
            elif what == 'b':
 
130
                for i in range(t[1], t[2]):
 
131
                    yield self.b[i]
 
132
            elif what == 'conflict':
 
133
                yield start_marker + newline
 
134
                for i in range(t[3], t[4]):
 
135
                    yield self.a[i]
 
136
                if base_marker is not None:
 
137
                    yield base_marker + newline
 
138
                    for i in range(t[1], t[2]):
 
139
                        yield self.base[i]
 
140
                yield mid_marker + newline
 
141
                for i in range(t[5], t[6]):
 
142
                    yield self.b[i]
 
143
                yield end_marker + newline
 
144
            else:
 
145
                raise ValueError(what)
 
146
 
 
147
    def merge_annotated(self):
 
148
        """Return merge with conflicts, showing origin of lines.
 
149
 
 
150
        Most useful for debugging merge.
 
151
        """
 
152
        for t in self.merge_regions():
 
153
            what = t[0]
 
154
            if what == 'unchanged':
 
155
                for i in range(t[1], t[2]):
 
156
                    yield 'u | ' + self.base[i]
 
157
            elif what == 'a' or what == 'same':
 
158
                for i in range(t[1], t[2]):
 
159
                    yield what[0] + ' | ' + self.a[i]
 
160
            elif what == 'b':
 
161
                for i in range(t[1], t[2]):
 
162
                    yield 'b | ' + self.b[i]
 
163
            elif what == 'conflict':
 
164
                yield '<<<<\n'
 
165
                for i in range(t[3], t[4]):
 
166
                    yield 'A | ' + self.a[i]
 
167
                yield '----\n'
 
168
                for i in range(t[5], t[6]):
 
169
                    yield 'B | ' + self.b[i]
 
170
                yield '>>>>\n'
 
171
            else:
 
172
                raise ValueError(what)
 
173
 
 
174
    def merge_groups(self):
 
175
        """Yield sequence of line groups.  Each one is a tuple:
 
176
 
 
177
        'unchanged', lines
 
178
             Lines unchanged from base
 
179
 
 
180
        'a', lines
 
181
             Lines taken from a
 
182
 
 
183
        'same', lines
 
184
             Lines taken from a (and equal to b)
 
185
 
 
186
        'b', lines
 
187
             Lines taken from b
 
188
 
 
189
        'conflict', base_lines, a_lines, b_lines
 
190
             Lines from base were changed to either a or b and conflict.
 
191
        """
 
192
        for t in self.merge_regions():
 
193
            what = t[0]
 
194
            if what == 'unchanged':
 
195
                yield what, self.base[t[1]:t[2]]
 
196
            elif what == 'a' or what == 'same':
 
197
                yield what, self.a[t[1]:t[2]]
 
198
            elif what == 'b':
 
199
                yield what, self.b[t[1]:t[2]]
 
200
            elif what == 'conflict':
 
201
                yield (what,
 
202
                       self.base[t[1]:t[2]],
 
203
                       self.a[t[3]:t[4]],
 
204
                       self.b[t[5]:t[6]])
 
205
            else:
 
206
                raise ValueError(what)
 
207
 
 
208
    def merge_regions(self):
 
209
        """Return sequences of matching and conflicting regions.
 
210
 
 
211
        This returns tuples, where the first value says what kind we
 
212
        have:
 
213
 
 
214
        'unchanged', start, end
 
215
             Take a region of base[start:end]
 
216
 
 
217
        'same', astart, aend
 
218
             b and a are different from base but give the same result
 
219
 
 
220
        'a', start, end
 
221
             Non-clashing insertion from a[start:end]
 
222
 
 
223
        Method is as follows:
 
224
 
 
225
        The two sequences align only on regions which match the base
 
226
        and both descendents.  These are found by doing a two-way diff
 
227
        of each one against the base, and then finding the
 
228
        intersections between those regions.  These "sync regions"
 
229
        are by definition unchanged in both and easily dealt with.
 
230
 
 
231
        The regions in between can be in any of three cases:
 
232
        conflicted, or changed on only one side.
 
233
        """
 
234
 
 
235
        # section a[0:ia] has been disposed of, etc
 
236
        iz = ia = ib = 0
 
237
 
 
238
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
239
            matchlen = zend - zmatch
 
240
            # invariants:
 
241
            #   matchlen >= 0
 
242
            #   matchlen == (aend - amatch)
 
243
            #   matchlen == (bend - bmatch)
 
244
            len_a = amatch - ia
 
245
            len_b = bmatch - ib
 
246
            len_base = zmatch - iz
 
247
            # invariants:
 
248
            # assert len_a >= 0
 
249
            # assert len_b >= 0
 
250
            # assert len_base >= 0
 
251
 
 
252
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
253
 
 
254
            if len_a or len_b:
 
255
                # try to avoid actually slicing the lists
 
256
                same = compare_range(self.a, ia, amatch,
 
257
                                     self.b, ib, bmatch)
 
258
 
 
259
                if same:
 
260
                    yield 'same', ia, amatch
 
261
                else:
 
262
                    equal_a = compare_range(self.a, ia, amatch,
 
263
                                            self.base, iz, zmatch)
 
264
                    equal_b = compare_range(self.b, ib, bmatch,
 
265
                                            self.base, iz, zmatch)
 
266
                    if equal_a and not equal_b:
 
267
                        yield 'b', ib, bmatch
 
268
                    elif equal_b and not equal_a:
 
269
                        yield 'a', ia, amatch
 
270
                    elif not equal_a and not equal_b:
 
271
                        if self.is_cherrypick:
 
272
                            for node in self._refine_cherrypick_conflict(
 
273
                                                    iz, zmatch, ia, amatch,
 
274
                                                    ib, bmatch):
 
275
                                yield node
 
276
                        else:
 
277
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
278
                    else:
 
279
                        raise AssertionError("can't handle a=b=base but unmatched")
 
280
 
 
281
                ia = amatch
 
282
                ib = bmatch
 
283
            iz = zmatch
 
284
 
 
285
            # if the same part of the base was deleted on both sides
 
286
            # that's OK, we can just skip it.
 
287
 
 
288
            if matchlen > 0:
 
289
                # invariants:
 
290
                # assert ia == amatch
 
291
                # assert ib == bmatch
 
292
                # assert iz == zmatch
 
293
 
 
294
                yield 'unchanged', zmatch, zend
 
295
                iz = zend
 
296
                ia = aend
 
297
                ib = bend
 
298
 
 
299
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
 
300
        """When cherrypicking b => a, ignore matches with b and base."""
 
301
        # Do not emit regions which match, only regions which do not match
 
302
        matches = patiencediff.PatienceSequenceMatcher(None,
 
303
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
 
304
        last_base_idx = 0
 
305
        last_b_idx = 0
 
306
        last_b_idx = 0
 
307
        yielded_a = False
 
308
        for base_idx, b_idx, match_len in matches:
 
309
            conflict_z_len = base_idx - last_base_idx
 
310
            conflict_b_len = b_idx - last_b_idx
 
311
            if conflict_b_len == 0: # There are no lines in b which conflict,
 
312
                                    # so skip it
 
313
                pass
 
314
            else:
 
315
                if yielded_a:
 
316
                    yield ('conflict',
 
317
                           zstart + last_base_idx, zstart + base_idx,
 
318
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
 
319
                else:
 
320
                    # The first conflict gets the a-range
 
321
                    yielded_a = True
 
322
                    yield ('conflict', zstart + last_base_idx, zstart +
 
323
                    base_idx,
 
324
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
 
325
            last_base_idx = base_idx + match_len
 
326
            last_b_idx = b_idx + match_len
 
327
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
 
328
            if yielded_a:
 
329
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
330
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
 
331
            else:
 
332
                # The first conflict gets the a-range
 
333
                yielded_a = True
 
334
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
335
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
 
336
        if not yielded_a:
 
337
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
 
338
 
 
339
    def reprocess_merge_regions(self, merge_regions):
 
340
        """Where there are conflict regions, remove the agreed lines.
 
341
 
 
342
        Lines where both A and B have made the same changes are
 
343
        eliminated.
 
344
        """
 
345
        for region in merge_regions:
 
346
            if region[0] != "conflict":
 
347
                yield region
 
348
                continue
 
349
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
350
            a_region = self.a[ia:amatch]
 
351
            b_region = self.b[ib:bmatch]
 
352
            matches = patiencediff.PatienceSequenceMatcher(
 
353
                    None, a_region, b_region).get_matching_blocks()
 
354
            next_a = ia
 
355
            next_b = ib
 
356
            for region_ia, region_ib, region_len in matches[:-1]:
 
357
                region_ia += ia
 
358
                region_ib += ib
 
359
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
360
                                           region_ib)
 
361
                if reg is not None:
 
362
                    yield reg
 
363
                yield 'same', region_ia, region_len+region_ia
 
364
                next_a = region_ia + region_len
 
365
                next_b = region_ib + region_len
 
366
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
367
            if reg is not None:
 
368
                yield reg
 
369
 
 
370
    @staticmethod
 
371
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
372
        if next_a < region_ia or next_b < region_ib:
 
373
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
374
 
 
375
    def find_sync_regions(self):
 
376
        """Return a list of sync regions, where both descendents match the base.
 
377
 
 
378
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
379
        always a zero-length sync region at the end of all the files.
 
380
        """
 
381
 
 
382
        ia = ib = 0
 
383
        amatches = patiencediff.PatienceSequenceMatcher(
 
384
                None, self.base, self.a).get_matching_blocks()
 
385
        bmatches = patiencediff.PatienceSequenceMatcher(
 
386
                None, self.base, self.b).get_matching_blocks()
 
387
        len_a = len(amatches)
 
388
        len_b = len(bmatches)
 
389
 
 
390
        sl = []
 
391
 
 
392
        while ia < len_a and ib < len_b:
 
393
            abase, amatch, alen = amatches[ia]
 
394
            bbase, bmatch, blen = bmatches[ib]
 
395
 
 
396
            # there is an unconflicted block at i; how long does it
 
397
            # extend?  until whichever one ends earlier.
 
398
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
399
            if i:
 
400
                intbase = i[0]
 
401
                intend = i[1]
 
402
                intlen = intend - intbase
 
403
 
 
404
                # found a match of base[i[0], i[1]]; this may be less than
 
405
                # the region that matches in either one
 
406
                # assert intlen <= alen
 
407
                # assert intlen <= blen
 
408
                # assert abase <= intbase
 
409
                # assert bbase <= intbase
 
410
 
 
411
                asub = amatch + (intbase - abase)
 
412
                bsub = bmatch + (intbase - bbase)
 
413
                aend = asub + intlen
 
414
                bend = bsub + intlen
 
415
 
 
416
                # assert self.base[intbase:intend] == self.a[asub:aend], \
 
417
                #       (self.base[intbase:intend], self.a[asub:aend])
 
418
                # assert self.base[intbase:intend] == self.b[bsub:bend]
 
419
 
 
420
                sl.append((intbase, intend,
 
421
                           asub, aend,
 
422
                           bsub, bend))
 
423
            # advance whichever one ends first in the base text
 
424
            if (abase + alen) < (bbase + blen):
 
425
                ia += 1
 
426
            else:
 
427
                ib += 1
 
428
 
 
429
        intbase = len(self.base)
 
430
        abase = len(self.a)
 
431
        bbase = len(self.b)
 
432
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
433
 
 
434
        return sl
 
435
 
 
436
    def find_unconflicted(self):
 
437
        """Return a list of ranges in base that are not conflicted."""
 
438
        am = patiencediff.PatienceSequenceMatcher(
 
439
                None, self.base, self.a).get_matching_blocks()
 
440
        bm = patiencediff.PatienceSequenceMatcher(
 
441
                None, self.base, self.b).get_matching_blocks()
 
442
 
 
443
        unc = []
 
444
 
 
445
        while am and bm:
 
446
            # there is an unconflicted block at i; how long does it
 
447
            # extend?  until whichever one ends earlier.
 
448
            a1 = am[0][0]
 
449
            a2 = a1 + am[0][2]
 
450
            b1 = bm[0][0]
 
451
            b2 = b1 + bm[0][2]
 
452
            i = intersect((a1, a2), (b1, b2))
 
453
            if i:
 
454
                unc.append(i)
 
455
 
 
456
            if a2 < b2:
 
457
                del am[0]
 
458
            else:
 
459
                del bm[0]
 
460
 
 
461
        return unc
 
462
 
 
463
 
 
464
def main(argv):
 
465
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
466
    a = file(argv[1], 'rt').readlines()
 
467
    base = file(argv[2], 'rt').readlines()
 
468
    b = file(argv[3], 'rt').readlines()
 
469
 
 
470
    m3 = Merge3(base, a, b)
 
471
 
 
472
    #for sr in m3.find_sync_regions():
 
473
    #    print sr
 
474
 
 
475
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
476
    sys.stdout.writelines(m3.merge_annotated())
 
477
 
 
478
 
 
479
if __name__ == '__main__':
 
480
    import sys
 
481
    sys.exit(main(sys.argv))