~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

todo

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 by Canonical Ltd
 
2
 
 
3
# This program is free software; you can redistribute it and/or modify
 
4
# it under the terms of the GNU General Public License as published by
 
5
# the Free Software Foundation; either version 2 of the License, or
 
6
# (at your option) any later version.
 
7
 
 
8
# This program is distributed in the hope that it will be useful,
 
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
# GNU General Public License for more details.
 
12
 
 
13
# You should have received a copy of the GNU General Public License
 
14
# along with this program; if not, write to the Free Software
 
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
16
 
 
17
 
 
18
# mbp: "you know that thing where cvs gives you conflict markers?"
 
19
# s: "i hate that."
 
20
 
 
21
 
 
22
 
 
23
def intersect(ra, rb):
 
24
    """Given two ranges return the range where they intersect or None.
 
25
 
 
26
    >>> intersect((0, 10), (0, 6))
 
27
    (0, 6)
 
28
    >>> intersect((0, 10), (5, 15))
 
29
    (5, 10)
 
30
    >>> intersect((0, 10), (10, 15))
 
31
    >>> intersect((0, 9), (10, 15))
 
32
    >>> intersect((0, 9), (7, 15))
 
33
    (7, 9)
 
34
    """
 
35
    assert ra[0] <= ra[1]
 
36
    assert rb[0] <= rb[1]
 
37
    
 
38
    sa = max(ra[0], rb[0])
 
39
    sb = min(ra[1], rb[1])
 
40
    if sa < sb:
 
41
        return sa, sb
 
42
    else:
 
43
        return None
 
44
 
 
45
 
 
46
def compare_range(a, astart, aend, b, bstart, bend):
 
47
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
 
48
    """
 
49
    if (aend-astart) != (bend-bstart):
 
50
        return False
 
51
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
 
52
        if a[ia] != b[ib]:
 
53
            return False
 
54
    else:
 
55
        return True
 
56
        
 
57
 
 
58
 
 
59
 
 
60
class Merge3(object):
 
61
    """3-way merge of texts.
 
62
 
 
63
    Given BASE, OTHER, THIS, tries to produce a combined text
 
64
    incorporating the changes from both BASE->OTHER and BASE->THIS.
 
65
    All three will typically be sequences of lines."""
 
66
    def __init__(self, base, a, b):
 
67
        self.base = base
 
68
        self.a = a
 
69
        self.b = b
 
70
        from difflib import SequenceMatcher
 
71
        self.a_ops = SequenceMatcher(None, base, a).get_opcodes()
 
72
        self.b_ops = SequenceMatcher(None, base, b).get_opcodes()
 
73
 
 
74
 
 
75
 
 
76
    def merge_lines(self,
 
77
                    name_a=None,
 
78
                    name_b=None,
 
79
                    start_marker='<<<<<<<',
 
80
                    mid_marker='=======',
 
81
                    end_marker='>>>>>>>',
 
82
                    show_base=False):
 
83
        """Return merge in cvs-like form.
 
84
        """
 
85
        if name_a:
 
86
            start_marker = start_marker + ' ' + name_a
 
87
        if name_b:
 
88
            end_marker = end_marker + ' ' + name_b
 
89
            
 
90
        for t in self.merge_regions():
 
91
            what = t[0]
 
92
            if what == 'unchanged':
 
93
                for i in range(t[1], t[2]):
 
94
                    yield self.base[i]
 
95
            elif what == 'a' or what == 'same':
 
96
                for i in range(t[1], t[2]):
 
97
                    yield self.a[i]
 
98
            elif what == 'b':
 
99
                for i in range(t[1], t[2]):
 
100
                    yield self.b[i]
 
101
            elif what == 'conflict':
 
102
                yield start_marker + '\n'
 
103
                for i in range(t[3], t[4]):
 
104
                    yield self.a[i]
 
105
                yield mid_marker + '\n'
 
106
                for i in range(t[5], t[6]):
 
107
                    yield self.b[i]
 
108
                yield end_marker + '\n'
 
109
            else:
 
110
                raise ValueError(what)
 
111
        
 
112
        
 
113
 
 
114
 
 
115
 
 
116
    def merge_annotated(self):
 
117
        """Return merge with conflicts, showing origin of lines.
 
118
 
 
119
        Most useful for debugging merge.        
 
