~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-07-04 08:06:51 UTC
  • Revision ID: mbp@sourcefrog.net-20050704080651-6ecec49164359e48
- track pending-merges

- unit tests for this

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004, 2005 Canonical Ltd
2
 
#
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
#
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
#
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
 
 
17
 
 
18
 
# mbp: "you know that thing where cvs gives you conflict markers?"
19
 
# s: "i hate that."
20
 
 
21
 
 
22
 
from bzrlib.errors import CantReprocessAndShowBase
23
 
import bzrlib.patiencediff
24
 
from bzrlib.textfile import check_text_lines
25
 
 
26
 
 
27
 
def intersect(ra, rb):
28
 
    """Given two ranges return the range where they intersect or None.
29
 
 
30
 
    >>> intersect((0, 10), (0, 6))
31
 
    (0, 6)
32
 
    >>> intersect((0, 10), (5, 15))
33
 
    (5, 10)
34
 
    >>> intersect((0, 10), (10, 15))
35
 
    >>> intersect((0, 9), (10, 15))
36
 
    >>> intersect((0, 9), (7, 15))
37
 
    (7, 9)
38
 
    """
39
 
    assert ra[0] <= ra[1]
40
 
    assert rb[0] <= rb[1]
41
 
    
42
 
    sa = max(ra[0], rb[0])
43
 
    sb = min(ra[1], rb[1])
44
 
    if sa < sb:
45
 
        return sa, sb
46
 
    else:
47
 
        return None
48
 
 
49
 
 
50
 
def compare_range(a, astart, aend, b, bstart, bend):
51
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
52
 
    """
53
 
    if (aend-astart) != (bend-bstart):
54
 
        return False
55
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
56
 
        if a[ia] != b[ib]:
57
 
            return False
58
 
    else:
59
 
        return True
60
 
        
61
 
 
62
 
 
63
 
 
64
 
class Merge3(object):
65
 
    """3-way merge of texts.
66
 
 
67
 
    Given BASE, OTHER, THIS, tries to produce a combined text
68
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
69
 
    All three will typically be sequences of lines."""
70
 
    def __init__(self, base, a, b):
71
 
        check_text_lines(base)
72
 
        check_text_lines(a)
73
 
        check_text_lines(b)
74
 
        self.base = base
75
 
        self.a = a
76
 
        self.b = b
77
 
 
78
 
 
79
 
 
80
 
    def merge_lines(self,
81
 
                    name_a=None,
82
 
                    name_b=None,
83
 
                    name_base=None,
84
 
                    start_marker='<<<<<<<',
85
 
                    mid_marker='=======',
86
 
                    end_marker='>>>>>>>',
87
 
                    base_marker=None,
88
 
                    reprocess=False):
89
 
        """Return merge in cvs-like form.
90
 
        """
91
 
        newline = '\n'
92
 
        if len(self.a) > 0:
93
 
            if self.a[0].endswith('\r\n'):
94
 
                newline = '\r\n'
95
 
            elif self.a[0].endswith('\r'):
96
 
                newline = '\r'
97
 
        if base_marker and reprocess:
98
 
            raise CantReprocessAndShowBase()
99
 
        if name_a:
100
 
            start_marker = start_marker + ' ' + name_a
101
 
        if name_b:
102
 
            end_marker = end_marker + ' ' + name_b
103
 
        if name_base and base_marker:
104
 
            base_marker = base_marker + ' ' + name_base
105
 
        merge_regions = self.merge_regions()
106
 
        if reprocess is True:
107
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
108
 
        for t in merge_regions:
109
 
            what = t[0]
110
 
            if what == 'unchanged':
111
 
                for i in range(t[1], t[2]):
112
 
                    yield self.base[i]
113
 
            elif what == 'a' or what == 'same':
114
 
                for i in range(t[1], t[2]):
115
 
                    yield self.a[i]
116
 
            elif what == 'b':
117
 
                for i in range(t[1], t[2]):
118
 
                    yield self.b[i]
119
 
            elif what == 'conflict':
120
 
                yield start_marker + newline
121
 
                for i in range(t[3], t[4]):
122
 
                    yield self.a[i]
123
 
                if base_marker is not None:
124
 
                    yield base_marker + newline
125
 
                    for i in range(t[1], t[2]):
126
 
                        yield self.base[i]
127
 
                yield mid_marker + newline
128
 
                for i in range(t[5], t[6]):
129
 
                    yield self.b[i]
130
 
                yield end_marker + newline
131
 
            else:
132
 
                raise ValueError(what)
133
 
        
134
 
        
135
 
 
136
 
 
137
 
 
138
 
    def merge_annotated(self):
139
 
        """Return merge with conflicts, showing origin of lines.
