~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-06-20 04:57:46 UTC
  • Revision ID: mbp@sourcefrog.net-20050620045746-bbb1976f0af52b94
- correctly set parent list when committing first
  revision to a branch

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
 
18
 
# mbp: "you know that thing where cvs gives you conflict markers?"
19
 
# s: "i hate that."
20
 
 
21
 
from bzrlib import (
22
 
    errors,
23
 
    patiencediff,
24
 
    textfile,
25
 
    )
26
 
 
27
 
 
28
 
def intersect(ra, rb):
29
 
    """Given two ranges return the range where they intersect or None.
30
 
 
31
 
    >>> intersect((0, 10), (0, 6))
32
 
    (0, 6)
33
 
    >>> intersect((0, 10), (5, 15))
34
 
    (5, 10)
35
 
    >>> intersect((0, 10), (10, 15))
36
 
    >>> intersect((0, 9), (10, 15))
37
 
    >>> intersect((0, 9), (7, 15))
38
 
    (7, 9)
39
 
    """
40
 
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
41
 
 
42
 
    sa = max(ra[0], rb[0])
43
 
    sb = min(ra[1], rb[1])
44
 
    if sa < sb:
45
 
        return sa, sb
46
 
    else:
47
 
        return None
48
 
 
49
 
 
50
 
def compare_range(a, astart, aend, b, bstart, bend):
51
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
52
 
    """
53
 
    if (aend-astart) != (bend-bstart):
54
 
        return False
55
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
56
 
        if a[ia] != b[ib]:
57
 
            return False
58
 
    else:
59
 
        return True
60
 
 
61
 
 
62
 
 
63
 
 
64
 
class Merge3(object):
65
 
    """3-way merge of texts.
66
 
 
67
 
    Given BASE, OTHER, THIS, tries to produce a combined text
68
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
69
 
    All three will typically be sequences of lines."""
70
 
 
71
 
    def __init__(self, base, a, b, is_cherrypick=False, allow_objects=False):
72
 
        """Constructor.
73
 
 
74
 
        :param base: lines in BASE
75
 
        :param a: lines in A
76
 
        :param b: lines in B
77
 
        :param is_cherrypick: flag indicating if this merge is a cherrypick.
78
 
            When cherrypicking b => a, matches with b and base do not conflict.
79
 
        :param allow_objects: if True, do not require that base, a and b are
80
 
            plain Python strs.  Also prevents BinaryFile from being raised.
81
 
            Lines can be any sequence of comparable and hashable Python
82
 
            objects.
83
 
        """
84
 
        if not allow_objects:
85
 
            textfile.check_text_lines(base)
86
 
            textfile.check_text_lines(a)
87
 
            textfile.check_text_lines(b)
88
 
        self.base = base
89
 
        self.a = a
90
 
        self.b = b
91
 
        self.is_cherrypick = is_cherrypick
92
 
 
93
 
    def merge_lines(self,
94
 
                    name_a=None,
95
 
                    name_b=None,
96
 
                    name_base=None,
97
 
                    start_marker='<<<<<<<',
98
 
                    mid_marker='=======',
99
 
                    end_marker='>>>>>>>',
100
 
                    base_marker=None,
101
 
                    reprocess=False):
102
 
        """Return merge in cvs-like form.
103
 
        """
104
 
        newline = '\n'
105
 
        if len(self.a) > 0:
106
 
            if self.a[0].endswith('\r\n'):
107
 
                newline = '\r\n'
108
 
            elif self.a[0].endswith('\r'):
109
 
                newline = '\r'
110
 
        if base_marker and reprocess:
111
 
            raise errors.CantReprocessAndShowBase()
112
 
        if name_a:
113
 
            start_marker = start_marker + ' ' + name_a
114
 
        if name_b:
115
 
            end_marker = end_marker + ' ' + name_b
116
 
        if name_base and base_marker:
117
 
            base_marker = base_marker + ' ' + name_base
118
 
        merge_regions = self.merge_regions()
119
 
        if reprocess is True:
120
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
121
 
        for t in merge_regions:
122
 
            what = t[0]
123
 
            if what == 'unchanged':
124
 
                for i in range(t[1], t[2]):
125
 
                    yield self.base[i]
126
 
            elif what == 'a' or what == 'same':
127
 
                for i in range(t[1], t[2]):
128
 
                    yield self.a[i]
129
 
            elif what == 'b':
130
 
                for i in range(t[1], t[2]):
131
 
                    yield self.b[i]
132
 
            elif what == 'conflict':
133
 
                yield start_marker + newline
134
 
                for i in range(t[3], t[4]):
135
 
                    yield self.a[i]
136
 
                if base_marker is not None:
137
 
                    yield base_marker + newline
138
 
                    for i in range(t[1], t[2]):
139
 
                        yield self.base[i]
140
 
                yield mid_marker + newline
141
 
                for i in range(t[5], t[6]):
142
 
                    yield self.b[i]
143
 
                yield end_marker + newline
144
 
            else:
145
 
                raise ValueError(what)
146
 
 
147
 
    def merge_annotated(self):
148
 
        """Return merge with conflicts, showing origin of lines.
