~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-05-05 06:38:18 UTC
  • Revision ID: mbp@sourcefrog.net-20050505063818-3eb3260343878325
- do upload CHANGELOG to web server, even though it's autogenerated

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004, 2005 by Canonical Ltd
2
 
 
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
 
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
 
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
 
 
17
 
 
18
 
# mbp: "you know that thing where cvs gives you conflict markers?"
19
 
# s: "i hate that."
20
 
 
21
 
 
22
 
from difflib import SequenceMatcher
23
 
 
24
 
from bzrlib.errors import CantReprocessAndShowBase
25
 
from bzrlib.textfile import check_text_lines
26
 
 
27
 
def intersect(ra, rb):
28
 
    """Given two ranges return the range where they intersect or None.
29
 
 
30
 
    >>> intersect((0, 10), (0, 6))
31
 
    (0, 6)
32
 
    >>> intersect((0, 10), (5, 15))
33
 
    (5, 10)
34
 
    >>> intersect((0, 10), (10, 15))
35
 
    >>> intersect((0, 9), (10, 15))
36
 
    >>> intersect((0, 9), (7, 15))
37
 
    (7, 9)
38
 
    """
39
 
    assert ra[0] <= ra[1]
40
 
    assert rb[0] <= rb[1]
41
 
    
42
 
    sa = max(ra[0], rb[0])
43
 
    sb = min(ra[1], rb[1])
44
 
    if sa < sb:
45
 
        return sa, sb
46
 
    else:
47
 
        return None
48
 
 
49
 
 
50
 
def compare_range(a, astart, aend, b, bstart, bend):
51
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
52
 
    """
53
 
    if (aend-astart) != (bend-bstart):
54
 
        return False
55
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
56
 
        if a[ia] != b[ib]:
57
 
            return False
58
 
    else:
59
 
        return True
60
 
        
61
 
 
62
 
 
63
 
 
64
 
class Merge3(object):
65
 
    """3-way merge of texts.
66
 
 
67
 
    Given BASE, OTHER, THIS, tries to produce a combined text
68
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
69
 
    All three will typically be sequences of lines."""
70
 
    def __init__(self, base, a, b):
71
 
        check_text_lines(base)
72
 
        check_text_lines(a)
73
 
        check_text_lines(b)
74
 
        self.base = base
75
 
        self.a = a
76
 
        self.b = b
77
 
 
78
 
 
79
 
 
80
 
    def merge_lines(self,
81
 
                    name_a=None,
82
 
                    name_b=None,
83
 
                    name_base=None,
84
 
                    start_marker='<<<<<<<',
85
 
                    mid_marker='=======',
86
 
                    end_marker='>>>>>>>',
87
 
                    base_marker=None,
88
 
                    reprocess=False):
89
 
        """Return merge in cvs-like form.
90
 
        """
91
 
        if base_marker and reprocess:
92
 
            raise CantReprocessAndShowBase()
93
 
        if name_a:
94
 
            start_marker = start_marker + ' ' + name_a
95
 
        if name_b:
96
 
            end_marker = end_marker + ' ' + name_b
97
 
        if name_base and base_marker:
98
 
            base_marker = base_marker + ' ' + name_base
99
 
        merge_regions = self.merge_regions()
100
 
        if reprocess is True:
101
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
102
 
        for t in merge_regions:
103
 
            what = t[0]
104
 
            if what == 'unchanged':
105
 
                for i in range(t[1], t[2]):
106
 
                    yield self.base[i]
107
 
            elif what == 'a' or what == 'same':
108
 
                for i in range(t[1], t[2]):
109
 
                    yield self.a[i]
110
 
            elif what == 'b':
111
 
                for i in range(t[1], t[2]):
112
 
                    yield self.b[i]
113
 
            elif what == 'conflict':
114
 
                yield start_marker + '\n'
115
 
                for i in range(t[3], t[4]):
116
 
                    yield self.a[i]
117
 
                if base_marker is not None:
118
 
                    yield base_marker + '\n'
119
 
                    for i in range(t[1], t[2]):
120
 
                        yield self.base[i]
121
 
                yield mid_marker + '\n'
122
 
                for i in range(t[5], t[6]):
123
 
                    yield self.b[i]
124
 
                yield end_marker + '\n'
125
 
            else:
126
 
                raise ValueError(what)
127
 
        
128
 
        
129
 
 
130
 
 
131
 
 
132
 
    def merge_annotated(self):
133
 
        """Return merge with conflicts, showing origin of lines.
134
 
 
135
 
        Most useful for debugging merge.        
