~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

[merge] jelmer

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 by Canonical Ltd
 
2
 
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
 
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
 
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
from difflib import SequenceMatcher
 
23
from bzrlib.errors import CantReprocessAndShowBase
 
24
 
 
25
def intersect(ra, rb):
 
26
    """Given two ranges return the range where they intersect or None.
 
27
 
 
28
    >>> intersect((0, 10), (0, 6))
 
29
    (0, 6)
 
30
    >>> intersect((0, 10), (5, 15))
 
31
    (5, 10)
 
32
    >>> intersect((0, 10), (10, 15))
 
33
    >>> intersect((0, 9), (10, 15))
 
34
    >>> intersect((0, 9), (7, 15))
 
35
    (7, 9)
 
36
    """
 
37
    assert ra[0] <= ra[1]
 
38
    assert rb[0] <= rb[1]
 
39
    
 
40
    sa = max(ra[0], rb[0])
 
41
    sb = min(ra[1], rb[1])
 
42
    if sa < sb:
 
43
        return sa, sb
 
44
    else:
 
45
        return None
 
46
 
 
47
 
 
48
def compare_range(a, astart, aend, b, bstart, bend):
 
49
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
50
    """
 
51
    if (aend-astart) != (bend-bstart):
 
52
        return False
 
53
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
54
        if a[ia] != b[ib]:
 
55
            return False
 
56
    else:
 
57
        return True
 
58
        
 
59
 
 
60
 
 
61
 
 
62
class Merge3(object):
 
63
    """3-way merge of texts.
 
64
 
 
65
    Given BASE, OTHER, THIS, tries to produce a combined text
 
66
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
67
    All three will typically be sequences of lines."""
 
68
    def __init__(self, base, a, b):
 
69
        self.base = base
 
70
        self.a = a
 
71
        self.b = b
 
72
 
 
73
 
 
74
 
 
75
    def merge_lines(self,
 
76
                    name_a=None,
 
77
                    name_b=None,
 
78
                    name_base=None,
 
79
                    start_marker='<<<<<<<',
 
80
                    mid_marker='=======',
 
81
                    end_marker='>>>>>>>',
 
82
                    base_marker=None,
 
83
                    reprocess=False):
 
84
        """Return merge in cvs-like form.
 
85
        """
 
86
        if base_marker and reprocess:
 
87
            raise CantReprocessAndShowBase()
 
88
        if name_a:
 
89
            start_marker = start_marker + ' ' + name_a
 
90
        if name_b:
 
91
            end_marker = end_marker + ' ' + name_b
 
92
        if name_base and base_marker:
 
93
            base_marker = base_marker + ' ' + name_base
 
94
        merge_regions = self.merge_regions()
 
95
        if reprocess is True:
 
96
            merge_regions = self.reprocess_merge_regions(merge_regions)
 
97
        for t in merge_regions:
 
98
            what = t[0]
 
99
            if what == 'unchanged':
 
100
                for i in range(t[1], t[2]):
 
101
                    yield self.base[i]
 
102
            elif what == 'a' or what == 'same':
 
103
                for i in range(t[1], t[2]):
 
104
                    yield self.a[i]
 
105
            elif what == 'b':
 
106
                for i in range(t[1], t[2]):
 
107
                    yield self.b[i]
 
108
            elif what == 'conflict':
 
109
                yield start_marker + '\n'
 
110
                for i in range(t[3], t[4]):
 
111
                    yield self.a[i]
 
112
                if base_marker is not None:
 
113
                    yield base_marker + '\n'
 
114
                    for i in range(t[1], t[2]):
 
115
                        yield self.base[i]
 
116
                yield mid_marker + '\n'
 
117
                for i in range(t[5], t[6]):
 
118
                    yield self.b[i]
 
119
                yield end_marker + '\n'
 
120
            else:
 
121
                raise ValueError(what)
 
122
        
 
123
        
 
124
 
 
125
 
 
126
 
 
127
    def merge_annotated(self):
 
128
        """Return merge with conflicts, showing origin of lines.
 
129
 
 
130
        Most useful for debugging merge.        
 
131
        """
 
132
        for t in self.merge_regions():
 
133
            what = t[0]
 
134
            if what == 'unchanged':
 
135
                for i in range(t[1], t[2]):
 
136
                    yield 'u | ' + self.base[i]
 
137
            elif what == 'a' or what == 'same':
 
138
                for i in range(t[1], t[2]):
 
139
                    yield what[0] + ' | ' + self.a[i]
 
140
            elif what == 'b':
 
141
                for i in range(t[1], t[2]):
 
142
                    yield 'b | ' + self.b[i]
 
143
            elif what == 'conflict':
 
144
                yield '<<<<\n'
 
145
                for i in range(t[3], t[4]):
 
146
                    yield 'A | ' + self.a[i]
 
147
                yield '----\n'
 
148
                for i in range(t[5], t[6]):
 
149
                    yield 'B | ' + self.b[i]
 
150
                yield '>>>>\n'
 
151
            else:
 
152
                raise ValueError(what)
 
153
        
 
154
        
 
155
 
 
156
 
 
157
 
 
158
    def merge_groups(self):
 
159
        """Yield sequence of line groups.  Each one is a tuple:
 