149
 
 
150
 
        Most useful for debugging merge.
151
 
        """
152
 
        for t in self.merge_regions():
153
 
            what = t[0]
154
 
            if what == 'unchanged':
155
 
                for i in range(t[1], t[2]):
156
 
                    yield 'u | ' + self.base[i]
157
 
            elif what == 'a' or what == 'same':
158
 
                for i in range(t[1], t[2]):
159
 
                    yield what[0] + ' | ' + self.a[i]
160
 
            elif what == 'b':
161
 
                for i in range(t[1], t[2]):
162
 
                    yield 'b | ' + self.b[i]
163
 
            elif what == 'conflict':
164
 
                yield '<<<<\n'
165
 
                for i in range(t[3], t[4]):
166
 
                    yield 'A | ' + self.a[i]
167
 
                yield '----\n'
168
 
                for i in range(t[5], t[6]):
169
 
                    yield 'B | ' + self.b[i]
170
 
                yield '>>>>\n'
171
 
            else:
172
 
                raise ValueError(what)
173
 
 
174
 
    def merge_groups(self):
175
 
        """Yield sequence of line groups.  Each one is a tuple:
176
 
 
177
 
        'unchanged', lines
178
 
             Lines unchanged from base
179
 
 
180
 
        'a', lines
181
 
             Lines taken from a
182
 
 
183
 
        'same', lines
184
 
             Lines taken from a (and equal to b)
185
 
 
186
 
        'b', lines
187
 
             Lines taken from b
188
 
 
189
 
        'conflict', base_lines, a_lines, b_lines
190
 
             Lines from base were changed to either a or b and conflict.
191
 
        """
192
 
        for t in self.merge_regions():
193
 
            what = t[0]
194
 
            if what == 'unchanged':
195
 
                yield what, self.base[t[1]:t[2]]
196
 
            elif what == 'a' or what == 'same':
197
 
                yield what, self.a[t[1]:t[2]]
198
 
            elif what == 'b':
199
 
                yield what, self.b[t[1]:t[2]]
200
 
            elif what == 'conflict':
201
 
                yield (what,
202
 
                       self.base[t[1]:t[2]],
203
 
                       self.a[t[3]:t[4]],
204
 
                       self.b[t[5]:t[6]])
205
 
            else:
206
 
                raise ValueError(what)
207
 
 
208
 
    def merge_regions(self):
209
 
        """Return sequences of matching and conflicting regions.
210
 
 
211
 
        This returns tuples, where the first value says what kind we
212
 
        have:
213
 
 
214
 
        'unchanged', start, end
215
 
             Take a region of base[start:end]
216
 
 
217
 
        'same', astart, aend
218
 
             b and a are different from base but give the same result
219
 
 
220
 
        'a', start, end
221
 
             Non-clashing insertion from a[start:end]
222
 
 
223
 
        Method is as follows:
224
 
 
225
 
        The two sequences align only on regions which match the base
226
 
        and both descendents.  These are found by doing a two-way diff
227
 
        of each one against the base, and then finding the
228
 
        intersections between those regions.  These "sync regions"
229
 
        are by definition unchanged in both and easily dealt with.
230
 
 
231
 
        The regions in between can be in any of three cases:
232
 
        conflicted, or changed on only one side.
