~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-03-14 07:07:24 UTC
  • Revision ID: mbp@sourcefrog.net-20050314070724-ba6c85db7d96c508
- add setup.py and install instructions
- rename main script to just bzr

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
from __future__ import absolute_import
18
 
 
19
 
# mbp: "you know that thing where cvs gives you conflict markers?"
20
 
# s: "i hate that."
21
 
 
22
 
from bzrlib import (
23
 
    errors,
24
 
    patiencediff,
25
 
    textfile,
26
 
    )
27
 
 
28
 
 
29
 
def intersect(ra, rb):
30
 
    """Given two ranges return the range where they intersect or None.
31
 
 
32
 
    >>> intersect((0, 10), (0, 6))
33
 
    (0, 6)
34
 
    >>> intersect((0, 10), (5, 15))
35
 
    (5, 10)
36
 
    >>> intersect((0, 10), (10, 15))
37
 
    >>> intersect((0, 9), (10, 15))
38
 
    >>> intersect((0, 9), (7, 15))
39
 
    (7, 9)
40
 
    """
41
 
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
42
 
 
43
 
    sa = max(ra[0], rb[0])
44
 
    sb = min(ra[1], rb[1])
45
 
    if sa < sb:
46
 
        return sa, sb
47
 
    else:
48
 
        return None
49
 
 
50
 
 
51
 
def compare_range(a, astart, aend, b, bstart, bend):
52
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
53
 
    """
54
 
    if (aend-astart) != (bend-bstart):
55
 
        return False
56
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
57
 
        if a[ia] != b[ib]:
58
 
            return False
59
 
    else:
60
 
        return True
61
 
 
62
 
 
63
 
 
64
 
 
65
 
class Merge3(object):
66
 
    """3-way merge of texts.
67
 
 
68
 
    Given BASE, OTHER, THIS, tries to produce a combined text
69
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
70
 
    All three will typically be sequences of lines."""
71
 
 
72
 
    def __init__(self, base, a, b, is_cherrypick=False, allow_objects=False):
73
 
        """Constructor.
74
 
 
75
 
        :param base: lines in BASE
76
 
        :param a: lines in A
77
 
        :param b: lines in B
78
 
        :param is_cherrypick: flag indicating if this merge is a cherrypick.
79
 
            When cherrypicking b => a, matches with b and base do not conflict.
80
 
        :param allow_objects: if True, do not require that base, a and b are
81
 
            plain Python strs.  Also prevents BinaryFile from being raised.
82
 
            Lines can be any sequence of comparable and hashable Python
83
 
            objects.
84
 
        """
85
 
        if not allow_objects:
86
 
            textfile.check_text_lines(base)
87
 
            textfile.check_text_lines(a)
88
 
            textfile.check_text_lines(b)
89
 
        self.base = base
90
 
        self.a = a
91
 
        self.b = b
92
 
        self.is_cherrypick = is_cherrypick
93
 
 
94
 
    def merge_lines(self,
95
 
                    name_a=None,
96
 
                    name_b=None,
97
 
                    name_base=None,
98
 
                    start_marker='<<<<<<<',
99
 
                    mid_marker='=======',
100
 
                    end_marker='>>>>>>>',
101
 
                    base_marker=None,
102
 
                    reprocess=False):
103
 
        """Return merge in cvs-like form.
104
 
        """
105
 
        newline = '\n'
106
 
        if len(self.a) > 0:
107
 
            if self.a[0].endswith('\r\n'):
108
 
                newline = '\r\n'
109
 
            elif self.a[0].endswith('\r'):
110
 
                newline = '\r'
111
 
        if base_marker and reprocess:
112
 
            raise errors.CantReprocessAndShowBase()
113
 
        if name_a:
114
 
            start_marker = start_marker + ' ' + name_a
115
 
        if name_b:
116
 
            end_marker = end_marker + ' ' + name_b
117
 
        if name_base and base_marker:
118
 
            base_marker = base_marker + ' ' + name_base
119
 
        merge_regions = self.merge_regions()
120
 
        if reprocess is True:
121
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
122
 
        for t in merge_regions:
123
 
            what = t[0]
124
 
            if what == 'unchanged':
125
 
                for i in range(t[1], t[2]):
126
 
                    yield self.base[i]
127
 
            elif what == 'a' or what == 'same':
128
 
                for i in range(t[1], t[2]):
129
 
                    yield self.a[i]
130
 
            elif what == 'b':
131
 
                for i in range(t[1], t[2]):
132
 
                    yield self.b[i]
133
 
            elif what == 'conflict':
134
 
                yield start_marker + newline
135
 
                for i in range(t[3], t[4]):
136
 
                    yield self.a[i]
137
 
                if base_marker is not None:
138
 
                    yield base_marker + newline
139
 
                    for i in range(t[1], t[2]):
140
 
                        yield self.base[i]
141
 
                yield mid_marker + newline
142
 
                for i in range(t[5], t[6]):
143
 
                    yield self.b[i]
144
 
                yield end_marker + newline
145
 
            else:
146
 
                raise ValueError(what)
147
 
 
148
 
    def merge_annotated(self):
149
 
        """Return merge with conflicts, showing origin of lines.