160
 
 
161
        'unchanged', lines
 
162
             Lines unchanged from base
 
163
 
 
164
        'a', lines
 
165
             Lines taken from a
 
166
 
 
167
        'same', lines
 
168
             Lines taken from a (and equal to b)
 
169
 
 
170
        'b', lines
 
171
             Lines taken from b
 
172
 
 
173
        'conflict', base_lines, a_lines, b_lines
 
174
             Lines from base were changed to either a or b and conflict.
 
175
        """
 
176
        for t in self.merge_regions():
 
177
            what = t[0]
 
178
            if what == 'unchanged':
 
179
                yield what, self.base[t[1]:t[2]]
 
180
            elif what == 'a' or what == 'same':
 
181
                yield what, self.a[t[1]:t[2]]
 
182
            elif what == 'b':
 
183
                yield what, self.b[t[1]:t[2]]
 
184
            elif what == 'conflict':
 
185
                yield (what,
 
186
                       self.base[t[1]:t[2]],
 
187
                       self.a[t[3]:t[4]],
 
188
                       self.b[t[5]:t[6]])
 
189
            else:
 
190
                raise ValueError(what)
 
191
 
 
192
 
 
193
    def merge_regions(self):
 
194
        """Return sequences of matching and conflicting regions.
 
195
 
 
196
        This returns tuples, where the first value says what kind we
 
197
        have:
 
198
 
 
199
        'unchanged', start, end
 
200
             Take a region of base[start:end]
 
201
 
 
202
        'same', astart, aend
 
203
             b and a are different from base but give the same result
 
204
 
 
205
        'a', start, end
 
206
             Non-clashing insertion from a[start:end]
 
207
 
 
208
        Method is as follows:
 
209
 
 
210
        The two sequences align only on regions which match the base
 
211
        and both descendents.  These are found by doing a two-way diff
 
212
        of each one against the base, and then finding the
 
213
        intersections between those regions.  These "sync regions"
 
214
        are by definition unchanged in both and easily dealt with.
 
215
 
 
216
        The regions in between can be in any of three cases:
 
217
        conflicted, or changed on only one side.
 
218
        """
 
219
 
 
220
        # section a[0:ia] has been disposed of, etc
 
221
        iz = ia = ib = 0
 
222
        
 
223
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
224
            #print 'match base [%d:%d]' % (zmatch, zend)
 
225
            
 
226
            matchlen = zend - zmatch
 
227
            assert matchlen >= 0
 
228
            assert matchlen == (aend - amatch)
 
229
            assert matchlen == (bend - bmatch)
 
230
            
 
231
            len_a = amatch - ia
 
232
            len_b = bmatch - ib
 
233
            len_base = zmatch - iz
 
234
            assert len_a >= 0
 
235
            assert len_b >= 0
 
236
            assert len_base >= 0
 
237
 
 
238
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
239
 
 
240
            if len_a or len_b:
 
241
                # try to avoid actually slicing the lists
 
242
                equal_a = compare_range(self.a, ia, amatch,
 
243
                                        self.base, iz, zmatch)
 
244
                equal_b = compare_range(self.b, ib, bmatch,
 
245
                                        self.base, iz, zmatch)
 
246
                same = compare_range(self.a, ia, amatch,
 
247
                                     self.b, ib, bmatch)
 
248
 
 
249
                if same:
 
250
                    yield 'same', ia, amatch
 
251
                elif equal_a and not equal_b:
 
252
                    yield 'b', ib, bmatch
 
253
                elif equal_b and not equal_a:
 
254
                    yield 'a', ia, amatch
 
255
                elif not equal_a and not equal_b:
 
256
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
257
                else:
 
258
                    raise AssertionError("can't handle a=b=base but unmatched")
 
259
 
 
260
                ia = amatch
 
261
                ib = bmatch
 
262
            iz = zmatch
 
263
 
 
264
            # if the same part of the base was deleted on both sides
 
265
            # that's OK, we can just skip it.
 
266
 
 
267
                
 
268
            if matchlen > 0:
 
269
                assert ia == amatch
 
270
                assert ib == bmatch
 
271
                assert iz == zmatch
 
272
                
 
273
                yield 'unchanged', zmatch, zend
 
274
                iz = zend
 
275
                ia = aend
 
276
                ib = bend
 
277
    
 
278
 
 
279
    def reprocess_merge_regions(self, merge_regions):
 
280
        for region in merge_regions:
 
281
            if region[0] != "conflict":
 
282
                yield region
 
283
                continue
 
284
            type, iz, zmatch, ia, amatch, ib, bmatch = region
 
285
            a_region = self.a[ia:amatch]
 
286
            b_region = self.b[ib:bmatch]
 
287
            matches = SequenceMatcher(None, a_region, 
 
288
                                      b_region).get_matching_blocks()
 
289
            next_a = ia
 
290
            next_b = ib
 
291
            for region_ia, region_ib, region_len in matches[:-1]:
 
292
                region_ia += ia
 
293
                region_ib += ib
 
294
                reg = self.mismatch_region(next_a, region_ia, next_b,
 
295
                                           region_ib)
 
296
                if reg is not None:
 
297
                    yield reg
 
298
                yield 'same', region_ia, region_len+region_ia
 
299
                next_a = region_ia + region_len
 
300
                next_b = region_ib + region_len
 
301
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
 
302
            if reg is not None:
 
303
                yield reg
 
304
 
 
305
 
 
306
    @staticmethod
 
307
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
 
308
        if next_a < region_ia or next_b < region_ib:
 
309
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
 
310
            
 
311
 
 
312
    def find_sync_regions(self):
 
313
        """Return a list of sync regions, where both descendents match the base.
 
