~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-06-24 09:03:20 UTC
  • Revision ID: mbp@sourcefrog.net-20050624090320-c34ea3e5aa81d01d
- write new working inventory using AtomicFile

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004, 2005 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
 
 
17
 
 
18
 
# mbp: "you know that thing where cvs gives you conflict markers?"
19
 
# s: "i hate that."
20
 
 
21
 
 
22
 
from bzrlib.errors import CantReprocessAndShowBase
23
 
import bzrlib.patiencediff
24
 
from bzrlib.textfile import check_text_lines
25
 
 
26
 
 
27
 
def intersect(ra, rb):
28
 
    """Given two ranges return the range where they intersect or None.
29
 
 
30
 
    >>> intersect((0, 10), (0, 6))
31
 
    (0, 6)
32
 
    >>> intersect((0, 10), (5, 15))
33
 
    (5, 10)
34
 
    >>> intersect((0, 10), (10, 15))
35
 
    >>> intersect((0, 9), (10, 15))
36
 
    >>> intersect((0, 9), (7, 15))
37
 
    (7, 9)
38
 
    """
39
 
    # preconditions: (ra[0] <= ra[1]) and (rb[0] <= rb[1])
40
 
 
41
 
    sa = max(ra[0], rb[0])
42
 
    sb = min(ra[1], rb[1])
43
 
    if sa < sb:
44
 
        return sa, sb
45
 
    else:
46
 
        return None
47
 
 
48
 
 
49
 
def compare_range(a, astart, aend, b, bstart, bend):
50
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
51
 
    """
52
 
    if (aend-astart) != (bend-bstart):
53
 
        return False
54
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
55
 
        if a[ia] != b[ib]:
56
 
            return False
57
 
    else:
58
 
        return True
59
 
 
60
 
 
61
 
 
62
 
 
63
 
class Merge3(object):
64
 
    """3-way merge of texts.
65
 
 
66
 
    Given BASE, OTHER, THIS, tries to produce a combined text
67
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
68
 
    All three will typically be sequences of lines."""
69
 
    def __init__(self, base, a, b, is_cherrypick=False):
70
 
        check_text_lines(base)
71
 
        check_text_lines(a)
72
 
        check_text_lines(b)
73
 
        self.base = base
74
 
        self.a = a
75
 
        self.b = b
76
 
        self.is_cherrypick = is_cherrypick
77
 
 
78
 
    def merge_lines(self,
79
 
                    name_a=None,
80
 
                    name_b=None,
81
 
                    name_base=None,
82
 
                    start_marker='<<<<<<<',
83
 
                    mid_marker='=======',
84
 
                    end_marker='>>>>>>>',
85
 
                    base_marker=None,
86
 
                    reprocess=False):
87
 
        """Return merge in cvs-like form.
88
 
        """
89
 
        newline = '\n'
90
 
        if len(self.a) > 0:
91
 
            if self.a[0].endswith('\r\n'):
92
 
                newline = '\r\n'
93
 
            elif self.a[0].endswith('\r'):
94
 
                newline = '\r'
95
 
        if base_marker and reprocess:
96
 
            raise CantReprocessAndShowBase()
97
 
        if name_a:
98
 
            start_marker = start_marker + ' ' + name_a
99
 
        if name_b:
100
 
            end_marker = end_marker + ' ' + name_b
101
 
        if name_base and base_marker:
102
 
            base_marker = base_marker + ' ' + name_base
103
 
        merge_regions = self.merge_regions()
104
 
        if reprocess is True:
105
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
106
 
        for t in merge_regions:
107
 
            what = t[0]
108
 
            if what == 'unchanged':
109
 
                for i in range(t[1], t[2]):
110
 
                    yield self.base[i]
111
 
            elif what == 'a' or what == 'same':
112
 
                for i in range(t[1], t[2]):
113
 
                    yield self.a[i]
114
 
            elif what == 'b':
115
 
                for i in range(t[1], t[2]):
116
 
                    yield self.b[i]
117
 
            elif what == 'conflict':
118
 
                yield start_marker + newline
119
 
                for i in range(t[3], t[4]):
120
 
                    yield self.a[i]
121
 
                if base_marker is not None:
122
 
                    yield base_marker + newline
123
 
                    for i in range(t[1], t[2]):
124
 
                        yield self.base[i]
125
 
                yield mid_marker + newline
126
 
                for i in range(t[5], t[6]):
127
 
                    yield self.b[i]
128
 
                yield end_marker + newline
129
 
            else:
130
 
                raise ValueError(what)
131
 
 
132
 
    def merge_annotated(self):
133
 
        """Return merge with conflicts, showing origin of lines.
