~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2010-02-03 00:08:23 UTC
  • mto: This revision was merged to the branch mainline in revision 5002.
  • Revision ID: mbp@sourcefrog.net-20100203000823-fcyf2791xrl3fbfo
expand tabs

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004 - 2006, 2008 Aaron Bentley, Canonical Ltd
 
2
# <aaron.bentley@utoronto.ca>
 
3
#
 
4
# This program is free software; you can redistribute it and/or modify
 
5
# it under the terms of the GNU General Public License as published by
 
6
# the Free Software Foundation; either version 2 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
17
import re
 
18
 
 
19
 
 
20
binary_files_re = 'Binary files (.*) and (.*) differ\n'
 
21
 
 
22
 
 
23
class BinaryFiles(Exception):
 
24
 
 
25
    def __init__(self, orig_name, mod_name):
 
26
        self.orig_name = orig_name
 
27
        self.mod_name = mod_name
 
28
        Exception.__init__(self, 'Binary files section encountered.')
 
29
 
 
30
 
 
31
class PatchSyntax(Exception):
 
32
    def __init__(self, msg):
 
33
        Exception.__init__(self, msg)
 
34
 
 
35
 
 
36
class MalformedPatchHeader(PatchSyntax):
 
37
    def __init__(self, desc, line):
 
38
        self.desc = desc
 
39
        self.line = line
 
40
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
 
41
        PatchSyntax.__init__(self, msg)
 
42
 
 
43
 
 
44
class MalformedHunkHeader(PatchSyntax):
 
45
    def __init__(self, desc, line):
 
46
        self.desc = desc
 
47
        self.line = line
 
48
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
 
49
        PatchSyntax.__init__(self, msg)
 
50
 
 
51
 
 
52
class MalformedLine(PatchSyntax):
 
53
    def __init__(self, desc, line):
 
54
        self.desc = desc
 
55
        self.line = line
 
56
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
 
57
        PatchSyntax.__init__(self, msg)
 
58
 
 
59
 
 
60
class PatchConflict(Exception):
 
61
    def __init__(self, line_no, orig_line, patch_line):
 
62
        orig = orig_line.rstrip('\n')
 
63
        patch = str(patch_line).rstrip('\n')
 
64
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
65
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
66
        Exception.__init__(self, msg)
 
67
 
 
68
 
 
69
def get_patch_names(iter_lines):
 
70
    try:
 
71
        line = iter_lines.next()
 
72
        match = re.match(binary_files_re, line)
 
73
        if match is not None:
 
74
            raise BinaryFiles(match.group(1), match.group(2))
 
75
        if not line.startswith("--- "):
 
76
            raise MalformedPatchHeader("No orig name", line)
 
77
        else:
 
78
            orig_name = line[4:].rstrip("\n")
 
79
    except StopIteration:
 
80
        raise MalformedPatchHeader("No orig line", "")
 
81
    try:
 
82
        line = iter_lines.next()
 
83
        if not line.startswith("+++ "):
 
84
            raise PatchSyntax("No mod name")
 
85
        else:
 
86
            mod_name = line[4:].rstrip("\n")
 
87
    except StopIteration:
 
88
        raise MalformedPatchHeader("No mod line", "")
 
89
    return (orig_name, mod_name)
 
90
 
 
91
 
 
92
def parse_range(textrange):
 
93
    """Parse a patch range, handling the "1" special-case
 
94
 
 
95
    :param textrange: The text to parse
 
96
    :type textrange: str
 
97
    :return: the position and range, as a tuple
 
98
    :rtype: (int, int)
 
99
    """
 
100
    tmp = textrange.split(',')
 
101
    if len(tmp) == 1:
 
102
        pos = tmp[0]
 
103
        range = "1"
 
104
    else:
 
105
        (pos, range) = tmp
 
106
    pos = int(pos)
 
107
    range = int(range)
 
108
    return (pos, range)
 
109
 
 
110
 
 
111
def hunk_from_header(line):
 
112
    import re
 
113
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
 
114
    if matches is None:
 
115
        raise MalformedHunkHeader("Does not match format.", line)
 
116
    try:
 
117
        (orig, mod) = matches.group(1).split(" ")
 
118
    except (ValueError, IndexError), e:
 
