~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-05-05 07:00:17 UTC
  • Revision ID: mbp@sourcefrog.net-20050505070017-6af6a766fc558dc2
todo

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004 - 2006, 2008 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
 
import re
18
 
 
19
 
 
20
 
binary_files_re = 'Binary files (.*) and (.*) differ\n'
21
 
 
22
 
 
23
 
class BinaryFiles(Exception):
24
 
 
25
 
    def __init__(self, orig_name, mod_name):
26
 
        self.orig_name = orig_name
27
 
        self.mod_name = mod_name
28
 
        Exception.__init__(self, 'Binary files section encountered.')
29
 
 
30
 
 
31
 
class PatchSyntax(Exception):
32
 
    def __init__(self, msg):
33
 
        Exception.__init__(self, msg)
34
 
 
35
 
 
36
 
class MalformedPatchHeader(PatchSyntax):
37
 
    def __init__(self, desc, line):
38
 
        self.desc = desc
39
 
        self.line = line
40
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
41
 
        PatchSyntax.__init__(self, msg)
42
 
 
43
 
 
44
 
class MalformedHunkHeader(PatchSyntax):
45
 
    def __init__(self, desc, line):
46
 
        self.desc = desc
47
 
        self.line = line
48
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
49
 
        PatchSyntax.__init__(self, msg)
50
 
 
51
 
 
52
 
class MalformedLine(PatchSyntax):
53
 
    def __init__(self, desc, line):
54
 
        self.desc = desc
55
 
        self.line = line
56
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
57
 
        PatchSyntax.__init__(self, msg)
58
 
 
59
 
 
60
 
class PatchConflict(Exception):
61
 
    def __init__(self, line_no, orig_line, patch_line):
62
 
        orig = orig_line.rstrip('\n')
63
 
        patch = str(patch_line).rstrip('\n')
64
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
65
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
66
 
        Exception.__init__(self, msg)
67
 
 
68
 
 
69
 
def get_patch_names(iter_lines):
70
 
    try:
71
 
        line = iter_lines.next()
72
 
        match = re.match(binary_files_re, line)
73
 
        if match is not None:
74
 
            raise BinaryFiles(match.group(1), match.group(2))
75
 
        if not line.startswith("--- "):
76
 
            raise MalformedPatchHeader("No orig name", line)
77
 
        else:
78
 
            orig_name = line[4:].rstrip("\n")
79
 
    except StopIteration:
80
 
        raise MalformedPatchHeader("No orig line", "")
81
 
    try:
82
 
        line = iter_lines.next()
83
 
        if not line.startswith("+++ "):
84
 
            raise PatchSyntax("No mod name")
85
 
        else:
86
 
            mod_name = line[4:].rstrip("\n")
87
 
    except StopIteration:
88
 
        raise MalformedPatchHeader("No mod line", "")
89
 
    return (orig_name, mod_name)
90
 
 
91
 
 
92
 
def parse_range(textrange):
93
 
    """Parse a patch range, handling the "1" special-case
94
 
 
95
 
    :param textrange: The text to parse
96
 
    :type textrange: str
97
 
    :return: the position and range, as a tuple
98
 
    :rtype: (int, int)
99
 
    """
100
 
    tmp = textrange.split(',')
101
 
    if len(tmp) == 1:
102
 
        pos = tmp[0]
103
 
        range = "1"
104
 
    else:
105
 
        (pos, range) = tmp
106
 
    pos = int(pos)
107
 
    range = int(range)
108
 
    return (pos, range)
109
 
 
110
 
 
111
 
def hunk_from_header(line):
112
 
    import re
113
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
114
 
    if matches is None:
115
 
        raise MalformedHunkHeader("Does not match format.", line)
116
 
    try:
117
 
        (orig, mod) = matches.group(1).split(" ")
118
 
    except (ValueError, IndexError), e:
119
 
        raise MalformedHunkHeader(str(e), line)
120
 
    if not orig.startswith('-') or not mod.startswith('+'):
121
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
122
 
    try:
123
 
        (orig_pos, orig_range) = parse_range(orig[1:])
124
 
        (mod_pos, mod_range) = parse_range(mod[1:])
