~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/merge3.py

  • Committer: Martin Pool
  • Date: 2005-06-11 01:33:22 UTC
  • Revision ID: mbp@sourcefrog.net-20050611013322-f12014bf65accd0c
- don't show progress bar unless completion is known

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004, 2005 by Canonical Ltd
2
 
 
3
 
# This program is free software; you can redistribute it and/or modify
4
 
# it under the terms of the GNU General Public License as published by
5
 
# the Free Software Foundation; either version 2 of the License, or
6
 
# (at your option) any later version.
7
 
 
8
 
# This program is distributed in the hope that it will be useful,
9
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11
 
# GNU General Public License for more details.
12
 
 
13
 
# You should have received a copy of the GNU General Public License
14
 
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
16
 
 
17
 
 
18
 
# mbp: "you know that thing where cvs gives you conflict markers?"
19
 
# s: "i hate that."
20
 
 
21
 
 
22
 
from bzrlib.errors import CantReprocessAndShowBase
23
 
from bzrlib.patiencediff import SequenceMatcher
24
 
from bzrlib.textfile import check_text_lines
25
 
 
26
 
def intersect(ra, rb):
27
 
    """Given two ranges return the range where they intersect or None.
28
 
 
29
 
    >>> intersect((0, 10), (0, 6))
30
 
    (0, 6)
31
 
    >>> intersect((0, 10), (5, 15))
32
 
    (5, 10)
33
 
    >>> intersect((0, 10), (10, 15))
34
 
    >>> intersect((0, 9), (10, 15))
35
 
    >>> intersect((0, 9), (7, 15))
36
 
    (7, 9)
37
 
    """
38
 
    assert ra[0] <= ra[1]
39
 
    assert rb[0] <= rb[1]
40
 
    
41
 
    sa = max(ra[0], rb[0])
42
 
    sb = min(ra[1], rb[1])
43
 
    if sa < sb:
44
 
        return sa, sb
45
 
    else:
46
 
        return None
47
 
 
48
 
 
49
 
def compare_range(a, astart, aend, b, bstart, bend):
50
 
    """Compare a[astart:aend] == b[bstart:bend], without slicing.
51
 
    """
52
 
    if (aend-astart) != (bend-bstart):
53
 
        return False
54
 
    for ia, ib in zip(xrange(astart, aend), xrange(bstart, bend)):
55
 
        if a[ia] != b[ib]:
56
 
            return False
57
 
    else:
58
 
        return True
59
 
        
60
 
 
61
 
 
62
 
 
63
 
class Merge3(object):
64
 
    """3-way merge of texts.
65
 
 
66
 
    Given BASE, OTHER, THIS, tries to produce a combined text
67
 
    incorporating the changes from both BASE->OTHER and BASE->THIS.
68
 
    All three will typically be sequences of lines."""
69
 
    def __init__(self, base, a, b):
70
 
        check_text_lines(base)
71
 
        check_text_lines(a)
72
 
        check_text_lines(b)
73
 
        self.base = base
74
 
        self.a = a
75
 
        self.b = b
76
 
 
77
 
 
78
 
 
79
 
    def merge_lines(self,
80
 
                    name_a=None,
81
 
                    name_b=None,
82
 
                    name_base=None,
83
 
                    start_marker='<<<<<<<',
84
 
                    mid_marker='=======',
85
 
                    end_marker='>>>>>>>',
86
 
                    base_marker=None,
87
 
                    reprocess=False):
88
 
        """Return merge in cvs-like form.
89
 
        """
90
 
        if base_marker and reprocess:
91
 
            raise CantReprocessAndShowBase()
92
 
        if name_a:
93
 
            start_marker = start_marker + ' ' + name_a
94
 
        if name_b:
95
 
            end_marker = end_marker + ' ' + name_b
96
 
        if name_base and base_marker:
97
 
            base_marker = base_marker + ' ' + name_base
98
 
        merge_regions = self.merge_regions()
99
 
        if reprocess is True:
100
 
            merge_regions = self.reprocess_merge_regions(merge_regions)
101
 
        for t in merge_regions:
102
 
            what = t[0]
103
 
            if what == 'unchanged':
104
 
                for i in range(t[1], t[2]):
105
 
                    yield self.base[i]