119
        raise MalformedHunkHeader(str(e), line)
 
120
    if not orig.startswith('-') or not mod.startswith('+'):
 
121
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
 
122
    try:
 
123
        (orig_pos, orig_range) = parse_range(orig[1:])
 
124
        (mod_pos, mod_range) = parse_range(mod[1:])
 
125
    except (ValueError, IndexError), e:
 
126
        raise MalformedHunkHeader(str(e), line)
 
127
    if mod_range < 0 or orig_range < 0:
 
128
        raise MalformedHunkHeader("Hunk range is negative", line)
 
129
    tail = matches.group(3)
 
130
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
 
131
 
 
132
 
 
133
class HunkLine:
 
134
    def __init__(self, contents):
 
135
        self.contents = contents
 
136
 
 
137
    def get_str(self, leadchar):
 
138
        if self.contents == "\n" and leadchar == " " and False:
 
139
            return "\n"
 
140
        if not self.contents.endswith('\n'):
 
141
            terminator = '\n' + NO_NL
 
142
        else:
 
143
            terminator = ''
 
144
        return leadchar + self.contents + terminator
 
145
 
 
146
 
 
147
class ContextLine(HunkLine):
 
148
    def __init__(self, contents):
 
149
        HunkLine.__init__(self, contents)
 
150
 
 
151
    def __str__(self):
 
152
        return self.get_str(" ")
 
153
 
 
154
 
 
155
class InsertLine(HunkLine):
 
156
    def __init__(self, contents):
 
157
        HunkLine.__init__(self, contents)
 
158
 
 
159
    def __str__(self):
 
160
        return self.get_str("+")
 
161
 
 
162
 
 
163
class RemoveLine(HunkLine):
 
164
    def __init__(self, contents):
 
165
        HunkLine.__init__(self, contents)
 
166
 
 
167
    def __str__(self):
 
168
        return self.get_str("-")
 
169
 
 
170
NO_NL = '\\ No newline at end of file\n'
 
171
__pychecker__="no-returnvalues"
 
172
 
 
173
def parse_line(line):
 
174
    if line.startswith("\n"):
 
175
        return ContextLine(line)
 
176
    elif line.startswith(" "):
 
177
        return ContextLine(line[1:])
 
178
    elif line.startswith("+"):
 
179
        return InsertLine(line[1:])
 
180
    elif line.startswith("-"):
 
181
        return RemoveLine(line[1:])
 
182
    else:
 
183
        raise MalformedLine("Unknown line type", line)
 
184
__pychecker__=""
 
185
 
 
186
 
 
187
class Hunk:
 
188
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
 
189
        self.orig_pos = orig_pos
 
190
        self.orig_range = orig_range
 
191
        self.mod_pos = mod_pos
 
192
        self.mod_range = mod_range
 
193
        self.tail = tail
 
194
        self.lines = []
 
195
 
 
196
    def get_header(self):
 
197
        if self.tail is None:
 
198
            tail_str = ''
 
199
        else:
 
200
            tail_str = ' ' + self.tail
 
201
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
 
202
                                                     self.orig_range),
 
203
                                      self.range_str(self.mod_pos,
 
204
                                                     self.mod_range),
 
205
                                      tail_str)
 
206
 
 
207
    def range_str(self, pos, range):
 
208
        """Return a file range, special-casing for 1-line files.
 
209
 
 
210
        :param pos: The position in the file
 
211
        :type pos: int
 
212
        :range: The range in the file
 
213
        :type range: int
 
214
        :return: a string in the format 1,4 except when range == pos == 1
 
215
        """
 
216
        if range == 1:
 
217
            return "%i" % pos
 
218
        else:
 
219
            return "%i,%i" % (pos, range)
 
220
 
 
221
    def __str__(self):
 
222
        lines = [self.get_header()]
 
223
        for line in self.lines:
 
224
            lines.append(str(line))
 
225
        return "".join(lines)
 
226
 
 
227
    def shift_to_mod(self, pos):
 
228
        if pos < self.orig_pos-1:
 
229
            return 0
 
230
        elif pos > self.orig_pos+self.orig_range:
 
231
            return self.mod_range - self.orig_range
 
232
        else:
 