140
 
 
141
 
        Most useful for debugging merge.        
142
 
        """
143
 
        for t in self.merge_regions():
144
 
            what = t[0]
145
 
            if what == 'unchanged':
146
 
                for i in range(t[1], t[2]):
147
 
                    yield 'u | ' + self.base[i]
148
 
            elif what == 'a' or what == 'same':
149
 
                for i in range(t[1], t[2]):
150
 
                    yield what[0] + ' | ' + self.a[i]
151
 
            elif what == 'b':
152
 
                for i in range(t[1], t[2]):
153
 
                    yield 'b | ' + self.b[i]
154
 
            elif what == 'conflict':
155
 
                yield '<<<<\n'
156
 
                for i in range(t[3], t[4]):
157
 
                    yield 'A | ' + self.a[i]
158
 
                yield '----\n'
159
 
                for i in range(t[5], t[6]):
160
 
                    yield 'B | ' + self.b[i]
161
 
                yield '>>>>\n'
162
 
            else:
163
 
                raise ValueError(what)
164
 
        
165
 
        
166
 
 
167
 
 
168
 
 
169
 
    def merge_groups(self):
170
 
        """Yield sequence of line groups.  Each one is a tuple:
171
 
 
172
 
        'unchanged', lines
173
 
             Lines unchanged from base
174
 
 
175
 
        'a', lines
176
 
             Lines taken from a
177
 
 
178
 
        'same', lines
179
 
             Lines taken from a (and equal to b)
180
 
 
181
 
        'b', lines
182
 
             Lines taken from b
183
 
 
184
 
        'conflict', base_lines, a_lines, b_lines
185
 
             Lines from base were changed to either a or b and conflict.
186
 
        """
187
 
        for t in self.merge_regions():
188
 
            what = t[0]
189
 
            if what == 'unchanged':
190
 
                yield what, self.base[t[1]:t[2]]
191
 
            elif what == 'a' or what == 'same':
192
 
                yield what, self.a[t[1]:t[2]]
193
 
            elif what == 'b':
194
 
                yield what, self.b[t[1]:t[2]]
195
 
            elif what == 'conflict':
196
 
                yield (what,
197
 
                       self.base[t[1]:t[2]],
198
 
                       self.a[t[3]:t[4]],
199
 
                       self.b[t[5]:t[6]])
200
 
            else:
201
 
                raise ValueError(what)
202
 
 
203
 
 
204
 
    def merge_regions(self):
205
 
        """Return sequences of matching and conflicting regions.
206
 
 
207
 
        This returns tuples, where the first value says what kind we
208
 
        have:
209
 
 
210
 
        'unchanged', start, end
211
 
             Take a region of base[start:end]
212
 
 
213
 
        'same', astart, aend
214
 
             b and a are different from base but give the same result
215
 
 
216
 
        'a', start, end
217
 
             Non-clashing insertion from a[start:end]
218
 
 
219
 
        Method is as follows:
220
 
 
221
 
        The two sequences align only on regions which match the base
222
 
        and both descendents.  These are found by doing a two-way diff
223
 
        of each one against the base, and then finding the
224
 
        intersections between those regions.  These "sync regions"