134
 
 
135
 
        Most useful for debugging merge.
136
 
        """
137
 
        for t in self.merge_regions():
138
 
            what = t[0]
139
 
            if what == 'unchanged':
140
 
                for i in range(t[1], t[2]):
141
 
                    yield 'u | ' + self.base[i]
142
 
            elif what == 'a' or what == 'same':
143
 
                for i in range(t[1], t[2]):
144
 
                    yield what[0] + ' | ' + self.a[i]
145
 
            elif what == 'b':
146
 
                for i in range(t[1], t[2]):
147
 
                    yield 'b | ' + self.b[i]
148
 
            elif what == 'conflict':
149
 
                yield '<<<<\n'
150
 
                for i in range(t[3], t[4]):
151
 
                    yield 'A | ' + self.a[i]
152
 
                yield '----\n'
153
 
                for i in range(t[5], t[6]):
154
 
                    yield 'B | ' + self.b[i]
155
 
                yield '>>>>\n'
156
 
            else:
157
 
                raise ValueError(what)
158
 
 
159
 
    def merge_groups(self):
160
 
        """Yield sequence of line groups.  Each one is a tuple:
161
 
 
162
 
        'unchanged', lines
163
 
             Lines unchanged from base
164
 
 
165
 
        'a', lines
166
 
             Lines taken from a
167
 
 
168
 
        'same', lines
169
 
             Lines taken from a (and equal to b)
170
 
 
171
 
        'b', lines
172
 
             Lines taken from b
173
 
 
174
 
        'conflict', base_lines, a_lines, b_lines
175
 
             Lines from base were changed to either a or b and conflict.
176
 
        """
177
 
        for t in self.merge_regions():
178
 
            what = t[0]
179
 
            if what == 'unchanged':
180
 
                yield what, self.base[t[1]:t[2]]
181
 
            elif what == 'a' or what == 'same':
182
 
                yield what, self.a[t[1]:t[2]]
183
 
            elif what == 'b':
184
 
                yield what, self.b[t[1]:t[2]]
185
 
            elif what == 'conflict':
186
 
                yield (what,
187
 
                       self.base[t[1]:t[2]],
188
 
                       self.a[t[3]:t[4]],
189
 
                       self.b[t[5]:t[6]])
190
 
            else:
191
 
                raise ValueError(what)
192
 
 
193
 
    def merge_regions(self):
194
 
        """Return sequences of matching and conflicting regions.
195
 
 
196
 
        This returns tuples, where the first value says what kind we
197
 
        have:
198
 
 
199
 
        'unchanged', start, end
200
 
             Take a region of base[start:end]
201
 
 
202
 
        'same', astart, aend
203
 
             b and a are different from base but give the same result
204
 
 
205
 
        'a', start, end
206
 
             Non-clashing insertion from a[start:end]
207
 
 
208
 
        Method is as follows:
209
 
 
210
 
        The two sequences align only on regions which match the base
211
 
        and both descendents.  These are found by doing a two-way diff
212
 
        of each one against the base, and then finding the
213
 
        intersections between those regions.  These "sync regions"
214
 
        are by definition unchanged in both and easily dealt with.
215
 
 
216
 
        The regions in between can be in any of three cases:
217
 
        conflicted, or changed on only one side.