150
 
 
151
 
        Most useful for debugging merge.
152
 
        """
153
 
        for t in self.merge_regions():
154
 
            what = t[0]
155
 
            if what == 'unchanged':
156
 
                for i in range(t[1], t[2]):
157
 
                    yield 'u | ' + self.base[i]
158
 
            elif what == 'a' or what == 'same':
159
 
                for i in range(t[1], t[2]):
160
 
                    yield what[0] + ' | ' + self.a[i]
161
 
            elif what == 'b':
162
 
                for i in range(t[1], t[2]):
163
 
                    yield 'b | ' + self.b[i]
164
 
            elif what == 'conflict':
165
 
                yield '<<<<\n'
166
 
                for i in range(t[3], t[4]):
167
 
                    yield 'A | ' + self.a[i]
168
 
                yield '----\n'
169
 
                for i in range(t[5], t[6]):
170
 
                    yield 'B | ' + self.b[i]
171
 
                yield '>>>>\n'
172
 
            else:
173
 
                raise ValueError(what)
174
 
 
175
 
    def merge_groups(self):
176
 
        """Yield sequence of line groups.  Each one is a tuple:
177
 
 
178
 
        'unchanged', lines
179
 
             Lines unchanged from base
180
 
 
181
 
        'a', lines
182
 
             Lines taken from a
183
 
 
184
 
        'same', lines
185
 
             Lines taken from a (and equal to b)
186
 
 
187
 
        'b', lines
188
 
             Lines taken from b
189
 
 
190
 
        'conflict', base_lines, a_lines, b_lines
191
 
             Lines from base were changed to either a or b and conflict.
192
 
        """
193
 
        for t in self.merge_regions():
194
 
            what = t[0]
195
 
            if what == 'unchanged':
196
 
                yield what, self.base[t[1]:t[2]]
197
 
            elif what == 'a' or what == 'same':
198
 
                yield what, self.a[t[1]:t[2]]
199
 
            elif what == 'b':
200
 
                yield what, self.b[t[1]:t[2]]
201
 
            elif what == 'conflict':
202
 
                yield (what,
203
 
                       self.base[t[1]:t[2]],
204
 
                       self.a[t[3]:t[4]],
205
 
                       self.b[t[5]:t[6]])
206
 
            else:
207
 
                raise ValueError(what)
208
 
 
209
 
    def merge_regions(self):
210
 
        """Return sequences of matching and conflicting regions.
211
 
 
212
 
        This returns tuples, where the first value says what kind we
213
 
        have:
214
 
 
215
 
        'unchanged', start, end
216
 
             Take a region of base[start:end]
217
 
 
218
 
        'same', astart, aend
219
 
             b and a are different from base but give the same result
220
 
 
221
 
        'a', start, end
222
 
             Non-clashing insertion from a[start:end]
223
 
 
224
 
        Method is as follows:
225
 
 
226
 
        The two sequences align only on regions which match the base
227
 
        and both descendents.  These are found by doing a two-way diff
228
 
        of each one against the base, and then finding the
229
 
        intersections between those regions.  These "sync regions"
230
 
        are by definition unchanged in both and easily dealt with.
231
 
 
232
 
        The regions in between can be in any of three cases:
233
 
        conflicted, or changed on only one side.
