~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2010-02-03 00:08:23 UTC
  • mto: This revision was merged to the branch mainline in revision 5002.
  • Revision ID: mbp@sourcefrog.net-20100203000823-fcyf2791xrl3fbfo
expand tabs

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 Canonical Ltd
 
2
#
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
#
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
#
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from bzrlib.errors import CantReprocessAndShowBase
 
23
import bzrlib.patiencediff
 
24
from bzrlib.textfile import check_text_lines
 
25
 
 
26
 
 
27
def intersect(ra, rb):
 
28
    """Given two ranges return the range where they intersect or None.
 
29
 
 
30
    >>> intersect((0, 10), (0, 6))
 
31
    (0, 6)
 
32
    >>> intersect((0, 10), (5, 15))
 
33
    (5, 10)
 
34
    >>> intersect((0, 10), (10, 15))
 
35
    >>> intersect((0, 9), (10, 15))
 
36
    >>> intersect((0, 9), (7, 15))
 
37
    (7, 9)
 
38
    """
 
39
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
 
40
 
 
41
    sa = max(ra[0], rb[0])
 
42
    sb = min(ra[1], rb[1])
 
43
    if sa < sb:
 
44
        return sa, sb
 
45
    else:
 
46
        return None
 
47
 
 
48
 
 
49
def compare_range(a, astart, aend, b, bstart, bend):
 
50
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
51
    """
 
52
    if (aend-astart) != (bend-bstart):
 
53
        return False
 
54
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
55
        if a[ia] != b[ib]:
 
56
            return False
 
57
    else:
 
58
        return True
 
59
 
 
60
 
 
61
 
 
62
 
 
63
class Merge3(object):
 
64
    """3-way merge of texts.
 
65
 
 
66
    Given BASE, OTHER, THIS, tries to produce a combined text
 
67
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
68
    All three will typically be sequences of lines."""
 
69
    def __init__(self, base, a, b, is_cherrypick=False):
 
70
        check_text_lines(base)
 
71
        check_text_lines(a)
 
72
        check_text_lines(b)
 
73
        self.base = base
 
74
        self.a = a
 
75
        self.b = b
 
76
        self.is_cherrypick = is_cherrypick
 
77
 
 
78
    def merge_lines(self,
 
79
                    name_a=None,
 
80
                    name_b=None,
 
81
                    name_base=None,
 
82
                    start_marker='<<<<<<<',
 
83
                    mid_marker='=======',
 
84
                    end_marker='>>>>>>>',
 
85
                    base_marker=None,
 
86
                    reprocess=False):
 
87
        """Return merge in cvs-like form.
 
88
        """
 
89
        newline = '\n'
 
90
        if len(self.a) > 0:
 
91
            if self.a[0].endswith('\r\n'):
 
92
                newline = '\r\n'
 
93
            elif self.a[0].endswith('\r'):
 
94
                newline = '\r'
 
95
        if base_marker and reprocess:
 
96
            raise CantReprocessAndShowBase()
 
97
        if name_a:
 
98
            start_marker = start_marker + ' ' + name_a
 
99
        if name_b:
 
100
            end_marker = end_marker + ' ' + name_b
 
101
        if name_base and base_marker:
 
102
            base_marker = base_marker + ' ' + name_base
 
103
        merge_regions = self.merge_regions()
 
104
        if reprocess is True:
 
105
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
106
        for t in merge_regions:
 
107
            what = t[0]
 
108
            if what == 'unchanged':
 
109
                for i in range(t[1], t[2]):
 
110
                    yield self.base[i]
 
111
            elif what == 'a' or what == 'same':
 
112
                for i in range(t[1], t[2]):
 
113
                    yield self.a[i]
 
114
            elif what == 'b':
 
115
                for i in range(t[1], t[2]):
 
116
                    yield self.b[i]
 
117
            elif what == 'conflict':
 
118
                yield start_marker + newline
 
119
                for i in range(t[3], t[4]):
 
120
                    yield self.a[i]
 
121
                if base_marker is not None:
 
122
                    yield base_marker + newline
 
123
                    for i in range(t[1], t[2]):
 
124
                        yield self.base[i]
 
125
                yield mid_marker + newline
 
126
                for i in range(t[5], t[6]):
 
127
                    yield self.b[i]
 
128
                yield end_marker + newline
 
129
            else:
 
130
                raise ValueError(what)
 
131
 
 
132
    def merge_annotated(self):
 
133
        """Return merge with conflicts, showing origin of lines.
 