218
 
        """
219
 
 
220
 
        # section a[0:ia] has been disposed of, etc
221
 
        iz = ia = ib = 0
222
 
 
223
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
224
 
            matchlen = zend - zmatch
225
 
            # invariants:
226
 
            #   matchlen >= 0
227
 
            #   matchlen == (aend - amatch)
228
 
            #   matchlen == (bend - bmatch)
229
 
            len_a = amatch - ia
230
 
            len_b = bmatch - ib
231
 
            len_base = zmatch - iz
232
 
            # invariants:
233
 
            # assert len_a >= 0
234
 
            # assert len_b >= 0
235
 
            # assert len_base >= 0
236
 
 
237
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
238
 
 
239
 
            if len_a or len_b:
240
 
                # try to avoid actually slicing the lists
241
 
                same = compare_range(self.a, ia, amatch,
242
 
                                     self.b, ib, bmatch)
243
 
 
244
 
                if same:
245
 
                    yield 'same', ia, amatch
246
 
                else:
247
 
                    equal_a = compare_range(self.a, ia, amatch,
248
 
                                            self.base, iz, zmatch)
249
 
                    equal_b = compare_range(self.b, ib, bmatch,
250
 
                                            self.base, iz, zmatch)
251
 
                    if equal_a and not equal_b:
252
 
                        yield 'b', ib, bmatch
253
 
                    elif equal_b and not equal_a:
254
 
                        yield 'a', ia, amatch
255
 
                    elif not equal_a and not equal_b:
256
 
                        if self.is_cherrypick:
257
 
                            for node in self._refine_cherrypick_conflict(
258
 
                                                    iz, zmatch, ia, amatch,
259
 
                                                    ib, bmatch):
260
 
                                yield node
261
 
                        else:
262
 
                            yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
263
 
                    else:
264
 
                        raise AssertionError("can't handle a=b=base but unmatched")
265
 
 
266
 
                ia = amatch
267
 
                ib = bmatch
268
 
            iz = zmatch
269
 
 
270
 
            # if the same part of the base was deleted on both sides
271
 
            # that's OK, we can just skip it.
272
 
 
273
 
            if matchlen > 0:
274
 
                # invariants:
275
 
                # assert ia == amatch
276
 
                # assert ib == bmatch
277
 
                # assert iz == zmatch
278
 
 
279
 
                yield 'unchanged', zmatch, zend
280
 
                iz = zend
281
 
                ia = aend
282
 
                ib = bend
283
 
 
284
 
    def _refine_cherrypick_conflict(self, zstart, zend, astart, aend, bstart, bend):
285
 
        """When cherrypicking b => a, ignore matches with b and base."""
286
 
        # Do not emit regions which match, only regions which do not match
287
 
        matches = bzrlib.patiencediff.PatienceSequenceMatcher(None,
288
 
            self.base[zstart:zend], self.b[bstart:bend]).get_matching_blocks()
289
 
        last_base_idx = 0
290
 
        last_b_idx = 0
291
 
        last_b_idx = 0
292
 
        yielded_a = False
293
 
        for base_idx, b_idx, match_len in matches:
294
 
            conflict_z_len = base_idx - last_base_idx
295
 
            conflict_b_len = b_idx - last_b_idx
296
 
            if conflict_b_len == 0: # There are no lines in b which conflict,
297
 
                                    # so skip it
298
 
                pass
299
 
            else:
300
 
                if yielded_a:
301
 
                    yield ('conflict',
302
 
                           zstart + last_base_idx, zstart + base_idx,
303
 
                           aend, aend, bstart + last_b_idx, bstart + b_idx)
304
 
                else:
305
 
                    # The first conflict gets the a-range
306
 
                    yielded_a = True
307
 
                    yield ('conflict', zstart + last_base_idx, zstart +
308
 
                    base_idx,
309
 
                           astart, aend, bstart + last_b_idx, bstart + b_idx)
310
 
            last_base_idx = base_idx + match_len
311
 
            last_b_idx = b_idx + match_len
312
 
        if last_base_idx != zend - zstart or last_b_idx != bend - bstart:
313
 
            if yielded_a:
314
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
315
 
                       aend, aend, bstart + last_b_idx, bstart + b_idx)
316
 
            else:
317
 
                # The first conflict gets the a-range
318
 
                yielded_a = True
319
 
                yield ('conflict', zstart + last_base_idx, zstart + base_idx,
320
 
                       astart, aend, bstart + last_b_idx, bstart + b_idx)
321
 
        if not yielded_a:
322
 
            yield ('conflict', zstart, zend, astart, aend, bstart, bend)
323
 
 
324
 
    def reprocess_merge_regions(self, merge_regions):
325
 
        """Where there are conflict regions, remove the agreed lines.
326
 
 
327
 
        Lines where both A and B have made the same changes are
328
 
        eliminated.
329
 
        """
330
 
        for region in merge_regions:
331
 
            if region[0] != "conflict":
332
 
                yield region
333
 
                continue
334
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
335
 
            a_region = self.a[ia:amatch]
336
 
            b_region = self.b[ib:bmatch]
337
 
            matches = bzrlib.patiencediff.PatienceSequenceMatcher(
338
 
                    None, a_region, b_region).get_matching_blocks()
339
 
            next_a = ia
340
 
            next_b = ib
341
 
            for region_ia, region_ib, region_len in matches[:-1]:
342
 
                region_ia += ia
343
 
                region_ib += ib
344
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
345
 
                                           region_ib)
346
 
                if reg is not None:
347
 
                    yield reg
348
 
                yield 'same', region_ia, region_len+region_ia
349
 
                next_a = region_ia + region_len
350
 
                next_b = region_ib + region_len
351
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
352
 
            if reg is not None:
353
 
                yield reg
354
 
 
355
 
    @staticmethod
356
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
357
 
        if next_a < region_ia or next_b < region_ib:
358
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
359
 
 
360
 
    def find_sync_regions(self):
361
 
        """Return a list of sync regions, where both descendents match the base.