233
 
        """
234
 
 
235
 
        # section a[0:ia] has been disposed of, etc
236
 
        iz = ia = ib = 0
237
 
 
238
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
239
 
            matchlen = zend - zmatch
240
 
            # invariants:
241
 
            #   matchlen >= 0
242
 
            #   matchlen == (aend - amatch)
243
 
            #   matchlen == (bend - bmatch)
244
 
            len_a = amatch - ia
245
 
            len_b = bmatch - ib
246
 
            len_base = zmatch - iz
247
 
            # invariants:
248
 
            # assert len_a >= 0
249
 
            # assert len_b >= 0
250
 
            # assert len_base >= 0
251
 
 
252
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
253
 
 
254
 
            if len_a or len_b:
255
 
                # try to avoid actually slicing the lists
256
 
                same = compare_range(self.a, ia, amatch,
257
 
                                     self.b, ib, bmatch)
258
 
 
259
 
                if same:
260
 
                    yield 'same', ia, amatch
261
 
                else:
262
 
                    equal_a = compare_range(self.a, ia, amatch,
263
 
                                            self.base, iz, zmatch)
264
 
                    equal_b = compare_range(self.b, ib, bmatch,
265
 
                                            self.base, iz, zmatch)
266
 
                    if equal_a and not equal_b:
267
 
                        yield 'b', ib, bmatch
268
 
                    elif equal_b and not equal_a:
269
 
                        yield 'a', ia, amatch
270
 
                    elif not equal_a and not equal_b:
271
 
                        if self.is_cherrypick:
272
 
                            for node in self._refine_cherrypick_conflict(
273
 
                                                    iz, zmatch, ia, amatch,
274
 
                                                    ib, bmatch):
275
 
                                yield node
276
 
                        else:
277
 
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
278
 
                    else:
279
 
                        raise AssertionError("can't handle a=b=base but unmatched")
280
 
 
281
 
                ia = amatch
282
 
                ib = bmatch
283
 
            iz = zmatch
284
 
 
285
 
            # if the same part of the base was deleted on both sides
286
 
            # that's OK, we can just skip it.
287
 
 
288
 
            if matchlen > 0:
289
 
                # invariants:
290
 
                # assert ia == amatch
291
 
                # assert ib == bmatch
292
 
                # assert iz == zmatch
293
 
 
294
 
                yield 'unchanged', zmatch, zend
295
 
                iz = zend
296
 
                ia = aend
297
 
                ib = bend
298
 
 
299
 
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
300
 
        """When cherrypicking b => a, ignore matches with b and base."""
301
 
        # Do not emit regions which match, only regions which do not match
302
 
        matches = patiencediff.PatienceSequenceMatcher(None,
303
 
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
304
 
        last_base_idx = 0
305
 
        last_b_idx = 0
306
 
        last_b_idx = 0
307
 
        yielded_a = False
308
 
        for base_idx, b_idx, match_len in matches:
309
 
            conflict_z_len = base_idx - last_base_idx
310
 
            conflict_b_len = b_idx - last_b_idx
311
 
            if conflict_b_len == 0: # There are no lines in b which conflict,
312
 
                                    # so skip it
313
 
                pass
314
 
            else:
315
 
                if yielded_a:
316
 
                    yield ('conflict',
317
 
                           zstart + last_base_idx, zstart + base_idx,
318
 
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
319
 
                else:
320
 
                    # The first conflict gets the a-range
321
 
                    yielded_a = True
322
 
                    yield ('conflict', zstart + last_base_idx, zstart +
323
 
                    base_idx,
324
 
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
325
 
            last_base_idx = base_idx + match_len
326
 
            last_b_idx = b_idx + match_len
327
 
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
328
 
            if yielded_a:
329
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
330
 
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
331
 
            else:
332
 
                # The first conflict gets the a-range
333
 
                yielded_a = True
334
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
335
 
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
336
 
        if not yielded_a:
337
 
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
338
 
 
339
 
    def reprocess_merge_regions(self, merge_regions):
340
 
        """Where there are conflict regions, remove the agreed lines.
341
 
 
342
 
        Lines where both A and B have made the same changes are
343
 
        eliminated.
344
 
        """
345
 
        for region in merge_regions:
346
 
            if region[0] != "conflict":
347
 
                yield region
348
 
                continue
349
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
350
 
            a_region = self.a[ia:amatch]
351
 
            b_region = self.b[ib:bmatch]
352
 
            matches = patiencediff.PatienceSequenceMatcher(
353
 
                    None, a_region, b_region).get_matching_blocks()
354
 
            next_a = ia
355
 
            next_b = ib
356
 
            for region_ia, region_ib, region_len in matches[:-1]:
357
 
                region_ia += ia
358
 
                region_ib += ib
359
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
360
 
                                           region_ib)
361
 
                if reg is not None:
362
 
                    yield reg
363
 
                yield 'same', region_ia, region_len+region_ia
364
 
                next_a = region_ia + region_len
365
 
                next_b = region_ib + region_len
366
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
367
 
            if reg is not None:
368
 
                yield reg
369
 
 
370
 
    @staticmethod
371
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
372
 
        if next_a < region_ia or next_b < region_ib:
373
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
374
 
 
375
 
    def find_sync_regions(self):
376
 
        """Return a list of sync regions, where both descendents match the base.