136
 
        """
137
 
        for t in self.merge_regions():
138
 
            what = t[0]
139
 
            if what == 'unchanged':
140
 
                for i in range(t[1], t[2]):
141
 
                    yield 'u | ' + self.base[i]
142
 
            elif what == 'a' or what == 'same':
143
 
                for i in range(t[1], t[2]):
144
 
                    yield what[0] + ' | ' + self.a[i]
145
 
            elif what == 'b':
146
 
                for i in range(t[1], t[2]):
147
 
                    yield 'b | ' + self.b[i]
148
 
            elif what == 'conflict':
149
 
                yield '<<<<\n'
150
 
                for i in range(t[3], t[4]):
151
 
                    yield 'A | ' + self.a[i]
152
 
                yield '----\n'
153
 
                for i in range(t[5], t[6]):
154
 
                    yield 'B | ' + self.b[i]
155
 
                yield '>>>>\n'
156
 
            else:
157
 
                raise ValueError(what)
158
 
        
159
 
        
160
 
 
161
 
 
162
 
 
163
 
    def merge_groups(self):
164
 
        """Yield sequence of line groups.  Each one is a tuple:
165
 
 
166
 
        'unchanged', lines
167
 
             Lines unchanged from base
168
 
 
169
 
        'a', lines
170
 
             Lines taken from a
171
 
 
172
 
        'same', lines
173
 
             Lines taken from a (and equal to b)
174
 
 
175
 
        'b', lines
176
 
             Lines taken from b
177
 
 
178
 
        'conflict', base_lines, a_lines, b_lines
179
 
             Lines from base were changed to either a or b and conflict.
180
 
        """
181
 
        for t in self.merge_regions():
182
 
            what = t[0]
183
 
            if what == 'unchanged':
184
 
                yield what, self.base[t[1]:t[2]]
185
 
            elif what == 'a' or what == 'same':
186
 
                yield what, self.a[t[1]:t[2]]
187
 
            elif what == 'b':
188
 
                yield what, self.b[t[1]:t[2]]
189
 
            elif what == 'conflict':
190
 
                yield (what,
191
 
                       self.base[t[1]:t[2]],
192
 
                       self.a[t[3]:t[4]],
193
 
                       self.b[t[5]:t[6]])
194
 
            else:
195
 
                raise ValueError(what)
196
 
 
197
 
 
198
 
    def merge_regions(self):
199
 
        """Return sequences of matching and conflicting regions.
200
 
 
201
 
        This returns tuples, where the first value says what kind we
202
 
        have:
203
 
 
204
 
        'unchanged', start, end
205
 
             Take a region of base[start:end]
206
 
 
207
 
        'same', astart, aend
208
 
             b and a are different from base but give the same result
209
 
 
210
 
        'a', start, end
211
 
             Non-clashing insertion from a[start:end]
212
 
 
213
 
        Method is as follows:
214
 
 
215
 
        The two sequences align only on regions which match the base
216
 
        and both descendents.  These are found by doing a two-way diff
217
 
        of each one against the base, and then finding the
218
 
        intersections between those regions.  These "sync regions"
219
 
        are by definition unchanged in both and easily dealt with.
220
 
 
221
 
        The regions in between can be in any of three cases:
222
 
        conflicted, or changed on only one side.
223
 
        """
224
 
 
225
 
        # section a[0:ia] has been disposed of, etc
226
 
        iz = ia = ib = 0
227
 
        
228
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
229
 
            #print 'match base [%d:%d]' % (zmatch, zend)
230
 
            
231
 
            matchlen = zend - zmatch
232
 
            assert matchlen >= 0
233
 
            assert matchlen == (aend - amatch)
234
 
            assert matchlen == (bend - bmatch)
235
 
            
236
 
            len_a = amatch - ia
237
 
            len_b = bmatch - ib
238
 
            len_base = zmatch - iz
239
 
            assert len_a >= 0
240
 
            assert len_b >= 0
241
 
            assert len_base >= 0
242
 
 
243
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
244
 
 
245
 
            if len_a or len_b:
246
 
                # try to avoid actually slicing the lists
247
 
                equal_a = compare_range(self.a, ia, amatch,
248
 
                                        self.base, iz, zmatch)
249
 
                equal_b = compare_range(self.b, ib, bmatch,
250
 
                                        self.base, iz, zmatch)
251
 
                same = compare_range(self.a, ia, amatch,
252
 
                                     self.b, ib, bmatch)
253
 
 
254
 
                if same:
255
 
                    yield 'same', ia, amatch
256
 
                elif equal_a and not equal_b:
257
 
                    yield 'b', ib, bmatch
258
 
                elif equal_b and not equal_a:
259
 
                    yield 'a', ia, amatch
260
 
                elif not equal_a and not equal_b:
261
 
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
262
 
                else:
263
 
                    raise AssertionError("can't handle a=b=base but unmatched")
264
 
 
265
 
                ia = amatch
266
 
                ib = bmatch
267
 
            iz = zmatch
268
 
 
269
 
            # if the same part of the base was deleted on both sides
270
 
            # that's OK, we can just skip it.
271
 
 
272
 
                
273
 
            if matchlen > 0:
274
 
                assert ia == amatch
275
 
                assert ib == bmatch
276
 
                assert iz == zmatch
277
 
                
278
 
                yield 'unchanged', zmatch, zend
279
 
                iz = zend
280
 
                ia = aend
281
 
                ib = bend
282
 
    
283
 
 
284
 
    def reprocess_merge_regions(self, merge_regions):
285
 
        """Where there are conflict regions, remove the agreed lines.
286
 
 
287
 
        Lines where both A and B have made the same changes are 