234
 
        """
235
 
 
236
 
        # section a[0:ia] has been disposed of, etc
237
 
        iz = ia = ib = 0
238
 
 
239
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
240
 
            matchlen = zend - zmatch
241
 
            # invariants:
242
 
            #   matchlen >= 0
243
 
            #   matchlen == (aend - amatch)
244
 
            #   matchlen == (bend - bmatch)
245
 
            len_a = amatch - ia
246
 
            len_b = bmatch - ib
247
 
            len_base = zmatch - iz
248
 
            # invariants:
249
 
            # assert len_a >= 0
250
 
            # assert len_b >= 0
251
 
            # assert len_base >= 0
252
 
 
253
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
254
 
 
255
 
            if len_a or len_b:
256
 
                # try to avoid actually slicing the lists
257
 
                same = compare_range(self.a, ia, amatch,
258
 
                                     self.b, ib, bmatch)
259
 
 
260
 
                if same:
261
 
                    yield 'same', ia, amatch
262
 
                else:
263
 
                    equal_a = compare_range(self.a, ia, amatch,
264
 
                                            self.base, iz, zmatch)
265
 
                    equal_b = compare_range(self.b, ib, bmatch,
266
 
                                            self.base, iz, zmatch)
267
 
                    if equal_a and not equal_b:
268
 
                        yield 'b', ib, bmatch
269
 
                    elif equal_b and not equal_a:
270
 
                        yield 'a', ia, amatch
271
 
                    elif not equal_a and not equal_b:
272
 
                        if self.is_cherrypick:
273
 
                            for node in self._refine_cherrypick_conflict(
274
 
                                                    iz, zmatch, ia, amatch,
275
 
                                                    ib, bmatch):
276
 
                                yield node
277
 
                        else:
278
 
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
279
 
                    else:
280
 
                        raise AssertionError("can't handle a=b=base but unmatched")
281
 
 
282
 
                ia = amatch
283
 
                ib = bmatch
284
 
            iz = zmatch
285
 
 
286
 
            # if the same part of the base was deleted on both sides
287
 
            # that's OK, we can just skip it.
288
 
 
289
 
            if matchlen > 0:
290
 
                # invariants:
291
 
                # assert ia == amatch
292
 
                # assert ib == bmatch
293
 
                # assert iz == zmatch
294
 
 
295
 
                yield 'unchanged', zmatch, zend
296
 
                iz = zend
297
 
                ia = aend
298
 
                ib = bend
299
 
 
300
 
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
301
 
        """When cherrypicking b => a, ignore matches with b and base."""
302
 
        # Do not emit regions which match, only regions which do not match
303
 
        matches = patiencediff.PatienceSequenceMatcher(None,
304
 
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
305
 
        last_base_idx = 0
306
 
        last_b_idx = 0
307
 
        last_b_idx = 0
308
 
        yielded_a = False
309
 
        for base_idx, b_idx, match_len in matches:
310
 
            conflict_z_len = base_idx - last_base_idx
311
 
            conflict_b_len = b_idx - last_b_idx
312
 
            if conflict_b_len == 0: # There are no lines in b which conflict,
313
 
                                    # so skip it
314
 
                pass
315
 
            else:
316
 
                if yielded_a:
317
 
                    yield ('conflict',
318
 
                           zstart + last_base_idx, zstart + base_idx,
319
 
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
320
 
                else:
321
 
                    # The first conflict gets the a-range
322
 
                    yielded_a = True
323
 
                    yield ('conflict', zstart + last_base_idx, zstart +
324
 
                    base_idx,
325
 
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
326
 
            last_base_idx = base_idx + match_len
327
 
            last_b_idx = b_idx + match_len
328
 
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
329
 
            if yielded_a:
330
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
331
 
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
332
 
            else:
333
 
                # The first conflict gets the a-range
334
 
                yielded_a = True
335
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
336
 
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
337
 
        if not yielded_a:
338
 
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
339
 
 
340
 
    def reprocess_merge_regions(self, merge_regions):
341
 
        """Where there are conflict regions, remove the agreed lines.
342
 
 
343
 
        Lines where both A and B have made the same changes are
344
 
        eliminated.
345
 
        """
346
 
        for region in merge_regions:
347
 
            if region[0] != "conflict":
348
 
                yield region
349
 
                continue
350
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
351
 
            a_region = self.a[ia:amatch]
352
 
            b_region = self.b[ib:bmatch]
353
 
            matches = patiencediff.PatienceSequenceMatcher(
354
 
                    None, a_region, b_region).get_matching_blocks()
355
 
            next_a = ia
356
 
            next_b = ib
357
 
            for region_ia, region_ib, region_len in matches[:-1]:
358
 
                region_ia += ia
359
 
                region_ib += ib
360
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
361
 
                                           region_ib)
362
 
                if reg is not None:
363
 
                    yield reg
364
 
                yield 'same', region_ia, region_len+region_ia
365
 
                next_a = region_ia + region_len
366
 
                next_b = region_ib + region_len
367
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
368
 
            if reg is not None:
369
 
                yield reg
370
 
 
371
 
    @staticmethod
372
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
373
 
        if next_a < region_ia or next_b < region_ib:
374
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
375
 
 
376
 
    def find_sync_regions(self):
377
 
        """Return a list of sync regions, where both descendents match the base.