362
 
 
363
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
364
 
        always a zero-length sync region at the end of all the files.
365
 
        """
366
 
 
367
 
        ia = ib = 0
368
 
        amatches = bzrlib.patiencediff.PatienceSequenceMatcher(
369
 
                None, self.base, self.a).get_matching_blocks()
370
 
        bmatches = bzrlib.patiencediff.PatienceSequenceMatcher(
371
 
                None, self.base, self.b).get_matching_blocks()
372
 
        len_a = len(amatches)
373
 
        len_b = len(bmatches)
374
 
 
375
 
        sl = []
376
 
 
377
 
        while ia < len_a and ib < len_b:
378
 
            abase, amatch, alen = amatches[ia]
379
 
            bbase, bmatch, blen = bmatches[ib]
380
 
 
381
 
            # there is an unconflicted block at i; how long does it
382
 
            # extend?  until whichever one ends earlier.
383
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
384
 
            if i:
385
 
                intbase = i[0]
386
 
                intend = i[1]
387
 
                intlen = intend - intbase
388
 
 
389
 
                # found a match of base[i[0], i[1]]; this may be less than
390
 
                # the region that matches in either one
391
 
                # assert intlen <= alen
392
 
                # assert intlen <= blen
393
 
                # assert abase <= intbase
394
 
                # assert bbase <= intbase
395
 
 
396
 
                asub = amatch + (intbase - abase)
397
 
                bsub = bmatch + (intbase - bbase)
398
 
                aend = asub + intlen
399
 
                bend = bsub + intlen
400
 
 
401
 
                # assert self.base[intbase:intend] == self.a[asub:aend], \
402
 
                #       (self.base[intbase:intend], self.a[asub:aend])
403
 
                # assert self.base[intbase:intend] == self.b[bsub:bend]
404
 
 
405
 
                sl.append((intbase, intend,
406
 
                           asub, aend,
407
 
                           bsub, bend))
408
 
            # advance whichever one ends first in the base text
409
 
            if (abase + alen) < (bbase + blen):
410
 
                ia += 1
411
 
            else:
412
 
                ib += 1
413
 
 
414
 
        intbase = len(self.base)
415
 
        abase = len(self.a)
416
 
        bbase = len(self.b)
417
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
418
 
 
419
 
        return sl
420
 
 
421
 
    def find_unconflicted(self):
422
 
        """Return a list of ranges in base that are not conflicted."""
423
 
        am = bzrlib.patiencediff.PatienceSequenceMatcher(
424
 
                None, self.base, self.a).get_matching_blocks()
425
 
        bm = bzrlib.patiencediff.PatienceSequenceMatcher(
426
 
                None, self.base, self.b).get_matching_blocks()
427
 
 
428
 
        unc = []
429
 
 
430
 
        while am and bm:
431
 
            # there is an unconflicted block at i; how long does it
432
 
            # extend?  until whichever one ends earlier.
433
 
            a1 = am[0][0]
434
 
            a2 = a1 + am[0][2]
435
 
            b1 = bm[0][0]
436
 
            b2 = b1 + bm[0][2]
437
 
            i = intersect((a1, a2), (b1, b2))
438
 
            if i:
439
 
                unc.append(i)
440
 
 
441
 
            if a2 < b2:
442
 
                del am[0]
443
 
            else:
444
 
                del bm[0]
445
 
 
446
 
        return unc
447
 
 
448
 
 
449
 
def main(argv):
450
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
451
 
    a = file(argv[1], 'rt').readlines()
452
 
    base = file(argv[2], 'rt').readlines()
453
 
    b = file(argv[3], 'rt').readlines()
454
 
 
455
 
    m3 = Merge3(base, a, b)
456
 
 
457
 
    #for sr in m3.find_sync_regions():
458
 
    #    print sr
459
 
 
460
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
461
 
    sys.stdout.writelines(m3.merge_annotated())
462
 
 
463
 
 
464
 
if __name__ == '__main__':
465
 
    import sys
466
 
    sys.exit(main(sys.argv))