288
 
        eliminated.
289
 
        """
290
 
        for region in merge_regions:
291
 
            if region[0] != "conflict":
292
 
                yield region
293
 
                continue
294
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
295
 
            a_region = self.a[ia:amatch]
296
 
            b_region = self.b[ib:bmatch]
297
 
            matches = SequenceMatcher(None, a_region, 
298
 
                                      b_region).get_matching_blocks()
299
 
            next_a = ia
300
 
            next_b = ib
301
 
            for region_ia, region_ib, region_len in matches[:-1]:
302
 
                region_ia += ia
303
 
                region_ib += ib
304
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
305
 
                                           region_ib)
306
 
                if reg is not None:
307
 
                    yield reg
308
 
                yield 'same', region_ia, region_len+region_ia
309
 
                next_a = region_ia + region_len
310
 
                next_b = region_ib + region_len
311
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
312
 
            if reg is not None:
313
 
                yield reg
314
 
 
315
 
 
316
 
    @staticmethod
317
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
318
 
        if next_a < region_ia or next_b < region_ib:
319
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
320
 
            
321
 
 
322
 
    def find_sync_regions(self):
323
 
        """Return a list of sync regions, where both descendents match the base.
324
 
 
325
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
326
 
        always a zero-length sync region at the end of all the files.
327
 
        """
328
 
 
329
 
        ia = ib = 0
330
 
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
331
 
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
332
 
        len_a = len(amatches)
333
 
        len_b = len(bmatches)
334
 
 
335
 
        sl = []
336
 
 
337
 
        while ia < len_a and ib < len_b:
338
 
            abase, amatch, alen = amatches[ia]
339
 
            bbase, bmatch, blen = bmatches[ib]
340
 
 
341
 
            # there is an unconflicted block at i; how long does it
342
 
            # extend?  until whichever one ends earlier.
343
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
344
 
            if i:
345
 
                intbase = i[0]
346
 
                intend = i[1]
347
 
                intlen = intend - intbase
348
 
 
349
 
                # found a match of base[i[0], i[1]]; this may be less than
350
 
                # the region that matches in either one
351
 
                assert intlen <= alen
352
 
                assert intlen <= blen
353
 
                assert abase <= intbase
354
 
                assert bbase <= intbase
355
 
 
356
 
                asub = amatch + (intbase - abase)
357
 
                bsub = bmatch + (intbase - bbase)
358
 
                aend = asub + intlen
359
 
                bend = bsub + intlen
360
 
 
361
 
                assert self.base[intbase:intend] == self.a[asub:aend], \
362
 
                       (self.base[intbase:intend], self.a[asub:aend])
363
 
 
364
 
                assert self.base[intbase:intend] == self.b[bsub:bend]
365
 
 
366
 
                sl.append((intbase, intend,
367
 
                           asub, aend,
368
 
                           bsub, bend))
369
 
 
370
 
            # advance whichever one ends first in the base text
371
 
            if (abase + alen) < (bbase + blen):
372
 
                ia += 1
373
 
            else:
374
 
                ib += 1
375
 
            
376
 
        intbase = len(self.base)
377
 
        abase = len(self.a)
378
 
        bbase = len(self.b)
379
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
380
 
 
381
 
        return sl
382
 
 
383
 
 
384
 
 
385
 
    def find_unconflicted(self):
386
 
        """Return a list of ranges in base that are not conflicted."""
387
 
 
388
 
        import re
389
 
 
390
 
        # don't sync-up on lines containing only blanks or pounds
391
 
        junk_re = re.compile(r'^[ \t#]*$')
392
 
        
393
 
        am = SequenceMatcher(junk_re.match, self.base, self.a).get_matching_blocks()
394
 
        bm = SequenceMatcher(junk_re.match, self.base, self.b).get_matching_blocks()
395
 
 
396
 
        unc = []
397
 
 
398
 
        while am and bm:
399
 
            # there is an unconflicted block at i; how long does it
400
 
            # extend?  until whichever one ends earlier.
401
 
            a1 = am[0][0]
402
 
            a2 = a1 + am[0][2]
403
 
            b1 = bm[0][0]
404
 
            b2 = b1 + bm[0][2]
405
 
            i = intersect((a1, a2), (b1, b2))
406
 
            if i:
407
 
                unc.append(i)
408
 
 
409
 
            if a2 < b2:
410
 
                del am[0]
411
 
            else:
412
 
                del bm[0]
413
 
                
414
 
        return unc
415
 
 
416
 
 
417
 
def main(argv):
418
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
419
 
    a = file(argv[1], 'rt').readlines()
420
 
    base = file(argv[2], 'rt').readlines()
421
 
    b = file(argv[3], 'rt').readlines()
422
 
 
423
 
    m3 = Merge3(base, a, b)
424
 
 
425
 
    #for sr in m3.find_sync_regions():
426
 
    #    print sr
427
 
 
428
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
429
 
    sys.stdout.writelines(m3.merge_annotated())
430
 
 
431
 
 
432
 
if __name__ == '__main__':
433
 
    import sys
434
 
    sys.exit(main(sys.argv))