225
 
        are by definition unchanged in both and easily dealt with.
226
 
 
227
 
        The regions in between can be in any of three cases:
228
 
        conflicted, or changed on only one side.
229
 
        """
230
 
 
231
 
        # section a[0:ia] has been disposed of, etc
232
 
        iz = ia = ib = 0
233
 
        
234
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
235
 
            #print 'match base [%d:%d]' % (zmatch, zend)
236
 
            
237
 
            matchlen = zend - zmatch
238
 
            assert matchlen >= 0
239
 
            assert matchlen == (aend - amatch)
240
 
            assert matchlen == (bend - bmatch)
241
 
            
242
 
            len_a = amatch - ia
243
 
            len_b = bmatch - ib
244
 
            len_base = zmatch - iz
245
 
            assert len_a >= 0
246
 
            assert len_b >= 0
247
 
            assert len_base >= 0
248
 
 
249
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
250
 
 
251
 
            if len_a or len_b:
252
 
                # try to avoid actually slicing the lists
253
 
                equal_a = compare_range(self.a, ia, amatch,
254
 
                                        self.base, iz, zmatch)
255
 
                equal_b = compare_range(self.b, ib, bmatch,
256
 
                                        self.base, iz, zmatch)
257
 
                same = compare_range(self.a, ia, amatch,
258
 
                                     self.b, ib, bmatch)
259
 
 
260
 
                if same:
261
 
                    yield 'same', ia, amatch
262
 
                elif equal_a and not equal_b:
263
 
                    yield 'b', ib, bmatch
264
 
                elif equal_b and not equal_a:
265
 
                    yield 'a', ia, amatch
266
 
                elif not equal_a and not equal_b:
267
 
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
268
 
                else:
269
 
                    raise AssertionError("can't handle a=b=base but unmatched")
270
 
 
271
 
                ia = amatch
272
 
                ib = bmatch
273
 
            iz = zmatch
274
 
 
275
 
            # if the same part of the base was deleted on both sides
276
 
            # that's OK, we can just skip it.
277
 
 
278
 
                
279
 
            if matchlen > 0:
280
 
                assert ia == amatch
281
 
                assert ib == bmatch
282
 
                assert iz == zmatch
283
 
                
284
 
                yield 'unchanged', zmatch, zend
285
 
                iz = zend
286
 
                ia = aend
287
 
                ib = bend
288
 
    
289
 
 
290
 
    def reprocess_merge_regions(self, merge_regions):
291
 
        """Where there are conflict regions, remove the agreed lines.
292
 
 
293
 
        Lines where both A and B have made the same changes are 
294
 
        eliminated.
295
 
        """
296
 
        for region in merge_regions:
297
 
            if region[0] != "conflict":
298
 
                yield region
299
 
                continue
300
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
301
 
            a_region = self.a[ia:amatch]
302
 
            b_region = self.b[ib:bmatch]
303
 
            matches = bzrlib.patiencediff.PatienceSequenceMatcher(
304
 
                    None, a_region, b_region).get_matching_blocks()
305
 
            next_a = ia
306
 
            next_b = ib
307
 
            for region_ia, region_ib, region_len in matches[:-1]:
308
 
                region_ia += ia
309
 
                region_ib += ib
310
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
311
 
                                           region_ib)
312
 
                if reg is not None:
313
 
                    yield reg
314
 
                yield 'same', region_ia, region_len+region_ia
315
 
                next_a = region_ia + region_len
316
 
                next_b = region_ib + region_len
317
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
318
 
            if reg is not None:
319
 
                yield reg
320
 
 
321
 
 
322
 
    @staticmethod
323
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
324
 
        if next_a < region_ia or next_b < region_ib:
325
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
326
 
            
327
 
 
328
 
    def find_sync_regions(self):
329
 
        """Return a list of sync regions, where both descendents match the base.
330
 
 
331
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
332
 
        always a zero-length sync region at the end of all the files.
