~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-04-21 01:30:22 UTC
  • Revision ID: mbp@sourcefrog.net-20050421013022-858a70691455e1dc
todo

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
 
import re
18
 
 
19
 
 
20
 
binary_files_re = 'Binary files (.*) and (.*) differ\n'
21
 
 
22
 
 
23
 
class BinaryFiles(Exception):
24
 
 
25
 
    def __init__(self, orig_name, mod_name):
26
 
        self.orig_name = orig_name
27
 
        self.mod_name = mod_name
28
 
        Exception.__init__(self, 'Binary files section encountered.')
29
 
 
30
 
 
31
 
class PatchSyntax(Exception):
32
 
    def __init__(self, msg):
33
 
        Exception.__init__(self, msg)
34
 
 
35
 
 
36
 
class MalformedPatchHeader(PatchSyntax):
37
 
    def __init__(self, desc, line):
38
 
        self.desc = desc
39
 
        self.line = line
40
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
41
 
        PatchSyntax.__init__(self, msg)
42
 
 
43
 
 
44
 
class MalformedHunkHeader(PatchSyntax):
45
 
    def __init__(self, desc, line):
46
 
        self.desc = desc
47
 
        self.line = line
48
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
49
 
        PatchSyntax.__init__(self, msg)
50
 
 
51
 
 
52
 
class MalformedLine(PatchSyntax):
53
 
    def __init__(self, desc, line):
54
 
        self.desc = desc
55
 
        self.line = line
56
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
57
 
        PatchSyntax.__init__(self, msg)
58
 
 
59
 
 
60
 
class PatchConflict(Exception):
61
 
    def __init__(self, line_no, orig_line, patch_line):
62
 
        orig = orig_line.rstrip('\n')
63
 
        patch = str(patch_line).rstrip('\n')
64
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
65
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
66
 
        Exception.__init__(self, msg)
67
 
 
68
 
 
69
 
def get_patch_names(iter_lines):
70
 
    try:
71
 
        line = iter_lines.next()
72
 
        match = re.match(binary_files_re, line)
73
 
        if match is not None:
74
 
            raise BinaryFiles(match.group(1), match.group(2))
75
 
        if not line.startswith("--- "):
76
 
            raise MalformedPatchHeader("No orig name", line)
77
 
        else:
78
 
            orig_name = line[4:].rstrip("\n")
79
 
    except StopIteration:
80
 
        raise MalformedPatchHeader("No orig line", "")
81
 
    try:
82
 
        line = iter_lines.next()
83
 
        if not line.startswith("+++ "):
84
 
            raise PatchSyntax("No mod name")
85
 
        else:
86
 
            mod_name = line[4:].rstrip("\n")
87
 
    except StopIteration:
88
 
        raise MalformedPatchHeader("No mod line", "")
89
 
    return (orig_name, mod_name)
90
 
 
91
 
 
92
 
def parse_range(textrange):
93
 
    """Parse a patch range, handling the "1" special-case
94
 
 
95
 
    :param textrange: The text to parse
96
 
    :type textrange: str
97
 
    :return: the position and range, as a tuple
98
 
    :rtype: (int, int)
99
 
    """
100
 
    tmp = textrange.split(',')
101
 
    if len(tmp) == 1:
102
 
        pos = tmp[0]
103
 
        range = "1"
104
 
    else:
105
 
        (pos, range) = tmp
106
 
    pos = int(pos)
107
 
    range = int(range)
108
 
    return (pos, range)
109
 
 
110
 
 
111
 
def hunk_from_header(line):
112
 
    import re
113
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
114
 
    if matches is None:
115
 
        raise MalformedHunkHeader("Does not match format.", line)
116
 
    try:
117
 
        (orig, mod) = matches.group(1).split(" ")
118
 
    except (ValueError, IndexError), e:
119
 
        raise MalformedHunkHeader(str(e), line)
120
 
    if not orig.startswith('-') or not mod.startswith('+'):
121
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
122
 
    try:
123
 
        (orig_pos, orig_range) = parse_range(orig[1:])
124
 
        (mod_pos, mod_range) = parse_range(mod[1:])
125
 
    except (ValueError, IndexError), e:
126
 
        raise MalformedHunkHeader(str(e), line)
127
 
    if mod_range < 0 or orig_range < 0:
128
 
        raise MalformedHunkHeader("Hunk range is negative", line)
129
 
    tail = matches.group(3)
130
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
131
 
 
132
 
 
133
 
class HunkLine:
134
 
    def __init__(self, contents):
135
 
        self.contents = contents
136
 
 
137
 
    def get_str(self, leadchar):
138
 
        if self.contents == "\n" and leadchar == " " and False:
139
 
            return "\n"