378
 
 
379
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
380
 
        always a zero-length sync region at the end of all the files.
381
 
        """
382
 
 
383
 
        ia = ib = 0
384
 
        amatches = patiencediff.PatienceSequenceMatcher(
385
 
                None, self.base, self.a).get_matching_blocks()
386
 
        bmatches = patiencediff.PatienceSequenceMatcher(
387
 
                None, self.base, self.b).get_matching_blocks()
388
 
        len_a = len(amatches)
389
 
        len_b = len(bmatches)
390
 
 
391
 
        sl = []
392
 
 
393
 
        while ia < len_a and ib < len_b:
394
 
            abase, amatch, alen = amatches[ia]
395
 
            bbase, bmatch, blen = bmatches[ib]
396
 
 
397
 
            # there is an unconflicted block at i; how long does it
398
 
            # extend?  until whichever one ends earlier.
399
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
400
 
            if i:
401
 
                intbase = i[0]
402
 
                intend = i[1]
403
 
                intlen = intend - intbase
404
 
 
405
 
                # found a match of base[i[0], i[1]]; this may be less than
406
 
                # the region that matches in either one
407
 
                # assert intlen <= alen
408
 
                # assert intlen <= blen
409
 
                # assert abase <= intbase
410
 
                # assert bbase <= intbase
411
 
 
412
 
                asub = amatch + (intbase - abase)
413
 
                bsub = bmatch + (intbase - bbase)
414
 
                aend = asub + intlen
415
 
                bend = bsub + intlen
416
 
 
417
 
                # assert self.base[intbase:intend] == self.a[asub:aend], \
418
 
                #       (self.base[intbase:intend], self.a[asub:aend])
419
 
                # assert self.base[intbase:intend] == self.b[bsub:bend]
420
 
 
421
 
                sl.append((intbase, intend,
422
 
                           asub, aend,
423
 
                           bsub, bend))
424
 
            # advance whichever one ends first in the base text
425
 
            if (abase + alen) < (bbase + blen):
426
 
                ia += 1
427
 
            else:
428
 
                ib += 1
429
 
 
430
 
        intbase = len(self.base)
431
 
        abase = len(self.a)
432
 
        bbase = len(self.b)
433
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
434
 
 
435
 
        return sl
436
 
 
437
 
    def find_unconflicted(self):
438
 
        """Return a list of ranges in base that are not conflicted."""
439
 
        am = patiencediff.PatienceSequenceMatcher(
440
 
                None, self.base, self.a).get_matching_blocks()
441
 
        bm = patiencediff.PatienceSequenceMatcher(
442
 
                None, self.base, self.b).get_matching_blocks()
443
 
 
444
 
        unc = []
445
 
 
446
 
        while am and bm:
447
 
            # there is an unconflicted block at i; how long does it
448
 
            # extend?  until whichever one ends earlier.
449
 
            a1 = am[0][0]
450
 
            a2 = a1 + am[0][2]
451
 
            b1 = bm[0][0]
452
 
            b2 = b1 + bm[0][2]
453
 
            i = intersect((a1, a2), (b1, b2))
454
 
            if i:
455
 
                unc.append(i)
456
 
 
457
 
            if a2 < b2:
458
 
                del am[0]
459
 
            else:
460
 
                del bm[0]
461
 
 
462
 
        return unc
463
 
 
464
 
 
465
 
def main(argv):
466
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
467
 
    a = file(argv[1], 'rt').readlines()
468
 
    base = file(argv[2], 'rt').readlines()
469
 
    b = file(argv[3], 'rt').readlines()
470
 
 
471
 
    m3 = Merge3(base, a, b)
472
 
 
473
 
    #for sr in m3.find_sync_regions():
474
 
    #    print sr
475
 
 
476
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
477
 
    sys.stdout.writelines(m3.merge_annotated())
478
 
 
479
 
 
480
 
if __name__ == '__main__':
481
 
    import sys
482
 
    sys.exit(main(sys.argv))