120
        """
 
121
        for t in self.merge_regions():
 
122
            what = t[0]
 
123
            if what == 'unchanged':
 
124
                for i in range(t[1], t[2]):
 
125
                    yield 'u | ' + self.base[i]
 
126
            elif what == 'a' or what == 'same':
 
127
                for i in range(t[1], t[2]):
 
128
                    yield what[0] + ' | ' + self.a[i]
 
129
            elif what == 'b':
 
130
                for i in range(t[1], t[2]):
 
131
                    yield 'b | ' + self.b[i]
 
132
            elif what == 'conflict':
 
133
                yield '<<<<\n'
 
134
                for i in range(t[3], t[4]):
 
135
                    yield 'A | ' + self.a[i]
 
136
                yield '----\n'
 
137
                for i in range(t[5], t[6]):
 
138
                    yield 'B | ' + self.b[i]
 
139
                yield '>>>>\n'
 
140
            else:
 
141
                raise ValueError(what)
 
142
        
 
143
        
 
144
 
 
145
 
 
146
 
 
147
    def merge_groups(self):
 
148
        """Yield sequence of line groups.  Each one is a tuple:
 
149
 
 
150
        'unchanged', lines
 
151
             Lines unchanged from base
 
152
 
 
153
        'a', lines
 
154
             Lines taken from a
 
155
 
 
156
        'same', lines
 
157
             Lines taken from a (and equal to b)
 
158
 
 
159
        'b', lines
 
160
             Lines taken from b
 
161
 
 
162
        'conflict', base_lines, a_lines, b_lines
 
163
             Lines from base were changed to either a or b and conflict.
 
164
        """
 
165
        for t in self.merge_regions():
 
166
            what = t[0]
 
167
            if what == 'unchanged':
 
168
                yield what, self.base[t[1]:t[2]]
 
169
            elif what == 'a' or what == 'same':
 
170
                yield what, self.a[t[1]:t[2]]
 
171
            elif what == 'b':
 
172
                yield what, self.b[t[1]:t[2]]
 
173
            elif what == 'conflict':
 
174
                yield (what,
 
175
                       self.base[t[1]:t[2]],
 
176
                       self.a[t[3]:t[4]],
 
177
                       self.b[t[5]:t[6]])
 
178
            else:
 
179
                raise ValueError(what)
 
180
 
 
181
 
 
182
    def merge_regions(self):
 
183
        """Return sequences of matching and conflicting regions.
 
184
 
 
185
        This returns tuples, where the first value says what kind we
 
186
        have:
 
187
 
 
188
        'unchanged', start, end
 
189
             Take a region of base[start:end]
 
190
 
 
191
        'same', astart, aend
 
192
             b and a are different from base but give the same result
 
193
 
 
194
        'a', start, end
 
195
             Non-clashing insertion from a[start:end]
 
196
 
 
197
        Method is as follows:
 
198
 
 
199
        The two sequences align only on regions which match the base
 
200
        and both descendents.  These are found by doing a two-way diff
 
201
        of each one against the base, and then finding the
 
202
        intersections between those regions.  These "sync regions"
 
203
        are by definition unchanged in both and easily dealt with.
 
204
 
 
205
        The regions in between can be in any of three cases:
 
206
        conflicted, or changed on only one side.
 
207
        """
 
208
 
 
209
        # section a[0:ia] has been disposed of, etc
 
210
        iz = ia = ib = 0
 
211
        
 
212
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
 
213
            #print 'match base [%d:%d]' % (zmatch, zend)
 
214
            
 
215
            matchlen = zend - zmatch
 
216
            assert matchlen >= 0
 
217
            assert matchlen == (aend - amatch)
 
218
            assert matchlen == (bend - bmatch)
 
219
            
 
220
            len_a = amatch - ia
 
221
            len_b = bmatch - ib
 
222
            len_base = zmatch - iz
 
223
            assert len_a >= 0
 
224
            assert len_b >= 0
 
225
            assert len_base >= 0
 
226
 
 
227
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
 
228
 
 
229
            if len_a or len_b:
 
230
                # try to avoid actually slicing the lists
 
231
                equal_a = compare_range(self.a, ia, amatch,
 
232
                                        self.base, iz, zmatch)
 
233
                equal_b = compare_range(self.b, ib, bmatch,
 
234
                                        self.base, iz, zmatch)
 
235
                same = compare_range(self.a, ia, amatch,
 
236
                                     self.b, ib, bmatch)
 
237
 
 
238
                if same:
 
239
                    yield 'same', ia, amatch
 
240
                elif equal_a and not equal_b:
 
241
                    yield 'b', ib, bmatch
 
242
                elif equal_b and not equal_a:
 
243
                    yield 'a', ia, amatch
 
244
                elif not equal_a and not equal_b:
 
245
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
 
246
                else:
 
247
                    raise AssertionError("can't handle a=b=base but unmatched")
 
248
 
 
249
                ia = amatch
 
250
                ib = bmatch
 
251
            iz = zmatch
 
252
 
 
253
            # if the same part of the base was deleted on both sides
 
254
            # that's OK, we can just skip it.
 
255
 
 
256
                
 
257
            if matchlen > 0:
 
258
                assert ia == amatch
 
259
                assert ib == bmatch
 
260
                assert iz == zmatch
 
261
                
 
262
                yield 'unchanged', zmatch, zend
 
263
                iz = zend
 
264
                ia = aend
 
265
                ib = bend
 
266
        
 
267
 
 
268
        
 
269
    def find_sync_regions(self):
 
270
        """Return a list of sync regions, where both descendents match the base.
 
