~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Canonical.com Patch Queue Manager
  • Date: 2006-05-21 23:35:20 UTC
  • mfrom: (1721.1.1 integration)
  • Revision ID: pqm@pqm.ubuntu.com-20060521233520-7f8a0248d93bde80
Some tests for bzr ignored. (Robert Collins).

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 by Canonical Ltd
 
2
 
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
 
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
 
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from difflib import SequenceMatcher
 
23
 
 
24
from bzrlib.errors import CantReprocessAndShowBase
 
25
from bzrlib.textfile import check_text_lines
 
26
 
 
27
def intersect(ra, rb):
 
28
    """Given two ranges return the range where they intersect or None.
 
29
 
 
30
    >>> intersect((0, 10), (0, 6))
 
31
    (0, 6)
 
32
    >>> intersect((0, 10), (5, 15))
 
33
    (5, 10)
 
34
    >>> intersect((0, 10), (10, 15))
 
35
    >>> intersect((0, 9), (10, 15))
 
36
    >>> intersect((0, 9), (7, 15))
 
37
    (7, 9)
 
38
    """
 
39
    assert ra[0] <= ra[1]
 
40
    assert rb[0] <= rb[1]
 
41
    
 
42
    sa = max(ra[0], rb[0])
 
43
    sb = min(ra[1], rb[1])
 
44
    if sa < sb:
 
45
        return sa, sb
 
46
    else:
 
47
        return None
 
48
 
 
49
 
 
50
def compare_range(a, astart, aend, b, bstart, bend):
 
51
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
52
    """
 
53
    if (aend-astart) != (bend-bstart):
 
54
        return False
 
55
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
56
        if a[ia] != b[ib]:
 
57
            return False
 
58
    else:
 
59
        return True
 
60
        
 
61
 
 
62
 
 
63
 
 
64
class Merge3(object):
 
65
    """3-way merge of texts.
 
66
 
 
67
    Given BASE, OTHER, THIS, tries to produce a combined text
 
68
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
69
    All three will typically be sequences of lines."""
 
70
    def __init__(self, base, a, b):
 
71
        check_text_lines(base)
 
72
        check_text_lines(a)
 
73
        check_text_lines(b)
 
74
        self.base = base
 
75
        self.a = a
 
76
        self.b = b
 
77
 
 
78
 
 
79
 
 
80
    def merge_lines(self,
 
81
                    name_a=None,
 
82
                    name_b=None,
 
83
                    name_base=None,
 
84
                    start_marker='<<<<<<<',
 
85
                    mid_marker='=======',
 
86
                    end_marker='>>>>>>>',
 
87
                    base_marker=None,
 
88
                    reprocess=False):
 
89
        """Return merge in cvs-like form.
 
90
        """
 
91
        if base_marker and reprocess:
 
92
            raise CantReprocessAndShowBase()
 
93
        if name_a:
 
94
            start_marker = start_marker + ' ' + name_a
 
95
        if name_b:
 
96
            end_marker = end_marker + ' ' + name_b
 
97
        if name_base and base_marker:
 
98
            base_marker = base_marker + ' ' + name_base
 
99
        merge_regions = self.merge_regions()
 
100
        if reprocess is True:
 
101
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
102
        for t in merge_regions:
 
103
            what = t[0]
 
104
            if what == 'unchanged':
 
105
                for i in range(t[1], t[2]):
 
106
                    yield self.base[i]
 
107
            elif what == 'a' or what == 'same':
 
108
                for i in range(t[1], t[2]):
 
109
                    yield self.a[i]
 
110
            elif what == 'b':
 
111
                for i in range(t[1], t[2]):
 
112
                    yield self.b[i]
 
113
            elif what == 'conflict':
 
114
                yield start_marker + '\n'
 
115
                for i in range(t[3], t[4]):
 
116
                    yield self.a[i]
 
117
                if base_marker is not None:
 
118
                    yield base_marker + '\n'
 
119
                    for i in range(t[1], t[2]):
 
120
                        yield self.base[i]
 
121
                yield mid_marker + '\n'
 
122
                for i in range(t[5], t[6]):
 
123
                    yield self.b[i]
 
124
                yield end_marker + '\n'
 
125
            else:
 
126
                raise ValueError(what)
 
127
        
 
128
        
 
129
 
 
130
 
 
131
 
 
132
    def merge_annotated(self):
 
133
        """Return merge with conflicts, showing origin of lines.
 
134
 
 
135
        Most useful for debugging merge.        
 
136
        """
 
137
        for t in self.merge_regions():
 
138
            what = t[0]
 
139
            if what == 'unchanged':
 
140
                for i in range(t[1], t[2]):
 
141
                    yield 'u | ' + self.base[i]
 
142
            elif what == 'a' or what == 'same':
 
143
                for i in range(t[1], t[2]):
 
144
                    yield what[0] + ' | ' + self.a[i]
 
145
            elif what == 'b':
 
146
                for i in range(t[1], t[2]):
 
147
                    yield 'b | ' + self.b[i]
 
148
            elif what == 'conflict':
 
149
                yield '<<<<\n'
 
150
                for i in range(t[3], t[4]):
 
151
                    yield 'A | ' + self.a[i]
 
152
                yield '----\n'
 
153
                for i in range(t[5], t[6]):
 
154
                    yield 'B | ' + self.b[i]
 
155
                yield '>>>>\n'
 
156
            else:
 
157
                raise ValueError(what)
 
158
        
 
159
        
 
160
 
 
161
 
 
162
 
 
163
    def merge_groups(self):
 
164
        """Yield sequence of line groups.  Each one is a tuple:
 
