~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

Forgot to add the test case

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 by Canonical Ltd
 
2
 
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
 
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
 
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from difflib import SequenceMatcher
 
23
from bzrlib.errors import CantReprocessAndShowBase
 
24
 
 
25
def intersect(ra, rb):
 
26
    """Given two ranges return the range where they intersect or None.
 
27
 
 
28
    >>> intersect((0, 10), (0, 6))
 
29
    (0, 6)
 
30
    >>> intersect((0, 10), (5, 15))
 
31
    (5, 10)
 
32
    >>> intersect((0, 10), (10, 15))
 
33
    >>> intersect((0, 9), (10, 15))
 
34
    >>> intersect((0, 9), (7, 15))
 
35
    (7, 9)
 
36
    """
 
37
    assert ra[0] <= ra[1]
 
38
    assert rb[0] <= rb[1]
 
39
    
 
40
    sa = max(ra[0], rb[0])
 
41
    sb = min(ra[1], rb[1])
 
42
    if sa < sb:
 
43
        return sa, sb
 
44
    else:
 
45
        return None
 
46
 
 
47
 
 
48
def compare_range(a, astart, aend, b, bstart, bend):
 
49
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
50
    """
 
51
    if (aend-astart) != (bend-bstart):
 
52
        return False
 
53
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
54
        if a[ia] != b[ib]:
 
55
            return False
 
56
    else:
 
57
        return True
 
58
        
 
59
 
 
60
 
 
61
 
 
62
class Merge3(object):
 
63
    """3-way merge of texts.
 
64
 
 
65
    Given BASE, OTHER, THIS, tries to produce a combined text
 
66
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
67
    All three will typically be sequences of lines."""
 
68
    def __init__(self, base, a, b):
 
69
        self.base = base
 
70
        self.a = a
 
71
        self.b = b
 
72
 
 
73
 
 
74
 
 
75
    def merge_lines(self,
 
76
                    name_a=None,
 
77
                    name_b=None,
 
78
                    name_base=None,
 
79
                    start_marker='<<<<<<<',
 
80
                    mid_marker='=======',
 
81
                    end_marker='>>>>>>>',
 
82
                    base_marker=None,
 
83
                    reprocess=False):
 
84
        """Return merge in cvs-like form.
 
85
        """
 
86
        if base_marker and reprocess:
 
87
            raise CantReprocessAndShowBase()
 
88
        if name_a:
 
89
            start_marker = start_marker + ' ' + name_a
 
90
        if name_b:
 
91
            end_marker = end_marker + ' ' + name_b
 
92
        if name_base and base_marker:
 
93
            base_marker = base_marker + ' ' + name_base
 
94
        merge_regions = self.merge_regions()
 
95
        if reprocess is True:
 
96
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
97
        for t in merge_regions:
 
98
            what = t[0]
 
99
            if what == 'unchanged':
 
100
                for i in range(t[1], t[2]):
 
101
                    yield self.base[i]
 
102
            elif what == 'a' or what == 'same':
 
103
                for i in range(t[1], t[2]):
 
104
                    yield self.a[i]
 
105
            elif what == 'b':
 
106
                for i in range(t[1], t[2]):
 
107
                    yield self.b[i]
 
108
            elif what == 'conflict':
 
109
                yield start_marker + '\n'
 
110
                for i in range(t[3], t[4]):
 
111
                    yield self.a[i]
 
112
                if base_marker is not None:
 
113
                    yield base_marker + '\n'
 
114
                    for i in range(t[1], t[2]):
 
115
                        yield self.base[i]
 
116
                yield mid_marker + '\n'
 
117
                for i in range(t[5], t[6]):
 
118
                    yield self.b[i]
 
119
                yield end_marker + '\n'
 
120
            else:
 
121
                raise ValueError(what)
 
122
        
 
123
        
 
124
 
 
125
 
 
126
 
 
127
    def merge_annotated(self):
 
128
        """Return merge with conflicts, showing origin of lines.
 
129
 
 
130
        Most useful for debugging merge.        
 
131
        """
 
132
        for t in self.merge_regions():
 
133
            what = t[0]
 
134
            if what == 'unchanged':
 
135
                for i in range(t[1], t[2]):
 
136
                    yield 'u | ' + self.base[i]
 
137
            elif what == 'a' or what == 'same':
 
138
                for i in range(t[1], t[2]):
 
139
                    yield what[0] + ' | ' + self.a[i]
 
140
            elif what == 'b':
 
141
                for i in range(t[1], t[2]):
 
142
                    yield 'b | ' + self.b[i]
 
143
            elif what == 'conflict':
 
144
                yield '<<<<\n'
 
145
                for i in range(t[3], t[4]):
 
146
                    yield 'A | ' + self.a[i]
 
147
                yield '----\n'
 
148
                for i in range(t[5], t[6]):
 
149
                    yield 'B | ' + self.b[i]
 
150
                yield '>>>>\n'
 
151
            else:
 
152
                raise ValueError(what)
 
153
        
 
154
        
 
155
 
 
156
 
 
157
 
 
158
    def merge_groups(self):
 
159
        """Yield sequence of line groups.  Each one is a tuple:
 