125
 
    except (ValueError, IndexError), e:
126
 
        raise MalformedHunkHeader(str(e), line)
127
 
    if mod_range < 0 or orig_range < 0:
128
 
        raise MalformedHunkHeader("Hunk range is negative", line)
129
 
    tail = matches.group(3)
130
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
131
 
 
132
 
 
133
 
class HunkLine:
134
 
    def __init__(self, contents):
135
 
        self.contents = contents
136
 
 
137
 
    def get_str(self, leadchar):
138
 
        if self.contents == "\n" and leadchar == " " and False:
139
 
            return "\n"
140
 
        if not self.contents.endswith('\n'):
141
 
            terminator = '\n' + NO_NL
142
 
        else:
143
 
            terminator = ''
144
 
        return leadchar + self.contents + terminator
145
 
 
146
 
 
147
 
class ContextLine(HunkLine):
148
 
    def __init__(self, contents):
149
 
        HunkLine.__init__(self, contents)
150
 
 
151
 
    def __str__(self):
152
 
        return self.get_str(" ")
153
 
 
154
 
 
155
 
class InsertLine(HunkLine):
156
 
    def __init__(self, contents):
157
 
        HunkLine.__init__(self, contents)
158
 
 
159
 
    def __str__(self):
160
 
        return self.get_str("+")
161
 
 
162
 
 
163
 
class RemoveLine(HunkLine):
164
 
    def __init__(self, contents):
165
 
        HunkLine.__init__(self, contents)
166
 
 
167
 
    def __str__(self):
168
 
        return self.get_str("-")
169
 
 
170
 
NO_NL = '\\ No newline at end of file\n'
171
 
__pychecker__="no-returnvalues"
172
 
 
173
 
def parse_line(line):
174
 
    if line.startswith("\n"):
175
 
        return ContextLine(line)
176
 
    elif line.startswith(" "):
177
 
        return ContextLine(line[1:])
178
 
    elif line.startswith("+"):
179
 
        return InsertLine(line[1:])
180
 
    elif line.startswith("-"):
181
 
        return RemoveLine(line[1:])
182
 
    else:
183
 
        raise MalformedLine("Unknown line type", line)
184
 
__pychecker__=""
185
 
 
186
 
 
187
 
class Hunk:
188
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
189
 
        self.orig_pos = orig_pos
190
 
        self.orig_range = orig_range
191
 
        self.mod_pos = mod_pos
192
 
        self.mod_range = mod_range
193
 
        self.tail = tail
194
 
        self.lines = []
195
 
 
196
 
    def get_header(self):
197
 
        if self.tail is None:
198
 
            tail_str = ''
199
 
        else:
200
 
            tail_str = ' ' + self.tail
201
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
202
 
                                                     self.orig_range),
203
 
                                      self.range_str(self.mod_pos,
204
 
                                                     self.mod_range),
205
 
                                      tail_str)
206
 
 
207
 
    def range_str(self, pos, range):
208
 
        """Return a file range, special-casing for 1-line files.
209
 
 
210
 
        :param pos: The position in the file
211
 
        :type pos: int
212
 
        :range: The range in the file
213
 
        :type range: int
214
 
        :return: a string in the format 1,4 except when range == pos == 1
215
 
        """
216
 
        if range == 1:
217
 
            return "%i" % pos
218
 
        else:
219
 
            return "%i,%i" % (pos, range)
220
 
 
221
 
    def __str__(self):
222
 
        lines = [self.get_header()]
223
 
        for line in self.lines:
224
 
            lines.append(str(line))
225
 
        return "".join(lines)
226
 
 
227
 
    def shift_to_mod(self, pos):
228
 
        if pos < self.orig_pos-1:
229
 
            return 0
230
 
        elif pos > self.orig_pos+self.orig_range:
231
 
            return self.mod_range - self.orig_range
232
 
        else:
233
 
            return self.shift_to_mod_lines(pos)
234
 
 
235
 
    def shift_to_mod_lines(self, pos):
236
 
        position = self.orig_pos-1
237
 
        shift = 0
238
 
        for line in self.lines:
239
 
            if isinstance(line, InsertLine):
240
 
                shift += 1
241
 
            elif isinstance(line, RemoveLine):
242
 
                if position == pos:
243
 
                    return None
244
 
                shift -= 1
245
 
                position += 1
246
 
            elif isinstance(line, ContextLine):
247
 
                position += 1
248
 
            if position > pos:
249
 
                break
250
 
        return shift
251
 
 
252
 
 
253
 
def iter_hunks(iter_lines):
254
 
    hunk = None
255
 
    for line in iter_lines:
256
 
        if line == "\n":
257
 
            if hunk is not None:
258
 
                yield hunk
259
 
                hunk = None
260
 
            continue
261
 
        if hunk is not None:
262
 
            yield hunk
263
 
        hunk = hunk_from_header(line)
264
 
        orig_size = 0
265
 
        mod_size = 0
266
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
267
 
            hunk_line = parse_line(iter_lines.next())
268
 
            hunk.lines.append(hunk_line)
269
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
270
 
                orig_size += 1
271
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
272
 
                mod_size += 1
273
 
    if hunk is not None:
274
 
        yield hunk
275
 
 
276
 
 
277
 
class BinaryPatch(object):
278
 
    def __init__(self, oldname, newname):
279
 
        self.oldname = oldname
280
 
        self.newname = newname
281
 
 
282
 
    def __str__(self):
283
 
        return 'Binary files %s and %s differ\n' % (self.oldname, self.newname)
284
 
 
285
 
 
286
 
class Patch(BinaryPatch):
287
 
 
288
 
    def __init__(self, oldname, newname):
289
 
        BinaryPatch.__init__(self, oldname, newname)
290
 
        self.hunks = []
291
 
 
292
 
    def __str__(self):
293
 
        ret = self.get_header()
294
 
        ret += "".join([str(h) for h in self.hunks])
295
 
        return ret
296
 
 
297
 
    def get_header(self):
298
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
299
 
 
300
 
    def stats_values(self):
301
 
        """Calculate the number of inserts and removes."""
302
 
        removes = 0
303
 
        inserts = 0
304
 
        for hunk in self.hunks:
305
 
            for line in hunk.lines:
306
 
                if isinstance(line, InsertLine):
307
 
                     inserts+=1;
308
 
                elif isinstance(line, RemoveLine):
309
 
                     removes+=1;
310
 
        return (inserts, removes, len(self.hunks))
311
 
 
312
 
    def stats_str(self):
313
 
        """Return a string of patch statistics"""
314
 
        return "%i inserts, %i removes in %i hunks" % \
315
 
            self.stats_values()
316
 
 
317
 
    def pos_in_mod(self, position):
318
 
        newpos = position
319
 
        for hunk in self.hunks:
320
 
            shift = hunk.shift_to_mod(position)
321
 
            if shift is None:
322
 
                return None
323
 
            newpos += shift
324
 
        return newpos
325
 
 
326
 
    def iter_inserted(self):
327
 
        """Iteraties through inserted lines
328
 
 
329
 
        :return: Pair of line number, line
330
 
        :rtype: iterator of (int, InsertLine)
331
 
        """
332
 
        for hunk in self.hunks:
333
 
            pos = hunk.mod_pos - 1;
334
 
            for line in hunk.lines:
335
 
                if isinstance(line, InsertLine):
336
 
                    yield (pos, line)
337
 
                    pos += 1
338
 
                if isinstance(line, ContextLine):
339
 
                    pos += 1
340
 
 
341
 
 
342
 
def parse_patch(iter_lines):
343
 
    iter_lines = iter_lines_handle_nl(iter_lines)
344
 
    try:
345
 
        (orig_name, mod_name) = get_patch_names(iter_lines)
346
 
    except BinaryFiles, e:
347
 
        return BinaryPatch(e.orig_name, e.mod_name)
348
 
    else:
349
 
        patch = Patch(orig_name, mod_name)
350
 
        for hunk in iter_hunks(iter_lines):
351
 
            patch.hunks.append(hunk)
352
 
        return patch
353
 
 
354
 
 
355
 