333
 
        """
334
 
 
335
 
        ia = ib = 0
336
 
        amatches = bzrlib.patiencediff.PatienceSequenceMatcher(
337
 
                None, self.base, self.a).get_matching_blocks()
338
 
        bmatches = bzrlib.patiencediff.PatienceSequenceMatcher(
339
 
                None, self.base, self.b).get_matching_blocks()
340
 
        len_a = len(amatches)
341
 
        len_b = len(bmatches)
342
 
 
343
 
        sl = []
344
 
 
345
 
        while ia < len_a and ib < len_b:
346
 
            abase, amatch, alen = amatches[ia]
347
 
            bbase, bmatch, blen = bmatches[ib]
348
 
 
349
 
            # there is an unconflicted block at i; how long does it
350
 
            # extend?  until whichever one ends earlier.
351
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
352
 
            if i:
353
 
                intbase = i[0]
354
 
                intend = i[1]
355
 
                intlen = intend - intbase
356
 
 
357
 
                # found a match of base[i[0], i[1]]; this may be less than
358
 
                # the region that matches in either one
359
 
                assert intlen <= alen
360
 
                assert intlen <= blen
361
 
                assert abase <= intbase
362
 
                assert bbase <= intbase
363
 
 
364
 
                asub = amatch + (intbase - abase)
365
 
                bsub = bmatch + (intbase - bbase)
366
 
                aend = asub + intlen
367
 
                bend = bsub + intlen
368
 
 
369
 
                assert self.base[intbase:intend] == self.a[asub:aend], \
370
 
                       (self.base[intbase:intend], self.a[asub:aend])
371
 
 
372
 
                assert self.base[intbase:intend] == self.b[bsub:bend]
373
 
 
374
 
                sl.append((intbase, intend,
375
 
                           asub, aend,
376
 
                           bsub, bend))
377
 
 
378
 
            # advance whichever one ends first in the base text
379
 
            if (abase + alen) < (bbase + blen):
380
 
                ia += 1
381
 
            else:
382
 
                ib += 1
383
 
            
384
 
        intbase = len(self.base)
385
 
        abase = len(self.a)
386
 
        bbase = len(self.b)
387
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
388
 
 
389
 
        return sl
390
 
 
391
 
 
392
 
 
393
 
    def find_unconflicted(self):
394
 
        """Return a list of ranges in base that are not conflicted."""
395
 
        am = bzrlib.patiencediff.PatienceSequenceMatcher(
396
 
                None, self.base, self.a).get_matching_blocks()
397
 
        bm = bzrlib.patiencediff.PatienceSequenceMatcher(
398
 
                None, self.base, self.b).get_matching_blocks()
399
 
 
400
 
        unc = []
401
 
 
402
 
        while am and bm:
403
 
            # there is an unconflicted block at i; how long does it
404
 
            # extend?  until whichever one ends earlier.
405
 
            a1 = am[0][0]
406
 
            a2 = a1 + am[0][2]
407
 
            b1 = bm[0][0]
408
 
            b2 = b1 + bm[0][2]
409
 
            i = intersect((a1, a2), (b1, b2))
410
 
            if i:
411
 
                unc.append(i)
412
 
 
413
 
            if a2 < b2:
414
 
                del am[0]
415
 
            else:
416
 
                del bm[0]
417
 
                
418
 
        return unc
419
 
 
420
 
 
421
 
def main(argv):
422
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
423
 
    a = file(argv[1], 'rt').readlines()
424
 
    base = file(argv[2], 'rt').readlines()
425
 
    b = file(argv[3], 'rt').readlines()
426
 
 
427
 
    m3 = Merge3(base, a, b)
428
 
 
429
 
    #for sr in m3.find_sync_regions():
430
 
    #    print sr
431
 
 
432
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
433
 
    sys.stdout.writelines(m3.merge_annotated())
434
 
 
435
 
 
436
 
if __name__ == '__main__':
437
 
    import sys
438
 
    sys.exit(main(sys.argv))