160
 
 
161
        'unchanged', lines
 
162
             Lines unchanged from base
 
163
 
 
164
        'a', lines
 
165
             Lines taken from a
 
166
 
 
167
        'same', lines
 
168
             Lines taken from a (and equal to b)
 
169
 
 
170
        'b', lines
 
171
             Lines taken from b
 
172
 
 
173
        'conflict', base_lines, a_lines, b_lines
 
174
             Lines from base were changed to either a or b and conflict.
 
175
        """
 
176
        for t in self.merge_regions():
 
177
            what = t[0]
 
178
            if what == 'unchanged':
 
179
                yield what, self.base[t[1]:t[2]]
 
180
            elif what == 'a' or what == 'same':
 
181
                yield what, self.a[t[1]:t[2]]
 
182
            elif what == 'b':
 
183
                yield what, self.b[t[1]:t[2]]
 
184
            elif what == 'conflict':
 
185
                yield (what,
 
186
                       self.base[t[1]:t[2]],
 
187
                       self.a[t[3]:t[4]],
 
188
                       self.b[t[5]:t[6]])
 
189
            else:
 
190
                raise ValueError(what)
 
191
 
 
192
 
 
193
    def merge_regions(self):
 
194
        """Return sequences of matching and conflicting regions.
 
195
 
 
196
        This returns tuples, where the first value says what kind we
 
197
        have:
 
198
 
 
199
        'unchanged', start, end
 
200
             Take a region of base[start:end]
 
201
 
 
202
        'same', astart, aend
 
203
             b and a are different from base but give the same result
 
204
 
 
205
        'a', start, end
 
206
             Non-clashing insertion from a[start:end]
 
207
 
 
208
        Method is as follows:
 
209
 
 
210
        The two sequences align only on regions which match the base
 
211
        and both descendents.  These are found by doing a two-way diff
 
212
        of each one against the base, and then finding the
 
213
        intersections between those regions.  These "sync regions"
 
214
        are by definition unchanged in both and easily dealt with.
 
215
 
 
216
        The regions in between can be in any of three cases:
 
217
        conflicted, or changed on only one side.
 
218
        """
 
219
 
 
220
        # section a[0:ia] has been disposed of, etc
 
221
        iz = ia = ib = 0
 
222
        
 
223
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
224
            #print 'match base [%d:%d]' % (zmatch, zend)
 
225
            
 
226
            matchlen = zend - zmatch
 
227
            assert matchlen >= 0
 
228
            assert matchlen == (aend - amatch)
 
229
            assert matchlen == (bend - bmatch)
 
230
            
 
231
            len_a = amatch - ia
 
232
            len_b = bmatch - ib
 
233
            len_base = zmatch - iz
 
234
            assert len_a >= 0
 
235
            assert len_b >= 0
 
236
            assert len_base >= 0
 
237
 
 
238
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
239
 
 
240
            if len_a or len_b:
 
241
                # try to avoid actually slicing the lists
 
242
                equal_a = compare_range(self.a, ia, amatch,
 
243
                                        self.base, iz, zmatch)
 
244
                equal_b = compare_range(self.b, ib, bmatch,
 
245
                                        self.base, iz, zmatch)
 
246
                same = compare_range(self.a, ia, amatch,
 
247
                                     self.b, ib, bmatch)
 
248
 
 
249
                if same:
 
250
                    yield 'same', ia, amatch
 
251
                elif equal_a and not equal_b:
 
252
                    yield 'b', ib, bmatch
 
253
                elif equal_b and not equal_a:
 
254
                    yield 'a', ia, amatch
 
255
                elif not equal_a and not equal_b:
 
256
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
257
                else:
 
258
                    raise AssertionError("can't handle a=b=base but unmatched")
 
259
 
 
260
                ia = amatch
 
261
                ib = bmatch
 
262
            iz = zmatch
 
263
 
 
264
            # if the same part of the base was deleted on both sides
 
265
            # that's OK, we can just skip it.
 
266
 
 
267
                
 
268
            if matchlen > 0:
 
269
                assert ia == amatch
 
270
                assert ib == bmatch
 
271
                assert iz == zmatch
 
272
                
 
273
                yield 'unchanged', zmatch, zend
 
274
                iz = zend
 
275
                ia = aend
 
276
                ib = bend
 
277
    
 
278
 
 
279
    def reprocess_merge_regions(self, merge_regions):
 
280
        """Where there are conflict regions, remove the agreed lines.
 
281
 
 
282
        Lines where both A and B have made the same changes are 
 
283
        eliminated.
 