106
 
            elif what == 'a' or what == 'same':
107
 
                for i in range(t[1], t[2]):
108
 
                    yield self.a[i]
109
 
            elif what == 'b':
110
 
                for i in range(t[1], t[2]):
111
 
                    yield self.b[i]
112
 
            elif what == 'conflict':
113
 
                yield start_marker + '\n'
114
 
                for i in range(t[3], t[4]):
115
 
                    yield self.a[i]
116
 
                if base_marker is not None:
117
 
                    yield base_marker + '\n'
118
 
                    for i in range(t[1], t[2]):
119
 
                        yield self.base[i]
120
 
                yield mid_marker + '\n'
121
 
                for i in range(t[5], t[6]):
122
 
                    yield self.b[i]
123
 
                yield end_marker + '\n'
124
 
            else:
125
 
                raise ValueError(what)
126
 
        
127
 
        
128
 
 
129
 
 
130
 
 
131
 
    def merge_annotated(self):
132
 
        """Return merge with conflicts, showing origin of lines.
133
 
 
134
 
        Most useful for debugging merge.        
135
 
        """
136
 
        for t in self.merge_regions():
137
 
            what = t[0]
138
 
            if what == 'unchanged':
139
 
                for i in range(t[1], t[2]):
140
 
                    yield 'u | ' + self.base[i]
141
 
            elif what == 'a' or what == 'same':
142
 
                for i in range(t[1], t[2]):
143
 
                    yield what[0] + ' | ' + self.a[i]
144
 
            elif what == 'b':
145
 
                for i in range(t[1], t[2]):
146
 
                    yield 'b | ' + self.b[i]
147
 
            elif what == 'conflict':
148
 
                yield '<<<<\n'
149
 
                for i in range(t[3], t[4]):
150
 
                    yield 'A | ' + self.a[i]
151
 
                yield '----\n'
152
 
                for i in range(t[5], t[6]):
153
 
                    yield 'B | ' + self.b[i]
154
 
                yield '>>>>\n'
155
 
            else:
156
 
                raise ValueError(what)
157
 
        
158
 
        
159
 
 
160
 
 
161
 
 
162
 
    def merge_groups(self):
163
 
        """Yield sequence of line groups.  Each one is a tuple:
164
 
 
165
 
        'unchanged', lines
166
 
             Lines unchanged from base
167
 
 
168
 
        'a', lines
169
 
             Lines taken from a
170
 
 
171
 
        'same', lines
172
 
             Lines taken from a (and equal to b)
173
 
 
174
 
        'b', lines
175
 
             Lines taken from b
176
 
 
177
 
        'conflict', base_lines, a_lines, b_lines
178
 
             Lines from base were changed to either a or b and conflict.
179
 
        """
180
 
        for t in self.merge_regions():
181
 
            what = t[0]
182
 
            if what == 'unchanged':
183
 
                yield what, self.base[t[1]:t[2]]
184
 
            elif what == 'a' or what == 'same':
185
 
                yield what, self.a[t[1]:t[2]]
186
 
            elif what == 'b':
187
 
                yield what, self.b[t[1]:t[2]]
188
 
            elif what == 'conflict':
189
 
                yield (what,
190
 
                       self.base[t[1]:t[2]],
191
 
                       self.a[t[3]:t[4]],
192
 
                       self.b[t[5]:t[6]])
193
 
            else:
194
 
                raise ValueError(what)
195
 
 
196
 
 
197
 
    def merge_regions(self):
198
 
        """Return sequences of matching and conflicting regions.
199
 
 
200
 
        This returns tuples, where the first value says what kind we
201
 
        have:
202
 
 
203
 
        'unchanged', start, end
204
 
             Take a region of base[start:end]
205
 
 
206
 
        'same', astart, aend
207
 
             b and a are different from base but give the same result
208
 
 
209
 
        'a', start, end
210
 
             Non-clashing insertion from a[start:end]
211
 
 
212
 
        Method is as follows:
213
 
 
214
 
        The two sequences align only on regions which match the base
215
 
        and both descendents.  These are found by doing a two-way diff
216
 
        of each one against the base, and then finding the
217
 
        intersections between those regions.  These "sync regions"
218
 
        are by definition unchanged in both and easily dealt with.
219
 
 
220
 
        The regions in between can be in any of three cases:
221
 
        conflicted, or changed on only one side.
222
 
        """