165
 
 
166
        'unchanged', lines
 
167
             Lines unchanged from base
 
168
 
 
169
        'a', lines
 
170
             Lines taken from a
 
171
 
 
172
        'same', lines
 
173
             Lines taken from a (and equal to b)
 
174
 
 
175
        'b', lines
 
176
             Lines taken from b
 
177
 
 
178
        'conflict', base_lines, a_lines, b_lines
 
179
             Lines from base were changed to either a or b and conflict.
 
180
        """
 
181
        for t in self.merge_regions():
 
182
            what = t[0]
 
183
            if what == 'unchanged':
 
184
                yield what, self.base[t[1]:t[2]]
 
185
            elif what == 'a' or what == 'same':
 
186
                yield what, self.a[t[1]:t[2]]
 
187
            elif what == 'b':
 
188
                yield what, self.b[t[1]:t[2]]
 
189
            elif what == 'conflict':
 
190
                yield (what,
 
191
                       self.base[t[1]:t[2]],
 
192
                       self.a[t[3]:t[4]],
 
193
                       self.b[t[5]:t[6]])
 
194
            else:
 
195
                raise ValueError(what)
 
196
 
 
197
 
 
198
    def merge_regions(self):
 
199
        """Return sequences of matching and conflicting regions.
 
200
 
 
201
        This returns tuples, where the first value says what kind we
 
202
        have:
 
203
 
 
204
        'unchanged', start, end
 
205
             Take a region of base[start:end]
 
206
 
 
207
        'same', astart, aend
 
208
             b and a are different from base but give the same result
 
209
 
 
210
        'a', start, end
 
211
             Non-clashing insertion from a[start:end]
 
212
 
 
213
        Method is as follows:
 
214
 
 
215
        The two sequences align only on regions which match the base
 
216
        and both descendents.  These are found by doing a two-way diff
 
217
        of each one against the base, and then finding the
 
218
        intersections between those regions.  These "sync regions"
 
219
        are by definition unchanged in both and easily dealt with.
 
220
 
 
221
        The regions in between can be in any of three cases:
 
222
        conflicted, or changed on only one side.
 
223
        """
 
224
 
 
225
        # section a[0:ia] has been disposed of, etc
 
226
        iz = ia = ib = 0
 
227
        
 
228
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
229
            #print 'match base [%d:%d]' % (zmatch, zend)
 
230
            
 
231
            matchlen = zend - zmatch
 
232
            assert matchlen >= 0
 
233
            assert matchlen == (aend - amatch)
 
234
            assert matchlen == (bend - bmatch)
 
235
            
 
236
            len_a = amatch - ia
 
237
            len_b = bmatch - ib
 
238
            len_base = zmatch - iz
 
239
            assert len_a >= 0
 
240
            assert len_b >= 0
 
241
            assert len_base >= 0
 
242
 
 
243
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
244
 
 
245
            if len_a or len_b:
 
246
                # try to avoid actually slicing the lists
 
247
                equal_a = compare_range(self.a, ia, amatch,
 
248
                                        self.base, iz, zmatch)
 
249
                equal_b = compare_range(self.b, ib, bmatch,
 
250
                                        self.base, iz, zmatch)
 
251
                same = compare_range(self.a, ia, amatch,
 
252
                                     self.b, ib, bmatch)
 
253
 
 
254
                if same:
 
255
                    yield 'same', ia, amatch
 
256
                elif equal_a and not equal_b:
 
257
                    yield 'b', ib, bmatch
 
258
                elif equal_b and not equal_a:
 
259
                    yield 'a', ia, amatch
 
260
                elif not equal_a and not equal_b:
 
261
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
262
                else:
 
263
                    raise AssertionError("can't handle a=b=base but unmatched")
 
264
 
 
265
                ia = amatch
 
266
                ib = bmatch
 
267
            iz = zmatch
 
268
 
 
269
            # if the same part of the base was deleted on both sides
 
270
            # that's OK, we can just skip it.
 
271
 
 
272
                
 
273
            if matchlen > 0:
 
274
                assert ia == amatch
 
275
                assert ib == bmatch
 
276
                assert iz == zmatch
 
277
                
 
278
                yield 'unchanged', zmatch, zend
 
279
                iz = zend
 
280
                ia = aend
 
281
                ib = bend
 
282
    
 
283
 
 
284
    def reprocess_merge_regions(self, merge_regions):
 
285
        """Where there are conflict regions, remove the agreed lines.
 
286
 
 
287
        Lines where both A and B have made the same changes are 
 