def iter_file_patch(iter_lines):
356
 
    regex = re.compile(binary_files_re)
357
 
    saved_lines = []
358
 
    orig_range = 0
359
 
    for line in iter_lines:
360
 
        if line.startswith('=== ') or line.startswith('*** '):
361
 
            continue
362
 
        if line.startswith('#'):
363
 
            continue
364
 
        elif orig_range > 0:
365
 
            if line.startswith('-') or line.startswith(' '):
366
 
                orig_range -= 1
367
 
        elif line.startswith('--- ') or regex.match(line):
368
 
            if len(saved_lines) > 0:
369
 
                yield saved_lines
370
 
            saved_lines = []
371
 
        elif line.startswith('@@'):
372
 
            hunk = hunk_from_header(line)
373
 
            orig_range = hunk.orig_range
374
 
        saved_lines.append(line)
375
 
    if len(saved_lines) > 0:
376
 
        yield saved_lines
377
 
 
378
 
 
379
 
def iter_lines_handle_nl(iter_lines):
380
 
    """
381
 
    Iterates through lines, ensuring that lines that originally had no
382
 
    terminating \n are produced without one.  This transformation may be
383
 
    applied at any point up until hunk line parsing, and is safe to apply
384
 
    repeatedly.
385
 
    """
386
 
    last_line = None
387
 
    for line in iter_lines:
388
 
        if line == NO_NL:
389
 
            if not last_line.endswith('\n'):
390
 
                raise AssertionError()
391
 
            last_line = last_line[:-1]
392
 
            line = None
393
 
        if last_line is not None:
394
 
            yield last_line
395
 
        last_line = line
396
 
    if last_line is not None:
397
 
        yield last_line
398
 
 
399
 
 
400
 
def parse_patches(iter_lines):
401
 
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
402
 
 
403
 
 
404
 
def difference_index(atext, btext):
405
 
    """Find the indext of the first character that differs between two texts
406
 
 
407
 
    :param atext: The first text
408
 
    :type atext: str
409
 
    :param btext: The second text
410
 
    :type str: str
411
 
    :return: The index, or None if there are no differences within the range
412
 
    :rtype: int or NoneType
413
 
    """
414
 
    length = len(atext)
415
 
    if len(btext) < length:
416
 
        length = len(btext)
417
 
    for i in range(length):
418
 
        if atext[i] != btext[i]:
419
 
            return i;
420
 
    return None
421
 
 
422
 
 
423
 
def iter_patched(orig_lines, patch_lines):
424
 
    """Iterate through a series of lines with a patch applied.
425
 
    This handles a single file, and does exact, not fuzzy patching.
426
 
    """
427
 
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
428
 
    get_patch_names(patch_lines)
429
 
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
430
 
 
431
 
 
432
 
def iter_patched_from_hunks(orig_lines, hunks):
433
 
    """Iterate through a series of lines with a patch applied.
434
 
    This handles a single file, and does exact, not fuzzy patching.
435
 
 
436
 
    :param orig_lines: The unpatched lines.
437
 
    :param hunks: An iterable of Hunk instances.
438
 
    """
439
 
    seen_patch = []
440
 
    line_no = 1
441
 
    if orig_lines is not None:
442
 
        orig_lines = iter(orig_lines)
443
 
    for hunk in hunks:
444
 
        while line_no < hunk.orig_pos:
445
 
            orig_line = orig_lines.next()
446
 
            yield orig_line
447
 
            line_no += 1
448
 
        for hunk_line in hunk.lines:
449
 
            seen_patch.append(str(hunk_line))
450
 
            if isinstance(hunk_line, InsertLine):
451
 
                yield hunk_line.contents
452
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
453
 
                orig_line = orig_lines.next()
454
 
                if orig_line != hunk_line.contents:
455
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
456
 
                if isinstance(hunk_line, ContextLine):
457
 
                    yield orig_line
458
 
                else:
459
 
                    if not isinstance(hunk_line, RemoveLine):
460
 
                        raise AssertionError(hunk_line)
461
 
                line_no += 1
462
 
    if orig_lines is not None:
463
 
        for line in orig_lines:
464
 
            yield line