377
 
 
378
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
379
 
        always a zero-length sync region at the end of all the files.
380
 
        """
381
 
 
382
 
        ia = ib = 0
383
 
        amatches = patiencediff.PatienceSequenceMatcher(
384
 
                None, self.base, self.a).get_matching_blocks()
385
 
        bmatches = patiencediff.PatienceSequenceMatcher(
386
 
                None, self.base, self.b).get_matching_blocks()
387
 
        len_a = len(amatches)
388
 
        len_b = len(bmatches)
389
 
 
390
 
        sl = []
391
 
 
392
 
        while ia < len_a and ib < len_b:
393
 
            abase, amatch, alen = amatches[ia]
394
 
            bbase, bmatch, blen = bmatches[ib]
395
 
 
396
 
            # there is an unconflicted block at i; how long does it
397
 
            # extend?  until whichever one ends earlier.
398
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
399
 
            if i:
400
 
                intbase = i[0]
401
 
                intend = i[1]
402
 
                intlen = intend - intbase
403
 
 
404
 
                # found a match of base[i[0], i[1]]; this may be less than
405
 
                # the region that matches in either one
406
 
                # assert intlen <= alen
407
 
                # assert intlen <= blen
408
 
                # assert abase <= intbase
409
 
                # assert bbase <= intbase
410
 
 
411
 
                asub = amatch + (intbase - abase)
412
 
                bsub = bmatch + (intbase - bbase)
413
 
                aend = asub + intlen
414
 
                bend = bsub + intlen
415
 
 
416
 
                # assert self.base[intbase:intend] == self.a[asub:aend], \
417
 
                #       (self.base[intbase:intend], self.a[asub:aend])
418
 
                # assert self.base[intbase:intend] == self.b[bsub:bend]
419
 
 
420
 
                sl.append((intbase, intend,
421
 
                           asub, aend,
422
 
                           bsub, bend))
423
 
            # advance whichever one ends first in the base text
424
 
            if (abase + alen) < (bbase + blen):
425
 
                ia += 1
426
 
            else:
427
 
                ib += 1
428
 
 
429
 
        intbase = len(self.base)
430
 
        abase = len(self.a)
431
 
        bbase = len(self.b)
432
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
433
 
 
434
 
        return sl
435
 
 
436
 
    def find_unconflicted(self):
437
 
        """Return a list of ranges in base that are not conflicted."""
438
 
        am = patiencediff.PatienceSequenceMatcher(
439
 
                None, self.base, self.a).get_matching_blocks()
440
 
        bm = patiencediff.PatienceSequenceMatcher(
441
 
                None, self.base, self.b).get_matching_blocks()
442
 
 
443
 
        unc = []
444
 
 
445
 
        while am and bm:
446
 
            # there is an unconflicted block at i; how long does it
447
 
            # extend?  until whichever one ends earlier.
448
 
            a1 = am[0][0]
449
 
            a2 = a1 + am[0][2]
450
 
            b1 = bm[0][0]
451
 
            b2 = b1 + bm[0][2]
452
 
            i = intersect((a1, a2), (b1, b2))
453
 
            if i:
454
 
                unc.append(i)
455
 
 
456
 
            if a2 < b2:
457
 
                del am[0]
458
 
            else:
459
 
                del bm[0]
460
 
 
461
 
        return unc
462
 
 
463
 
 
464
 
def main(argv):
465
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
466
 
    a = file(argv[1], 'rt').readlines()
467
 
    base = file(argv[2], 'rt').readlines()
468
 
    b = file(argv[3], 'rt').readlines()
469
 
 
470
 
    m3 = Merge3(base, a, b)
471
 
 
472
 
    #for sr in m3.find_sync_regions():
473
 
    #    print sr
474
 
 
475
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
476
 
    sys.stdout.writelines(m3.merge_annotated())
477
 
 
478
 
 
479
 
if __name__ == '__main__':
480
 
    import sys
481
 
    sys.exit(main(sys.argv))