134
 
 
135
        Most useful for debugging merge.
 
136
        """
 
137
        for t in self.merge_regions():
 
138
            what = t[0]
 
139
            if what == 'unchanged':
 
140
                for i in range(t[1], t[2]):
 
141
                    yield 'u | ' + self.base[i]
 
142
            elif what == 'a' or what == 'same':
 
143
                for i in range(t[1], t[2]):
 
144
                    yield what[0] + ' | ' + self.a[i]
 
145
            elif what == 'b':
 
146
                for i in range(t[1], t[2]):
 
147
                    yield 'b | ' + self.b[i]
 
148
            elif what == 'conflict':
 
149
                yield '<<<<\n'
 
150
                for i in range(t[3], t[4]):
 
151
                    yield 'A | ' + self.a[i]
 
152
                yield '----\n'
 
153
                for i in range(t[5], t[6]):
 
154
                    yield 'B | ' + self.b[i]
 
155
                yield '>>>>\n'
 
156
            else:
 
157
                raise ValueError(what)
 
158
 
 
159
    def merge_groups(self):
 
160
        """Yield sequence of line groups.  Each one is a tuple:
 
161
 
 
162
        'unchanged', lines
 
163
             Lines unchanged from base
 
164
 
 
165
        'a', lines
 
166
             Lines taken from a
 
167
 
 
168
        'same', lines
 
169
             Lines taken from a (and equal to b)
 
170
 
 
171
        'b', lines
 
172
             Lines taken from b
 
173
 
 
174
        'conflict', base_lines, a_lines, b_lines
 
175
             Lines from base were changed to either a or b and conflict.
 
176
        """
 
177
        for t in self.merge_regions():
 
178
            what = t[0]
 
179
            if what == 'unchanged':
 
180
                yield what, self.base[t[1]:t[2]]
 
181
            elif what == 'a' or what == 'same':
 
182
                yield what, self.a[t[1]:t[2]]
 
183
            elif what == 'b':
 
184
                yield what, self.b[t[1]:t[2]]
 
185
            elif what == 'conflict':
 
186
                yield (what,
 
187
                       self.base[t[1]:t[2]],
 
188
                       self.a[t[3]:t[4]],
 
189
                       self.b[t[5]:t[6]])
 
190
            else:
 
191
                raise ValueError(what)
 
192
 
 
193
    def merge_regions(self):
 
194
        """Return sequences of matching and conflicting regions.
 
195
 
 
196
        This returns tuples, where the first value says what kind we
 
197
        have:
 
198
 
 
199
        'unchanged', start, end
 
200
             Take a region of base[start:end]
 
201
 
 
202
        'same', astart, aend
 
203
             b and a are different from base but give the same result
 
204
 
 
205
        'a', start, end
 
206
             Non-clashing insertion from a[start:end]
 
207
 
 
208
        Method is as follows:
 
209
 
 
210
        The two sequences align only on regions which match the base
 
211
        and both descendents.  These are found by doing a two-way diff
 
212
        of each one against the base, and then finding the
 
213
        intersections between those regions.  These "sync regions"
 
214
        are by definition unchanged in both and easily dealt with.
 
215
 
 
216
        The regions in between can be in any of three cases:
 
217
        conflicted, or changed on only one side.
 