314
 
 
315
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
316
        always a zero-length sync region at the end of all the files.
 
317
        """
 
318
 
 
319
        ia = ib = 0
 
320
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
321
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
322
        len_a = len(amatches)
 
323
        len_b = len(bmatches)
 
324
 
 
325
        sl = []
 
326
 
 
327
        while ia < len_a and ib < len_b:
 
328
            abase, amatch, alen = amatches[ia]
 
329
            bbase, bmatch, blen = bmatches[ib]
 
330
 
 
331
            # there is an unconflicted block at i; how long does it
 
332
            # extend?  until whichever one ends earlier.
 
333
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
334
            if i:
 
335
                intbase = i[0]
 
336
                intend = i[1]
 
337
                intlen = intend - intbase
 
338
 
 
339
                # found a match of base[i[0], i[1]]; this may be less than
 
340
                # the region that matches in either one
 
341
                assert intlen <= alen
 
342
                assert intlen <= blen
 
343
                assert abase <= intbase
 
344
                assert bbase <= intbase
 
345
 
 
346
                asub = amatch + (intbase - abase)
 
347
                bsub = bmatch + (intbase - bbase)
 
348
                aend = asub + intlen
 
349
                bend = bsub + intlen
 
350
 
 
351
                assert self.base[intbase:intend] == self.a[asub:aend], \
 
352
                       (self.base[intbase:intend], self.a[asub:aend])
 
353
 
 
354
                assert self.base[intbase:intend] == self.b[bsub:bend]
 
355
 
 
356
                sl.append((intbase, intend,
 
357
                           asub, aend,
 
358
                           bsub, bend))
 
359
 
 
360
            # advance whichever one ends first in the base text
 
361
            if (abase + alen) < (bbase + blen):
 
362
                ia += 1
 
363
            else:
 
364
                ib += 1
 
365
            
 
366
        intbase = len(self.base)
 
367
        abase = len(self.a)
 
368
        bbase = len(self.b)
 
369
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
370
 
 
371
        return sl
 
372
 
 
373
 
 
374
 
 
375
    def find_unconflicted(self):
 
376
        """Return a list of ranges in base that are not conflicted."""
 
377
 
 
378
        import re
 
379
 
 
380
        # don't sync-up on lines containing only blanks or pounds
 
381
        junk_re = re.compile(r'^[ \t#]*$')
 
382
        
 
383
        am = SequenceMatcher(junk_re.match, self.base, self.a).get_matching_blocks()
 
384
        bm = SequenceMatcher(junk_re.match, self.base, self.b).get_matching_blocks()
 
385
 
 
386
        unc = []
 
387
 
 
388
        while am and bm:
 
389
            # there is an unconflicted block at i; how long does it
 
390
            # extend?  until whichever one ends earlier.
 
391
            a1 = am[0][0]
 
392
            a2 = a1 + am[0][2]
 
393
            b1 = bm[0][0]
 
394
            b2 = b1 + bm[0][2]
 
395
            i = intersect((a1, a2), (b1, b2))
 
396
            if i:
 
397
                unc.append(i)
 
398
 
 
399
            if a2 < b2:
 
400
                del am[0]
 
401
            else:
 
402
                del bm[0]
 
403
                
 
404
        return unc
 
405
 
 
406
 
 
407
def main(argv):
 
408
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
409
    a = file(argv[1], 'rt').readlines()
 
410
    base = file(argv[2], 'rt').readlines()
 
411
    b = file(argv[3], 'rt').readlines()
 
412
 
 
413
    m3 = Merge3(base, a, b)
 
414
 
 
415
    #for sr in m3.find_sync_regions():
 
416
    #    print sr
 
417
 
 
418
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
419
    sys.stdout.writelines(m3.merge_annotated())
 
420
 
 
421
 
 
422
if __name__ == '__main__':
 
423
    import sys
 
424
    sys.exit(main(sys.argv))