223
 
 
224
 
        # section a[0:ia] has been disposed of, etc
225
 
        iz = ia = ib = 0
226
 
        
227
 
        for zmatch, zend, amatch, aend, bmatch, bend in self.find_sync_regions():
228
 
            #print 'match base [%d:%d]' % (zmatch, zend)
229
 
            
230
 
            matchlen = zend - zmatch
231
 
            assert matchlen >= 0
232
 
            assert matchlen == (aend - amatch)
233
 
            assert matchlen == (bend - bmatch)
234
 
            
235
 
            len_a = amatch - ia
236
 
            len_b = bmatch - ib
237
 
            len_base = zmatch - iz
238
 
            assert len_a >= 0
239
 
            assert len_b >= 0
240
 
            assert len_base >= 0
241
 
 
242
 
            #print 'unmatched a=%d, b=%d' % (len_a, len_b)
243
 
 
244
 
            if len_a or len_b:
245
 
                # try to avoid actually slicing the lists
246
 
                equal_a = compare_range(self.a, ia, amatch,
247
 
                                        self.base, iz, zmatch)
248
 
                equal_b = compare_range(self.b, ib, bmatch,
249
 
                                        self.base, iz, zmatch)
250
 
                same = compare_range(self.a, ia, amatch,
251
 
                                     self.b, ib, bmatch)
252
 
 
253
 
                if same:
254
 
                    yield 'same', ia, amatch
255
 
                elif equal_a and not equal_b:
256
 
                    yield 'b', ib, bmatch
257
 
                elif equal_b and not equal_a:
258
 
                    yield 'a', ia, amatch
259
 
                elif not equal_a and not equal_b:
260
 
                    yield 'conflict', iz, zmatch, ia, amatch, ib, bmatch
261
 
                else:
262
 
                    raise AssertionError("can't handle a=b=base but unmatched")
263
 
 
264
 
                ia = amatch
265
 
                ib = bmatch
266
 
            iz = zmatch
267
 
 
268
 
            # if the same part of the base was deleted on both sides
269
 
            # that's OK, we can just skip it.
270
 
 
271
 
                
272
 
            if matchlen > 0:
273
 
                assert ia == amatch
274
 
                assert ib == bmatch
275
 
                assert iz == zmatch
276
 
                
277
 
                yield 'unchanged', zmatch, zend
278
 
                iz = zend
279
 
                ia = aend
280
 
                ib = bend
281
 
    
282
 
 
283
 
    def reprocess_merge_regions(self, merge_regions):
284
 
        """Where there are conflict regions, remove the agreed lines.
285
 
 
286
 
        Lines where both A and B have made the same changes are 
287
 
        eliminated.
288
 
        """
289
 
        for region in merge_regions:
290
 
            if region[0] != "conflict":
291
 
                yield region
292
 
                continue
293
 
            type, iz, zmatch, ia, amatch, ib, bmatch = region
294
 
            a_region = self.a[ia:amatch]
295
 
            b_region = self.b[ib:bmatch]
296
 
            matches = SequenceMatcher(None, a_region, 
297
 
                                      b_region).get_matching_blocks()
298
 
            next_a = ia
299
 
            next_b = ib
300
 
            for region_ia, region_ib, region_len in matches[:-1]:
301
 
                region_ia += ia
302
 
                region_ib += ib
303
 
                reg = self.mismatch_region(next_a, region_ia, next_b,
304
 
                                           region_ib)
305
 
                if reg is not None:
306
 
                    yield reg
307
 
                yield 'same', region_ia, region_len+region_ia
308
 
                next_a = region_ia + region_len
309
 
                next_b = region_ib + region_len
310
 
            reg = self.mismatch_region(next_a, amatch, next_b, bmatch)
311
 
            if reg is not None:
312
 
                yield reg
313
 
 
314
 
 
315
 
    @staticmethod
316
 
    def mismatch_region(next_a, region_ia,  next_b, region_ib):
317
 
        if next_a < region_ia or next_b < region_ib:
318
 
            return 'conflict', None, None, next_a, region_ia, next_b, region_ib