218
        """
 
219
 
 
220
        # section a[0:ia] has been disposed of, etc
 
221
        iz = ia = ib = 0
 
222
 
 
223
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
224
            matchlen = zend - zmatch
 
225
            # invariants:
 
226
            #   matchlen >= 0
 
227
            #   matchlen == (aend - amatch)
 
228
            #   matchlen == (bend - bmatch)
 
229
            len_a = amatch - ia
 
230
            len_b = bmatch - ib
 
231
            len_base = zmatch - iz
 
232
            # invariants:
 
233
            # assert len_a >= 0
 
234
            # assert len_b >= 0
 
235
            # assert len_base >= 0
 
236
 
 
237
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
238
 
 
239
            if len_a or len_b:
 
240
                # try to avoid actually slicing the lists
 
241
                same = compare_range(self.a, ia, amatch,
 
242
                                     self.b, ib, bmatch)
 
243
 
 
244
                if same:
 
245
                    yield 'same', ia, amatch
 
246
                else:
 
247
                    equal_a = compare_range(self.a, ia, amatch,
 
248
                                            self.base, iz, zmatch)
 
249
                    equal_b = compare_range(self.b, ib, bmatch,
 
250
                                            self.base, iz, zmatch)
 
251
                    if equal_a and not equal_b:
 
252
                        yield 'b', ib, bmatch
 
253
                    elif equal_b and not equal_a:
 
254
                        yield 'a', ia, amatch
 
255
                    elif not equal_a and not equal_b:
 
256
                        if self.is_cherrypick:
 
257
                            for node in self._refine_cherrypick_conflict(
 
258
                                                    iz, zmatch, ia, amatch,
 
259
                                                    ib, bmatch):
 
260
                                yield node
 
261
                        else:
 
262
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
263
                    else:
 
264
                        raise AssertionError("can't handle a=b=base but unmatched")
 
265
 
 
266
                ia = amatch
 
267
                ib = bmatch
 
268
            iz = zmatch
 
269
 
 
270
            # if the same part of the base was deleted on both sides
 
271
            # that's OK, we can just skip it.
 
272
 
 
273
            if matchlen > 0:
 
274
                # invariants:
 
275
                # assert ia == amatch
 
276
                # assert ib == bmatch
 
277
                # assert iz == zmatch
 
278
 
 
279
                yield 'unchanged', zmatch, zend
 
280
                iz = zend
 
281
                ia = aend
 
282
                ib = bend
 
283
 
 
284
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
 
285
        """When cherrypicking b => a, ignore matches with b and base."""
 
286
        # Do not emit regions which match, only regions which do not match
 
287
        matches = bzrlib.patiencediff.PatienceSequenceMatcher(None,
 
288
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
 
289
        last_base_idx = 0
 
290
        last_b_idx = 0
 
291
        last_b_idx = 0
 
292
        yielded_a = False
 
293
        for base_idx, b_idx, match_len in matches:
 
294
            conflict_z_len = base_idx - last_base_idx
 
295
            conflict_b_len = b_idx - last_b_idx
 
296
            if conflict_b_len == 0: # There are no lines in b which conflict,
 
297
                                    # so skip it
 
298
                pass
 
299
            else:
 
300
                if yielded_a:
 
301
                    yield ('conflict',
 
302
                           zstart + last_base_idx, zstart + base_idx,
 
303
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
 
304
                else:
 
305
                    # The first conflict gets the a-range
 
306
                    yielded_a = True
 
307
                    yield ('conflict', zstart + last_base_idx, zstart +
 
308
                    base_idx,
 
309
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
 
310
            last_base_idx = base_idx + match_len
 
311
            last_b_idx = b_idx + match_len
 
312
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
 
313
            if yielded_a:
 
314
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
315
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
 
316
            else:
 
317
                # The first conflict gets the a-range
 
318
                yielded_a = True
 
319
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
 
320
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
 
321
        if not yielded_a:
 
322
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
 
323
 
 
324
    def reprocess_merge_regions(self, merge_regions):
 
325
        """Where there are conflict regions, remove the agreed lines.
 
326
 
 
327
        Lines where both A and B have made the same changes are
 
328
        eliminated.
 
329
        """
 
330
        for region in merge_regions:
 
331
            if region[0] != "conflict":
 
332
                yield region
 
333
                continue
 
334
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
335
            a_region = self.a[ia:amatch]
 
336
            b_region = self.b[ib:bmatch]
 
337
            matches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
338
                    None, a_region, b_region).get_matching_blocks()
 
339
            next_a = ia
 
340
            next_b = ib
 
341
            for region_ia, region_ib, region_len in matches[:-1]:
 
342
                region_ia += ia
 
343
                region_ib += ib
 
344
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
345
                                           region_ib)
 
346
                if reg is not None:
 
347
                    yield reg
 
348
                yield 'same', region_ia, region_len+region_ia
 
349
                next_a = region_ia + region_len
 
350
                next_b = region_ib + region_len
 
351
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
352
            if reg is not None:
 
353
                yield reg
 
354
 
 
355
    @staticmethod
 
356
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
357
        if next_a < region_ia or next_b < region_ib:
 
358
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
359
 
 
360
    def find_sync_regions(self):
 
361
        """Return a list of sync regions, where both descendents match the base.
 