288
        eliminated.
 
289
        """
 
290
        for region in merge_regions:
 
291
            if region[0] != "conflict":
 
292
                yield region
 
293
                continue
 
294
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
295
            a_region = self.a[ia:amatch]
 
296
            b_region = self.b[ib:bmatch]
 
297
            matches = SequenceMatcher(None, a_region, 
 
298
                                      b_region).get_matching_blocks()
 
299
            next_a = ia
 
300
            next_b = ib
 
301
            for region_ia, region_ib, region_len in matches[:-1]:
 
302
                region_ia += ia
 
303
                region_ib += ib
 
304
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
305
                                           region_ib)
 
306
                if reg is not None:
 
307
                    yield reg
 
308
                yield 'same', region_ia, region_len+region_ia
 
309
                next_a = region_ia + region_len
 
310
                next_b = region_ib + region_len
 
311
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
312
            if reg is not None:
 
313
                yield reg
 
314
 
 
315
 
 
316
    @staticmethod
 
317
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
318
        if next_a < region_ia or next_b < region_ib:
 
319
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
320
            
 
321
 
 
322
    def find_sync_regions(self):
 
323
        """Return a list of sync regions, where both descendents match the base.
 
324
 
 
325
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
326
        always a zero-length sync region at the end of all the files.
 
327
        """
 
328
 
 
329
        ia = ib = 0
 
330
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
331
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
332
        len_a = len(amatches)
 
333
        len_b = len(bmatches)
 
334
 
 
335
        sl = []
 
336
 
 
337
        while ia < len_a and ib < len_b:
 
338
            abase, amatch, alen = amatches[ia]
 
339
            bbase, bmatch, blen = bmatches[ib]
 
340
 
 
341
            # there is an unconflicted block at i; how long does it
 
342
            # extend?  until whichever one ends earlier.
 
343
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
344
            if i:
 
345
                intbase = i[0]
 
346
                intend = i[1]
 
347
                intlen = intend - intbase
 
348
 
 
349
                # found a match of base[i[0], i[1]]; this may be less than
 
350
                # the region that matches in either one
 
351
                assert intlen <= alen
 
352
                assert intlen <= blen
 
353
                assert abase <= intbase
 
354
                assert bbase <= intbase
 
355
 
 
356
                asub = amatch + (intbase - abase)
 
357
                bsub = bmatch + (intbase - bbase)
 
358
                aend = asub + intlen
 
359
                bend = bsub + intlen
 
360
 
 
361
                assert self.base[intbase:intend] == self.a[asub:aend], \
 
362
                       (self.base[intbase:intend], self.a[asub:aend])
 
363
 
 
364
                assert self.base[intbase:intend] == self.b[bsub:bend]
 
365
 
 
366
                sl.append((intbase, intend,
 
367
                           asub, aend,
 
368
                           bsub, bend))
 
369
 
 
370
            # advance whichever one ends first in the base text
 
371
            if (abase + alen) < (bbase + blen):
 
372
                ia += 1
 
373
            else:
 
374
                ib += 1
 
375
            
 
376
        intbase = len(self.base)
 
377
        abase = len(self.a)
 
378
        bbase = len(self.b)
 
379
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
380
 
 
381
        return sl
 
382
 
 
383
 
 
384
 
 
385
    def find_unconflicted(self):
 
386
        """Return a list of ranges in base that are not conflicted."""
 
387
 
 
388
        import re
 
389
 
 
390
        # don't sync-up on lines containing only blanks or pounds
 
391
        junk_re = re.compile(r'^[ \t#]*$')
 
392
        
 
393
        am = SequenceMatcher(junk_re.match, self.base, self.a).get_matching_blocks()
 
394
        bm = SequenceMatcher(junk_re.match, self.base, self.b).get_matching_blocks()
 
395
 
 
396
        unc = []
 
397
 
 
398
        while am and bm:
 
399
            # there is an unconflicted block at i; how long does it
 
400
            # extend?  until whichever one ends earlier.
 
401
            a1 = am[0][0]
 
402
            a2 = a1 + am[0][2]
 
403
            b1 = bm[0][0]
 
404
            b2 = b1 + bm[0][2]
 
405
            i = intersect((a1, a2), (b1, b2))
 
406
            if i:
 
407
                unc.append(i)
 
408
 
 
409
            if a2 < b2:
 
410
                del am[0]
 
411
            else:
 
412
                del bm[0]
 
413
                
 
414
        return unc
 
415
 
 
416
 
 
417
def main(argv):
 
418
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
419
    a = file(argv[1], 'rt').readlines()
 
420
    base = file(argv[2], 'rt').readlines()
 
421
    b = file(argv[3], 'rt').readlines()
 
422
 
 
423
    m3 = Merge3(base, a, b)
 
424
 
 
425
    #for sr in m3.find_sync_regions():
 
426
    #    print sr
 
427
 
 
428
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
429
    sys.stdout.writelines(m3.merge_annotated())
 
430
 
 
431
 
 
432
if __name__ == '__main__':
 
433
    import sys
 
434
    sys.exit(main(sys.argv))