233
            return self.shift_to_mod_lines(pos)
 
234
 
 
235
    def shift_to_mod_lines(self, pos):
 
236
        position = self.orig_pos-1
 
237
        shift = 0
 
238
        for line in self.lines:
 
239
            if isinstance(line, InsertLine):
 
240
                shift += 1
 
241
            elif isinstance(line, RemoveLine):
 
242
                if position == pos:
 
243
                    return None
 
244
                shift -= 1
 
245
                position += 1
 
246
            elif isinstance(line, ContextLine):
 
247
                position += 1
 
248
            if position > pos:
 
249
                break
 
250
        return shift
 
251
 
 
252
 
 
253
def iter_hunks(iter_lines):
 
254
    hunk = None
 
255
    for line in iter_lines:
 
256
        if line == "\n":
 
257
            if hunk is not None:
 
258
                yield hunk
 
259
                hunk = None
 
260
            continue
 
261
        if hunk is not None:
 
262
            yield hunk
 
263
        hunk = hunk_from_header(line)
 
264
        orig_size = 0
 
265
        mod_size = 0
 
266
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
267
            hunk_line = parse_line(iter_lines.next())
 
268
            hunk.lines.append(hunk_line)
 
269
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
270
                orig_size += 1
 
271
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
272
                mod_size += 1
 
273
    if hunk is not None:
 
274
        yield hunk
 
275
 
 
276
 
 
277
class BinaryPatch(object):
 
278
    def __init__(self, oldname, newname):
 
279
        self.oldname = oldname
 
280
        self.newname = newname
 
281
 
 
282
    def __str__(self):
 
283
        return 'Binary files %s and %s differ\n' % (self.oldname, self.newname)
 
284
 
 
285
 
 
286
class Patch(BinaryPatch):
 
287
 
 
288
    def __init__(self, oldname, newname):
 
289
        BinaryPatch.__init__(self, oldname, newname)
 
290
        self.hunks = []
 
291
 
 
292
    def __str__(self):
 
293
        ret = self.get_header()
 
294
        ret += "".join([str(h) for h in self.hunks])
 
295
        return ret
 
296
 
 
297
    def get_header(self):
 
298
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
299
 
 
300
    def stats_values(self):
 
301
        """Calculate the number of inserts and removes."""
 
302
        removes = 0
 
303
        inserts = 0
 
304
        for hunk in self.hunks:
 
305
            for line in hunk.lines:
 
306
                if isinstance(line, InsertLine):
 
307
                     inserts+=1;
 
308
                elif isinstance(line, RemoveLine):
 
309
                     removes+=1;
 
310
        return (inserts, removes, len(self.hunks))
 
311
 
 
312
    def stats_str(self):
 
313
        """Return a string of patch statistics"""
 
314
        return "%i inserts, %i removes in %i hunks" % \
 
315
            self.stats_values()
 
316
 
 
317
    def pos_in_mod(self, position):
 
318
        newpos = position
 
319
        for hunk in self.hunks:
 
320
            shift = hunk.shift_to_mod(position)
 
321
            if shift is None:
 
322
                return None
 
323
            newpos += shift
 
324
        return newpos
 
325
 
 
326
    def iter_inserted(self):
 
327
        """Iteraties through inserted lines
 
328
 
 
329
        :return: Pair of line number, line
 
330
        :rtype: iterator of (int, InsertLine)
 
331
        """
 
332
        for hunk in self.hunks:
 
333
            pos = hunk.mod_pos - 1;
 
334
            for line in hunk.lines:
 
335
                if isinstance(line, InsertLine):
 
336
                    yield (pos, line)
 
337
                    pos += 1
 
338
                if isinstance(line, ContextLine):
 
339
                    pos += 1
 
340
 
 
341
 
 
342
def parse_patch(iter_lines):
 
343
    iter_lines = iter_lines_handle_nl(iter_lines)
 
344
    try:
 
345
        (orig_name, mod_name) = get_patch_names(iter_lines)
 
346
    except BinaryFiles, e:
 
347
        return BinaryPatch(e.orig_name, e.mod_name)
 
348
    else:
 
349
        patch = Patch(orig_name, mod_name)
 
350
        for hunk in iter_hunks(iter_lines):
 
351
            patch.hunks.append(hunk)
 