284
        """
 
285
        for region in merge_regions:
 
286
            if region[0] != "conflict":
 
287
                yield region
 
288
                continue
 
289
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
290
            a_region = self.a[ia:amatch]
 
291
            b_region = self.b[ib:bmatch]
 
292
            matches = SequenceMatcher(None, a_region, 
 
293
                                      b_region).get_matching_blocks()
 
294
            next_a = ia
 
295
            next_b = ib
 
296
            for region_ia, region_ib, region_len in matches[:-1]:
 
297
                region_ia += ia
 
298
                region_ib += ib
 
299
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
300
                                           region_ib)
 
301
                if reg is not None:
 
302
                    yield reg
 
303
                yield 'same', region_ia, region_len+region_ia
 
304
                next_a = region_ia + region_len
 
305
                next_b = region_ib + region_len
 
306
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
307
            if reg is not None:
 
308
                yield reg
 
309
 
 
310
 
 
311
    @staticmethod
 
312
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
313
        if next_a < region_ia or next_b < region_ib:
 
314
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
315
            
 
316
 
 
317
    def find_sync_regions(self):
 
318
        """Return a list of sync regions, where both descendents match the base.
 
319
 
 
320
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
321
        always a zero-length sync region at the end of all the files.
 
322
        """
 
323
 
 
324
        ia = ib = 0
 
325
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
326
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
327
        len_a = len(amatches)
 
328
        len_b = len(bmatches)
 
329
 
 
330
        sl = []
 
331
 
 
332
        while ia < len_a and ib < len_b:
 
333
            abase, amatch, alen = amatches[ia]
 
334
            bbase, bmatch, blen = bmatches[ib]
 
335
 
 
336
            # there is an unconflicted block at i; how long does it
 
337
            # extend?  until whichever one ends earlier.
 
338
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
339
            if i:
 
340
                intbase = i[0]
 
341
                intend = i[1]
 
342
                intlen = intend - intbase
 
343
 
 
344
                # found a match of base[i[0], i[1]]; this may be less than
 
345
                # the region that matches in either one
 
346
                assert intlen <= alen
 
347
                assert intlen <= blen
 
348
                assert abase <= intbase
 
349
                assert bbase <= intbase
 
350
 
 
351
                asub = amatch + (intbase - abase)
 
352
                bsub = bmatch + (intbase - bbase)
 
353
                aend = asub + intlen
 
354
                bend = bsub + intlen
 
355
 
 
356
                assert self.base[intbase:intend] == self.a[asub:aend], \
 
357
                       (self.base[intbase:intend], self.a[asub:aend])
 
358
 
 
359
                assert self.base[intbase:intend] == self.b[bsub:bend]
 
360
 
 
361
                sl.append((intbase, intend,
 
362
                           asub, aend,
 
363
                           bsub, bend))
 
364
 
 
365
            # advance whichever one ends first in the base text
 
366
            if (abase + alen) < (bbase + blen):
 
367
                ia += 1
 
368
            else:
 
369
                ib += 1
 
370
            
 
371
        intbase = len(self.base)
 
372
        abase = len(self.a)
 
373
        bbase = len(self.b)
 
374
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
375
 
 
376
        return sl
 
377
 
 
378
 
 
379
 
 
380
    def find_unconflicted(self):
 
381
        """Return a list of ranges in base that are not conflicted."""
 
382
 
 
383
        import re
 
384
 
 
385
        # don't sync-up on lines containing only blanks or pounds
 
386
        junk_re = re.compile(r'^[ \t#]*$')
 
387
        
 
388
        am = SequenceMatcher(junk_re.match, self.base, self.a).get_matching_blocks()
 
389
        bm = SequenceMatcher(junk_re.match, self.base, self.b).get_matching_blocks()
 
390
 
 
391
        unc = []
 
392
 
 
393
        while am and bm:
 
394
            # there is an unconflicted block at i; how long does it
 
395
            # extend?  until whichever one ends earlier.
 
396
            a1 = am[0][0]
 
397
            a2 = a1 + am[0][2]
 
398
            b1 = bm[0][0]
 
399
            b2 = b1 + bm[0][2]
 
400
            i = intersect((a1, a2), (b1, b2))
 
401
            if i:
 
402
                unc.append(i)
 
403
 
 
404
            if a2 < b2:
 
405
                del am[0]
 
406
            else:
 
407
                del bm[0]
 
408
                
 
409
        return unc
 
410
 
 
411
 
 
412
def main(argv):
 
413
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
414
    a = file(argv[1], 'rt').readlines()
 
415
    base = file(argv[2], 'rt').readlines()
 
416
    b = file(argv[3], 'rt').readlines()
 
417
 
 
418
    m3 = Merge3(base, a, b)
 
419
 
 
420
    #for sr in m3.find_sync_regions():
 
421
    #    print sr
 
422
 
 
423
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
424
    sys.stdout.writelines(m3.merge_annotated())
 
425
 
 
426
 
 
427
if __name__ == '__main__':
 
428
    import sys
 
429
    sys.exit(main(sys.argv))