140
 
        if not self.contents.endswith('\n'):
141
 
            terminator = '\n' + NO_NL
142
 
        else:
143
 
            terminator = ''
144
 
        return leadchar + self.contents + terminator
145
 
 
146
 
 
147
 
class ContextLine(HunkLine):
148
 
    def __init__(self, contents):
149
 
        HunkLine.__init__(self, contents)
150
 
 
151
 
    def __str__(self):
152
 
        return self.get_str(" ")
153
 
 
154
 
 
155
 
class InsertLine(HunkLine):
156
 
    def __init__(self, contents):
157
 
        HunkLine.__init__(self, contents)
158
 
 
159
 
    def __str__(self):
160
 
        return self.get_str("+")
161
 
 
162
 
 
163
 
class RemoveLine(HunkLine):
164
 
    def __init__(self, contents):
165
 
        HunkLine.__init__(self, contents)
166
 
 
167
 
    def __str__(self):
168
 
        return self.get_str("-")
169
 
 
170
 
NO_NL = '\\ No newline at end of file\n'
171
 
__pychecker__="no-returnvalues"
172
 
 
173
 
def parse_line(line):
174
 
    if line.startswith("\n"):
175
 
        return ContextLine(line)
176
 
    elif line.startswith(" "):
177
 
        return ContextLine(line[1:])
178
 
    elif line.startswith("+"):
179
 
        return InsertLine(line[1:])
180
 
    elif line.startswith("-"):
181
 
        return RemoveLine(line[1:])
182
 
    else:
183
 
        raise MalformedLine("Unknown line type", line)
184
 
__pychecker__=""
185
 
 
186
 
 
187
 
class Hunk:
188
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
189
 
        self.orig_pos = orig_pos
190
 
        self.orig_range = orig_range
191
 
        self.mod_pos = mod_pos
192
 
        self.mod_range = mod_range
193
 
        self.tail = tail
194
 
        self.lines = []
195
 
 
196
 
    def get_header(self):
197
 
        if self.tail is None:
198
 
            tail_str = ''
199
 
        else:
200
 
            tail_str = ' ' + self.tail
201
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
202
 
                                                     self.orig_range),
203
 
                                      self.range_str(self.mod_pos,
204
 
                                                     self.mod_range),
205
 
                                      tail_str)
206
 
 
207
 
    def range_str(self, pos, range):
208
 
        """Return a file range, special-casing for 1-line files.
209
 
 
210
 
        :param pos: The position in the file
211
 
        :type pos: int
212
 
        :range: The range in the file
213
 
        :type range: int
214
 
        :return: a string in the format 1,4 except when range == pos == 1
215
 
        """
216
 
        if range == 1:
217
 
            return "%i" % pos
218
 
        else:
219
 
            return "%i,%i" % (pos, range)
220
 
 
221
 
    def __str__(self):
222
 
        lines = [self.get_header()]
223
 
        for line in self.lines:
224
 
            lines.append(str(line))
225
 
        return "".join(lines)
226
 
 
227
 
    def shift_to_mod(self, pos):
228
 
        if pos < self.orig_pos-1:
229
 
            return 0
230
 
        elif pos > self.orig_pos+self.orig_range:
231
 
            return self.mod_range - self.orig_range
232
 
        else:
233
 
            return self.shift_to_mod_lines(pos)
234
 
 
235
 
    def shift_to_mod_lines(self, pos):
236
 
        position = self.orig_pos-1
237
 
        shift = 0
238
 
        for line in self.lines:
239
 
            if isinstance(line, InsertLine):
240
 
                shift += 1
241
 
            elif isinstance(line, RemoveLine):
242
 
                if position == pos:
243
 
                    return None
244
 
                shift -= 1
245
 
                position += 1
246
 
            elif isinstance(line, ContextLine):
247
 
                position += 1
248
 
            if position > pos:
249
 
                break
250
 
        return shift
251
 
 
252
 
 
253
 
def iter_hunks(iter_lines, allow_dirty=False):
254
 
    '''
255
 
    :arg iter_lines: iterable of lines to parse for hunks
256
 
    :kwarg allow_dirty: If True, when we encounter something that is not
257
 
        a hunk header when we're looking for one, assume the rest of the lines
258
 
        are not part of the patch (comments or other junk).  Default False
259
 
    '''
260
 
    hunk = None
261
 
    for line in iter_lines:
262
 
        if line == "\n":
263
 
            if hunk is not None:
264
 
                yield hunk
265
 
                hunk = None
266
 
            continue
267
 
        if hunk is not None:
268
 
            yield hunk
269
 
        try:
270
 
            hunk = hunk_from_header(line)
271
 
        except MalformedHunkHeader:
272
 
            if allow_dirty:
273
 
                # If the line isn't a hunk header, then we've reached the end
274
 
                # of this patch and there's "junk" at the end.  Ignore the
275
 
                # rest of this patch.
276
 
                return
277
 
            raise
278
 
        orig_size = 0
279
 
        mod_size = 0
280
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
281
 
            hunk_line = parse_line(iter_lines.next())
282
 
            hunk.lines.append(hunk_line)
283
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
284
 
                orig_size += 1
285
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
286
 
                mod_size += 1
287
 
    if hunk is not None:
288
 
        yield hunk
289
 
 
290
 
 
291
 
class BinaryPatch(object):
292
 
    def __init__(self, oldname, newname):
293
 
        self.oldname = oldname
294
 
        self.newname = newname
295
 
 
296
 
    def __str__(self):
297
 
        return 'Binary files %s and %s differ\n' % (self.oldname, self.newname)
298
 
 
299
 
 
300
 
class Patch(BinaryPatch):
301
 
 
302
 
    def __init__(self, oldname, newname):
303
 
        BinaryPatch.__init__(self, oldname, newname)
304
 
        self.hunks = []
305
 
 
306
 
    def __str__(self):
307
 
        ret = self.get_header()
308
 
        ret += "".join([str(h) for h in self.hunks])
309
 
        return ret
310
 
 
311
 
    def get_header(self):
312
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
313
 
 
314
 
    def stats_values(self):
315
 
        """Calculate the number of inserts and removes."""
316
 
        removes = 0
317
 
        inserts = 0
318
 
        for hunk in self.hunks:
319
 
            for line in hunk.lines:
320
 
                if isinstance(line, InsertLine):
321
 
                     inserts+=1;
322
 
                elif isinstance(line, RemoveLine):
323
 
                     removes+=1;
324
 
        return (inserts, removes, len(self.hunks))
325
 
 
326
 
    def stats_str(self):
327
 
        """Return a string of patch statistics"""
328
 
        return "%i inserts, %i removes in %i hunks" % \
329
 
            self.stats_values()
330
 
 
331
 
    def pos_in_mod(self, position):
332
 
        newpos = position
333
 
        for hunk in self.hunks:
334
 
            shift = hunk.shift_to_mod(position)
335
 
            if shift is None:
336
 
                return None
337
 
            newpos += shift
338
 
        return newpos
339
 
 
340
 
    def iter_inserted(self):
341
 
        """Iteraties through inserted lines
342
 
 
343
 
        :return: Pair of line number, line
344
 
        :rtype: iterator of (int, InsertLine)
345
 
        """
346
 
        for hunk in self.hunks:
347
 
            pos = hunk.mod_pos - 1;
348
 
            for line in hunk.lines:
349
 
                if isinstance(line, InsertLine):
350
 
                    yield (pos, line)
351
 
                    pos += 1
352
 
                if isinstance(line, ContextLine):
353
 
                    pos += 1
354
 
 
355
 
 
356
 
def parse_patch(iter_lines, allow_dirty=False):
357
 
    '''
358
 
    :arg iter_lines: iterable of lines to parse
359
 
    :kwarg allow_dirty: If True, allow the patch to have trailing junk.
360
 
        Default False
361
 
    '''
362
 
    iter_lines = iter_lines_handle_nl(iter_lines)
363
 
    try:
364
 
        (orig_name, mod_name) = get_patch_names(iter_lines)
365
 
    except BinaryFiles, e:
366
 
        return BinaryPatch(e.orig_name, e.mod_name)
367
 
    else:
368
 
        patch = Patch(orig_name, mod_name)
369
 
        for hunk in iter_hunks(iter_lines, allow_dirty):
370
 
            patch.hunks.append(hunk)
371
 
        return patch
372
 
 
373
 
 
374
 
def iter_file_patch(iter_lines, allow_dirty=False):
375
 
    '''
376
 
    :arg iter_lines: iterable of lines to parse for patches
377
 
    :kwarg allow_dirty: If True, allow comments and other non-patch text
378
 
        before the first patch.  Note that the algorithm here can only find
379
 
        such text before any patches have been found.  Comments after the
380
 
        first patch are stripped away in iter_hunks() if it is also passed
381
 
        allow_dirty=True.  Default False.
382
 
    '''
383
 
    ### FIXME: Docstring is not quite true.  We allow certain comments no
384
 
    # matter what, If they startwith '===', '***', or '#' Someone should
385
 
    # reexamine this logic and decide if we should include those in
386
 
    # allow_dirty or restrict those to only being before the patch is found
387
 
    # (as allow_dirty does).
388
 
    regex = re.compile(binary_files_re)
389
 
    saved_lines = []