319
 
            
320
 
 
321
 
    def find_sync_regions(self):
322
 
        """Return a list of sync regions, where both descendents match the base.
323
 
 
324
 
        Generates a list of (base1, base2, a1, a2, b1, b2).  There is
325
 
        always a zero-length sync region at the end of all the files.
326
 
        """
327
 
 
328
 
        ia = ib = 0
329
 
        amatches = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
330
 
        bmatches = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
331
 
        len_a = len(amatches)
332
 
        len_b = len(bmatches)
333
 
 
334
 
        sl = []
335
 
 
336
 
        while ia < len_a and ib < len_b:
337
 
            abase, amatch, alen = amatches[ia]
338
 
            bbase, bmatch, blen = bmatches[ib]
339
 
 
340
 
            # there is an unconflicted block at i; how long does it
341
 
            # extend?  until whichever one ends earlier.
342
 
            i = intersect((abase, abase+alen), (bbase, bbase+blen))
343
 
            if i:
344
 
                intbase = i[0]
345
 
                intend = i[1]
346
 
                intlen = intend - intbase
347
 
 
348
 
                # found a match of base[i[0], i[1]]; this may be less than
349
 
                # the region that matches in either one
350
 
                assert intlen <= alen
351
 
                assert intlen <= blen
352
 
                assert abase <= intbase
353
 
                assert bbase <= intbase
354
 
 
355
 
                asub = amatch + (intbase - abase)
356
 
                bsub = bmatch + (intbase - bbase)
357
 
                aend = asub + intlen
358
 
                bend = bsub + intlen
359
 
 
360
 
                assert self.base[intbase:intend] == self.a[asub:aend], \
361
 
                       (self.base[intbase:intend], self.a[asub:aend])
362
 
 
363
 
                assert self.base[intbase:intend] == self.b[bsub:bend]
364
 
 
365
 
                sl.append((intbase, intend,
366
 
                           asub, aend,
367
 
                           bsub, bend))
368
 
 
369
 
            # advance whichever one ends first in the base text
370
 
            if (abase + alen) < (bbase + blen):
371
 
                ia += 1
372
 
            else:
373
 
                ib += 1
374
 
            
375
 
        intbase = len(self.base)
376
 
        abase = len(self.a)
377
 
        bbase = len(self.b)
378
 
        sl.append((intbase, intbase, abase, abase, bbase, bbase))
379
 
 
380
 
        return sl
381
 
 
382
 
 
383
 
 
384
 
    def find_unconflicted(self):
385
 
        """Return a list of ranges in base that are not conflicted."""
386
 
        am = SequenceMatcher(None, self.base, self.a).get_matching_blocks()
387
 
        bm = SequenceMatcher(None, self.base, self.b).get_matching_blocks()
388
 
 
389
 
        unc = []
390
 
 
391
 
        while am and bm:
392
 
            # there is an unconflicted block at i; how long does it
393
 
            # extend?  until whichever one ends earlier.
394
 
            a1 = am[0][0]
395
 
            a2 = a1 + am[0][2]
396
 
            b1 = bm[0][0]
397
 
            b2 = b1 + bm[0][2]
398
 
            i = intersect((a1, a2), (b1, b2))
399
 
            if i:
400
 
                unc.append(i)
401
 
 
402
 
            if a2 < b2:
403
 
                del am[0]
404
 
            else:
405
 
                del bm[0]
406
 
                
407
 
        return unc
408
 
 
409
 
 
410
 
def main(argv):
411
 
    # as for diff3 and meld the syntax is "MINE BASE OTHER"
412
 
    a = file(argv[1], 'rt').readlines()
413
 
    base = file(argv[2], 'rt').readlines()
414
 
    b = file(argv[3], 'rt').readlines()
415
 
 
416
 
    m3 = Merge3(base, a, b)
417
 
 
418
 
    #for sr in m3.find_sync_regions():
419
 
    #    print sr
420
 
 
421
 
    # sys.stdout.writelines(m3.merge_lines(name_a=argv[1], name_b=argv[3]))
422
 
    sys.stdout.writelines(m3.merge_annotated())
423
 
 
424
 
 
425
 
if __name__ == '__main__':
426
 
    import sys
427
 
    sys.exit(main(sys.argv))