352
        return patch
 
353
 
 
354
 
 
355
def iter_file_patch(iter_lines):
 
356
    regex = re.compile(binary_files_re)
 
357
    saved_lines = []
 
358
    orig_range = 0
 
359
    for line in iter_lines:
 
360
        if line.startswith('=== ') or line.startswith('*** '):
 
361
            continue
 
362
        if line.startswith('#'):
 
363
            continue
 
364
        elif orig_range > 0:
 
365
            if line.startswith('-') or line.startswith(' '):
 
366
                orig_range -= 1
 
367
        elif line.startswith('--- ') or regex.match(line):
 
368
            if len(saved_lines) > 0:
 
369
                yield saved_lines
 
370
            saved_lines = []
 
371
        elif line.startswith('@@'):
 
372
            hunk = hunk_from_header(line)
 
373
            orig_range = hunk.orig_range
 
374
        saved_lines.append(line)
 
375
    if len(saved_lines) > 0:
 
376
        yield saved_lines
 
377
 
 
378
 
 
379
def iter_lines_handle_nl(iter_lines):
 
380
    """
 
381
    Iterates through lines, ensuring that lines that originally had no
 
382
    terminating \n are produced without one.  This transformation may be
 
383
    applied at any point up until hunk line parsing, and is safe to apply
 
384
    repeatedly.
 
385
    """
 
386
    last_line = None
 
387
    for line in iter_lines:
 
388
        if line == NO_NL:
 
389
            if not last_line.endswith('\n'):
 
390
                raise AssertionError()
 
391
            last_line = last_line[:-1]
 
392
            line = None
 
393
        if last_line is not None:
 
394
            yield last_line
 
395
        last_line = line
 
396
    if last_line is not None:
 
397
        yield last_line
 
398
 
 
399
 
 
400
def parse_patches(iter_lines):
 
401
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
 
402
 
 
403
 
 
404
def difference_index(atext, btext):
 
405
    """Find the indext of the first character that differs between two texts
 
406
 
 
407
    :param atext: The first text
 
408
    :type atext: str
 
409
    :param btext: The second text
 
410
    :type str: str
 
411
    :return: The index, or None if there are no differences within the range
 
412
    :rtype: int or NoneType
 
413
    """
 
414
    length = len(atext)
 
415
    if len(btext) < length:
 
416
        length = len(btext)
 
417
    for i in range(length):
 
418
        if atext[i] != btext[i]:
 
419
            return i;
 
420
    return None
 
421
 
 
422
 
 
423
def iter_patched(orig_lines, patch_lines):
 
424
    """Iterate through a series of lines with a patch applied.
 
425
    This handles a single file, and does exact, not fuzzy patching.
 
426
    """
 
427
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
 
428
    get_patch_names(patch_lines)
 
429
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
 
430
 
 
431
 
 
432
def iter_patched_from_hunks(orig_lines, hunks):
 
433
    """Iterate through a series of lines with a patch applied.
 
434
    This handles a single file, and does exact, not fuzzy patching.
 
435
 
 
436
    :param orig_lines: The unpatched lines.
 
437
    :param hunks: An iterable of Hunk instances.
 
438
    """
 
439
    seen_patch = []
 
440
    line_no = 1
 
441
    if orig_lines is not None:
 
442
        orig_lines = iter(orig_lines)
 
443
    for hunk in hunks:
 
444
        while line_no < hunk.orig_pos:
 
445
            orig_line = orig_lines.next()
 
446
            yield orig_line
 
447
            line_no += 1
 
448
        for hunk_line in hunk.lines:
 
449
            seen_patch.append(str(hunk_line))
 
450
            if isinstance(hunk_line, InsertLine):
 
451
                yield hunk_line.contents
 
452
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
453
                orig_line = orig_lines.next()
 
454
                if orig_line != hunk_line.contents:
 
455
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
456
                if isinstance(hunk_line, ContextLine):
 
457
                    yield orig_line
 
458
                else:
 
459
                    if not isinstance(hunk_line, RemoveLine):
 
460
                        raise AssertionError(hunk_line)
 
461
                line_no += 1
 
462
    if orig_lines is not None:
 
463
        for line in orig_lines:
 
464
            yield line