271
 
 
272
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
 
273
        always a zero-length sync region at the end of all the files.
 
274
        """
 
275
        from difflib import SequenceMatcher
 
276
 
 
277
        ia = ib = 0
 
278
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
 
279
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
 
280
        len_a = len(amatches)
 
281
        len_b = len(bmatches)
 
282
 
 
283
        sl = []
 
284
 
 
285
        while ia < len_a and ib < len_b:
 
286
            abase, amatch, alen = amatches[ia]
 
287
            bbase, bmatch, blen = bmatches[ib]
 
288
 
 
289
            # there is an unconflicted block at i; how long does it
 
290
            # extend?  until whichever one ends earlier.
 
291
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
 
292
            if i:
 
293
                intbase = i[0]
 
294
                intend = i[1]
 
295
                intlen = intend - intbase
 
296
 
 
297
                # found a match of base[i[0], i[1]]; this may be less than
 
298
                # the region that matches in either one
 
299
                assert intlen <= alen
 
300
                assert intlen <= blen
 
301
                assert abase <= intbase
 
302
                assert bbase <= intbase
 
303
 
 
304
                asub = amatch + (intbase - abase)
 
305
                bsub = bmatch + (intbase - bbase)
 
306
                aend = asub + intlen
 
307
                bend = bsub + intlen
 
308
 
 
309
                assert self.base[intbase:intend] == self.a[asub:aend], \
 
310
                       (self.base[intbase:intend], self.a[asub:aend])
 
311
 
 
312
                assert self.base[intbase:intend] == self.b[bsub:bend]
 
313
 
 
314
                sl.append((intbase, intend,
 
315
                           asub, aend,
 
316
                           bsub, bend))
 
317
 
 
318
            # advance whichever one ends first in the base text
 
319
            if (abase + alen) < (bbase + blen):
 
320
                ia += 1
 
321
            else:
 
322
                ib += 1
 
323
            
 
324
        intbase = len(self.base)
 
325
        abase = len(self.a)
 
326
        bbase = len(self.b)
 
327
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
 
328
 
 
329
        return sl
 
330
 
 
331
 
 
332
 
 
333
    def find_unconflicted(self):
 
334
        """Return a list of ranges in base that are not conflicted."""
 
335
        from difflib import SequenceMatcher
 
336
 
 
337
        import re
 
338
 
 
339
        # don't sync-up on lines containing only blanks or pounds
 
340
        junk_re = re.compile(r'^[ \t#]*$')
 
341
        
 
342
        am = SequenceMatcher(junk_re.match, self.base, self.a).get_matching_blocks()
 
343
        bm = SequenceMatcher(junk_re.match, self.base, self.b).get_matching_blocks()
 
344
 
 
345
        unc = []
 
346
 
 
347
        while am and bm:
 
348
            # there is an unconflicted block at i; how long does it
 
349
            # extend?  until whichever one ends earlier.
 
350
            a1 = am[0][0]
 
351
            a2 = a1 + am[0][2]
 
352
            b1 = bm[0][0]
 
353
            b2 = b1 + bm[0][2]
 
354
            i = intersect((a1, a2), (b1, b2))
 
355
            if i:
 
356
                unc.append(i)
 
357
 
 
358
            if a2 < b2:
 
359
                del am[0]
 
360
            else:
 
361
                del bm[0]
 
362
                
 
363
        return unc
 
364
 
 
365
 
 
366
def main(argv):
 
367
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
 
368
    a = file(argv[1], 'rt').readlines()
 
369
    base = file(argv[2], 'rt').readlines()
 
370
    b = file(argv[3], 'rt').readlines()
 
371
 
 
372
    m3 = Merge3(base, a, b)
 
373
 
 
374
    #for sr in m3.find_sync_regions():
 
375
    #    print sr
 
376
 
 
377
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
 
378
    sys.stdout.writelines(m3.merge_annotated())
 
379
 
 
380
 
 
381
if __name__ == '__main__':
 
382
    import sys
 
383
    sys.exit(main(sys.argv))