390
 
    orig_range = 0
391
 
    beginning = True
392
 
    for line in iter_lines:
393
 
        if line.startswith('=== ') or line.startswith('*** '):
394
 
            continue
395
 
        if line.startswith('#'):
396
 
            continue
397
 
        elif orig_range > 0:
398
 
            if line.startswith('-') or line.startswith(' '):
399
 
                orig_range -= 1
400
 
        elif line.startswith('--- ') or regex.match(line):
401
 
            if allow_dirty and beginning:
402
 
                # Patches can have "junk" at the beginning
403
 
                # Stripping junk from the end of patches is handled when we
404
 
                # parse the patch
405
 
                beginning = False
406
 
            elif len(saved_lines) > 0:
407
 
                yield saved_lines
408
 
            saved_lines = []
409
 
        elif line.startswith('@@'):
410
 
            hunk = hunk_from_header(line)
411
 
            orig_range = hunk.orig_range
412
 
        saved_lines.append(line)
413
 
    if len(saved_lines) > 0:
414
 
        yield saved_lines
415
 
 
416
 
 
417
 
def iter_lines_handle_nl(iter_lines):
418
 
    """
419
 
    Iterates through lines, ensuring that lines that originally had no
420
 
    terminating \n are produced without one.  This transformation may be
421
 
    applied at any point up until hunk line parsing, and is safe to apply
422
 
    repeatedly.
423
 
    """
424
 
    last_line = None
425
 
    for line in iter_lines:
426
 
        if line == NO_NL:
427
 
            if not last_line.endswith('\n'):
428
 
                raise AssertionError()
429
 
            last_line = last_line[:-1]
430
 
            line = None
431
 
        if last_line is not None:
432
 
            yield last_line
433
 
        last_line = line
434
 
    if last_line is not None:
435
 
        yield last_line
436
 
 
437
 
 
438
 
def parse_patches(iter_lines, allow_dirty=False):
439
 
    '''
440
 
    :arg iter_lines: iterable of lines to parse for patches
441
 
    :kwarg allow_dirty: If True, allow text that's not part of the patch at
442
 
        selected places.  This includes comments before and after a patch
443
 
        for instance.  Default False.
444
 
    '''
445
 
    return [parse_patch(f.__iter__(), allow_dirty) for f in
446
 
                        iter_file_patch(iter_lines, allow_dirty)]
447
 
 
448
 
 
449
 
def difference_index(atext, btext):
450
 
    """Find the indext of the first character that differs between two texts
451
 
 
452
 
    :param atext: The first text
453
 
    :type atext: str
454
 
    :param btext: The second text
455
 
    :type str: str
456
 
    :return: The index, or None if there are no differences within the range
457
 
    :rtype: int or NoneType
458
 
    """
459
 
    length = len(atext)
460
 
    if len(btext) < length:
461
 
        length = len(btext)
462
 
    for i in range(length):
463
 
        if atext[i] != btext[i]:
464
 
            return i;
465
 
    return None
466
 
 
467
 
 
468
 
def iter_patched(orig_lines, patch_lines):
469
 
    """Iterate through a series of lines with a patch applied.
470
 
    This handles a single file, and does exact, not fuzzy patching.
471
 
    """
472
 
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
473
 
    get_patch_names(patch_lines)
474
 
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
475
 
 
476
 
 
477
 
def iter_patched_from_hunks(orig_lines, hunks):
478
 
    """Iterate through a series of lines with a patch applied.
479
 
    This handles a single file, and does exact, not fuzzy patching.
480
 
 
481
 
    :param orig_lines: The unpatched lines.
482
 
    :param hunks: An iterable of Hunk instances.
483
 
    """
484
 
    seen_patch = []
485
 
    line_no = 1
486
 
    if orig_lines is not None:
487
 
        orig_lines = iter(orig_lines)
488
 
    for hunk in hunks:
489
 
        while line_no < hunk.orig_pos:
490
 
            orig_line = orig_lines.next()
491
 
            yield orig_line
492
 
            line_no += 1
493
 
        for hunk_line in hunk.lines:
494
 
            seen_patch.append(str(hunk_line))
495
 
            if isinstance(hunk_line, InsertLine):
496
 
                yield hunk_line.contents
497
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
498
 
                orig_line = orig_lines.next()
499
 
                if orig_line != hunk_line.contents:
500
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
501
 
                if isinstance(hunk_line, ContextLine):
502
 
                    yield orig_line
503
 
                else:
504
 
                    if not isinstance(hunk_line, RemoveLine):
505
 
                        raise AssertionError(hunk_line)
506
 
                line_no += 1
507
 
    if orig_lines is not None:
508
 
        for line in orig_lines:
509
 
            yield line