362
 
 
363
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
364
        always a zero-length sync region at the end of all the files.
 
365
        """
 
366
 
 
367
        ia = ib = 0
 
368
        amatches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
369
                None, self.base, self.a).get_matching_blocks()
 
370
        bmatches = bzrlib.patiencediff.PatienceSequenceMatcher(
 
371
                None, self.base, self.b).get_matching_blocks()
 
372
        len_a = len(amatches)
 
373
        len_b = len(bmatches)
 
374
 
 
375
        sl = []
 
376
 
 
377
        while ia < len_a and ib < len_b:
 
378
            abase, amatch, alen = amatches[ia]
 
379
            bbase, bmatch, blen = bmatches[ib]
 
380
 
 
381
            # there is an unconflicted block at i; how long does it
 
382
            # extend?  until whichever one ends earlier.
 
383
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
384
            if i:
 
385
                intbase = i[0]
 
386
                intend = i[1]
 
387
                intlen = intend - intbase
 
388
 
 
389
                # found a match of base[i[0], i[1]]; this may be less than
 
390
                # the region that matches in either one
 
391
                # assert intlen <= alen
 
392
                # assert intlen <= blen
 
393
                # assert abase <= intbase
 
394
                # assert bbase <= intbase
 
395
 
 
396
                asub = amatch + (intbase - abase)
 
397
                bsub = bmatch + (intbase - bbase)
 
398
                aend = asub + intlen
 
399
                bend = bsub + intlen
 
400
 
 
401
                # assert self.base[intbase:intend] == self.a[asub:aend], \
 
402
                #       (self.base[intbase:intend], self.a[asub:aend])
 
403
                # assert self.base[intbase:intend] == self.b[bsub:bend]
 
404
 
 
405
                sl.append((intbase, intend,
 
406
                           asub, aend,
 
407
                           bsub, bend))
 
408
            # advance whichever one ends first in the base text
 
409
            if (abase + alen) < (bbase + blen):
 
410
                ia += 1
 
411
            else:
 
412
                ib += 1
 
413
 
 
414
        intbase = len(self.base)
 
415
        abase = len(self.a)
 
416
        bbase = len(self.b)
 
417
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
418
 
 
419
        return sl
 
420
 
 
421
    def find_unconflicted(self):
 
422
        """Return a list of ranges in base that are not conflicted."""
 
423
        am = bzrlib.patiencediff.PatienceSequenceMatcher(
 
424
                None, self.base, self.a).get_matching_blocks()
 
425
        bm = bzrlib.patiencediff.PatienceSequenceMatcher(
 
426
                None, self.base, self.b).get_matching_blocks()
 
427
 
 
428
        unc = []
 
429
 
 
430
        while am and bm:
 
431
            # there is an unconflicted block at i; how long does it
 
432
            # extend?  until whichever one ends earlier.
 
433
            a1 = am[0][0]
 
434
            a2 = a1 + am[0][2]
 
435
            b1 = bm[0][0]
 
436
            b2 = b1 + bm[0][2]
 
437
            i = intersect((a1, a2), (b1, b2))
 
438
            if i:
 
439
                unc.append(i)
 
440
 
 
441
            if a2 < b2:
 
442
                del am[0]
 
443
            else:
 
444
                del bm[0]
 
445
 
 
446
        return unc
 
447
 
 
448
 
 
449
def main(argv):
 
450
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
451
    a = file(argv[1], 'rt').readlines()
 
452
    base = file(argv[2], 'rt').readlines()
 
453
    b = file(argv[3], 'rt').readlines()
 
454
 
 
455
    m3 = Merge3(base, a, b)
 
456
 
 
457
    #for sr in m3.find_sync_regions():
 
458
    #    print sr
 
459
 
 
460
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
461
    sys.stdout.writelines(m3.merge_annotated())
 
462
 
 
463
 
 
464
if __name__ == '__main__':
 
465
    import sys
 
466
    sys.exit(main(sys.argv))