~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-08-03 14:16:04 UTC
  • Revision ID: mbp@sourcefrog.net-20050803141604-b69a03512e094f37
- better summary help screen

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004 - 2006, 2008 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
 
18
 
 
19
 
class PatchSyntax(Exception):
20
 
    def __init__(self, msg):
21
 
        Exception.__init__(self, msg)
22
 
 
23
 
 
24
 
class MalformedPatchHeader(PatchSyntax):
25
 
    def __init__(self, desc, line):
26
 
        self.desc = desc
27
 
        self.line = line
28
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
29
 
        PatchSyntax.__init__(self, msg)
30
 
 
31
 
 
32
 
class MalformedHunkHeader(PatchSyntax):
33
 
    def __init__(self, desc, line):
34
 
        self.desc = desc
35
 
        self.line = line
36
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
37
 
        PatchSyntax.__init__(self, msg)
38
 
 
39
 
 
40
 
class MalformedLine(PatchSyntax):
41
 
    def __init__(self, desc, line):
42
 
        self.desc = desc
43
 
        self.line = line
44
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
45
 
        PatchSyntax.__init__(self, msg)
46
 
 
47
 
 
48
 
class PatchConflict(Exception):
49
 
    def __init__(self, line_no, orig_line, patch_line):
50
 
        orig = orig_line.rstrip('\n')
51
 
        patch = str(patch_line).rstrip('\n')
52
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
53
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
54
 
        Exception.__init__(self, msg)
55
 
 
56
 
 
57
 
def get_patch_names(iter_lines):
58
 
    try:
59
 
        line = iter_lines.next()
60
 
        if not line.startswith("--- "):
61
 
            raise MalformedPatchHeader("No orig name", line)
62
 
        else:
63
 
            orig_name = line[4:].rstrip("\n")
64
 
    except StopIteration:
65
 
        raise MalformedPatchHeader("No orig line", "")
66
 
    try:
67
 
        line = iter_lines.next()
68
 
        if not line.startswith("+++ "):
69
 
            raise PatchSyntax("No mod name")
70
 
        else:
71
 
            mod_name = line[4:].rstrip("\n")
72
 
    except StopIteration:
73
 
        raise MalformedPatchHeader("No mod line", "")
74
 
    return (orig_name, mod_name)
75
 
 
76
 
 
77
 
def parse_range(textrange):
78
 
    """Parse a patch range, handling the "1" special-case
79
 
 
80
 
    :param textrange: The text to parse
81
 
    :type textrange: str
82
 
    :return: the position and range, as a tuple
83
 
    :rtype: (int, int)
84
 
    """
85
 
    tmp = textrange.split(',')
86
 
    if len(tmp) == 1:
87
 
        pos = tmp[0]
88
 
        range = "1"
89
 
    else:
90
 
        (pos, range) = tmp
91
 
    pos = int(pos)
92
 
    range = int(range)
93
 
    return (pos, range)
94
 
 
95
 
 
96
 
def hunk_from_header(line):
97
 
    import re
98
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
99
 
    if matches is None:
100
 
        raise MalformedHunkHeader("Does not match format.", line)
101
 
    try:
102
 
        (orig, mod) = matches.group(1).split(" ")
103
 
    except (ValueError, IndexError), e:
104
 
        raise MalformedHunkHeader(str(e), line)
105
 
    if not orig.startswith('-') or not mod.startswith('+'):
106
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
107
 
    try:
108
 
        (orig_pos, orig_range) = parse_range(orig[1:])
109
 
        (mod_pos, mod_range) = parse_range(mod[1:])
110
 
    except (ValueError, IndexError), e:
111
 
        raise MalformedHunkHeader(str(e), line)
112
 
    if mod_range < 0 or orig_range < 0:
113
 
        raise MalformedHunkHeader("Hunk range is negative", line)
114
 
    tail = matches.group(3)
115
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
116
 
 
117
 
 
118
 
class HunkLine:
119
 
    def __init__(self, contents):
120
 
        self.contents = contents
121
 
 
122
 
    def get_str(self, leadchar):
123
 
        if self.contents == "\n" and leadchar == " " and False:
124
 
            return "\n"
125
 
        if not self.contents.endswith('\n'):
126
 
            terminator = '\n' + NO_NL
127
 
        else:
128
 
            terminator = ''
129
 
        return leadchar + self.contents + terminator
130
 
 
131
 
 
132
 
class ContextLine(HunkLine):
133
 
    def __init__(self, contents):
134
 
        HunkLine.__init__(self, contents)
135
 
 
136
 
    def __str__(self):
137
 
        return self.get_str(" ")
138
 
 
139
 
 
140
 
class InsertLine(HunkLine):
141
 
    def __init__(self, contents):
142
 
        HunkLine.__init__(self, contents)
143
 
 
144
 
    def __str__(self):
145
 
        return self.get_str("+")
146
 
 
147
 
 
148
 
class RemoveLine(HunkLine):
149
 
    def __init__(self, contents):
150
 
        HunkLine.__init__(self, contents)
151
 
 
152
 
    def __str__(self):
153
 
        return self.get_str("-")
154
 
 
155
 
NO_NL = '\\ No newline at end of file\n'
156
 
__pychecker__="no-returnvalues"
157
 
 
158
 
def parse_line(line):
159
 
    if line.startswith("\n"):
160
 
        return ContextLine(line)
161
 
    elif line.startswith(" "):
162
 
        return ContextLine(line[1:])
163
 
    elif line.startswith("+"):
164
 
        return InsertLine(line[1:])
165
 
    elif line.startswith("-"):
166
 
        return RemoveLine(line[1:])
167
 
    elif line == NO_NL:
168
 
        return NO_NL
169
 
    else:
170
 
        raise MalformedLine("Unknown line type", line)
171
 
__pychecker__=""
172
 
 
173
 
 
174
 
class Hunk:
175
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
176
 
        self.orig_pos = orig_pos
177
 
        self.orig_range = orig_range
178
 
        self.mod_pos = mod_pos
179
 
        self.mod_range = mod_range
180
 
        self.tail = tail
181
 
        self.lines = []
182
 
 
183
 
    def get_header(self):
184
 
        if self.tail is None:
185
 
            tail_str = ''
186
 
        else:
187
 
            tail_str = ' ' + self.tail
188
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
189
 
                                                     self.orig_range),
190
 
                                      self.range_str(self.mod_pos,
191
 
                                                     self.mod_range),
192
 
                                      tail_str)
193
 
 
194
 
    def range_str(self, pos, range):
195
 
        """Return a file range, special-casing for 1-line files.
196
 
 
197
 
        :param pos: The position in the file
198
 
        :type pos: int
199
 
        :range: The range in the file
200
 
        :type range: int
201
 
        :return: a string in the format 1,4 except when range == pos == 1
202
 
        """
203
 
        if range == 1:
204
 
            return "%i" % pos
205
 
        else:
206
 
            return "%i,%i" % (pos, range)
207
 
 
208
 
    def __str__(self):
209
 
        lines = [self.get_header()]
210
 
        for line in self.lines:
211
 
            lines.append(str(line))
212
 
        return "".join(lines)
213
 
 
214
 
    def shift_to_mod(self, pos):
215
 
        if pos < self.orig_pos-1:
216
 
            return 0
217
 
        elif pos > self.orig_pos+self.orig_range:
218
 
            return self.mod_range - self.orig_range
219
 
        else:
220
 
            return self.shift_to_mod_lines(pos)
221
 
 
222
 
    def shift_to_mod_lines(self, pos):
223
 
        position = self.orig_pos-1
224
 
        shift = 0
225
 
        for line in self.lines:
226
 
            if isinstance(line, InsertLine):
227
 
                shift += 1
228
 
            elif isinstance(line, RemoveLine):
229
 
                if position == pos:
230
 
                    return None
231
 
                shift -= 1
232
 
                position += 1
233
 
            elif isinstance(line, ContextLine):
234
 
                position += 1
235
 
            if position > pos:
236
 
                break
237
 
        return shift
238
 
 
239
 
 
240
 
def iter_hunks(iter_lines):
241
 
    hunk = None
242
 
    for line in iter_lines:
243
 
        if line == "\n":
244
 
            if hunk is not None:
245
 
                yield hunk
246
 
                hunk = None
247
 
            continue
248
 
        if hunk is not None:
249
 
            yield hunk
250
 
        hunk = hunk_from_header(line)
251
 
        orig_size = 0
252
 
        mod_size = 0
253
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
254
 
            hunk_line = parse_line(iter_lines.next())
255
 
            hunk.lines.append(hunk_line)
256
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
257
 
                orig_size += 1
258
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
259
 
                mod_size += 1
260
 
    if hunk is not None:
261
 
        yield hunk
262
 
 
263
 
 
264
 
class Patch:
265
 
    def __init__(self, oldname, newname):
266
 
        self.oldname = oldname
267
 
        self.newname = newname
268
 
        self.hunks = []
269
 
 
270
 
    def __str__(self):
271
 
        ret = self.get_header() 
272
 
        ret += "".join([str(h) for h in self.hunks])
273
 
        return ret
274
 
 
275
 
    def get_header(self):
276
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
277
 
 
278
 
    def stats_str(self):
279
 
        """Return a string of patch statistics"""
280
 
        removes = 0
281
 
        inserts = 0
282
 
        for hunk in self.hunks:
283
 
            for line in hunk.lines:
284
 
                if isinstance(line, InsertLine):
285
 
                     inserts+=1;
286
 
                elif isinstance(line, RemoveLine):
287
 
                     removes+=1;
288
 
        return "%i inserts, %i removes in %i hunks" % \
289
 
            (inserts, removes, len(self.hunks))
290
 
 
291
 
    def pos_in_mod(self, position):
292
 
        newpos = position
293
 
        for hunk in self.hunks:
294
 
            shift = hunk.shift_to_mod(position)
295
 
            if shift is None:
296
 
                return None
297
 
            newpos += shift
298
 
        return newpos
299
 
            
300
 
    def iter_inserted(self):
301
 
        """Iteraties through inserted lines
302
 
        
303
 
        :return: Pair of line number, line
304
 
        :rtype: iterator of (int, InsertLine)
305
 
        """
306
 
        for hunk in self.hunks:
307
 
            pos = hunk.mod_pos - 1;
308
 
            for line in hunk.lines:
309
 
                if isinstance(line, InsertLine):
310
 
                    yield (pos, line)
311
 
                    pos += 1
312
 
                if isinstance(line, ContextLine):
313
 
                    pos += 1
314
 
 
315
 
 
316
 
def parse_patch(iter_lines):
317
 
    (orig_name, mod_name) = get_patch_names(iter_lines)
318
 
    patch = Patch(orig_name, mod_name)
319
 
    for hunk in iter_hunks(iter_lines):
320
 
        patch.hunks.append(hunk)
321
 
    return patch
322
 
 
323
 
 
324
 
def iter_file_patch(iter_lines):
325
 
    saved_lines = []
326
 
    orig_range = 0
327
 
    for line in iter_lines:
328
 
        if line.startswith('=== ') or line.startswith('*** '):
329
 
            continue
330
 
        if line.startswith('#'):
331
 
            continue
332
 
        elif orig_range > 0:
333
 
            if line.startswith('-') or line.startswith(' '):
334
 
                orig_range -= 1
335
 
        elif line.startswith('--- '):
336
 
            if len(saved_lines) > 0:
337
 
                yield saved_lines
338
 
            saved_lines = []
339
 
        elif line.startswith('@@'):
340
 
            hunk = hunk_from_header(line)
341
 
            orig_range = hunk.orig_range
342
 
        saved_lines.append(line)
343
 
    if len(saved_lines) > 0:
344
 
        yield saved_lines
345
 
 
346
 
 
347
 
def iter_lines_handle_nl(iter_lines):
348
 
    """
349
 
    Iterates through lines, ensuring that lines that originally had no
350
 
    terminating \n are produced without one.  This transformation may be
351
 
    applied at any point up until hunk line parsing, and is safe to apply
352
 
    repeatedly.
353
 
    """
354
 
    last_line = None
355
 
    for line in iter_lines:
356
 
        if line == NO_NL:
357
 
            if not last_line.endswith('\n'):
358
 
                raise AssertionError()
359
 
            last_line = last_line[:-1]
360
 
            line = None
361
 
        if last_line is not None:
362
 
            yield last_line
363
 
        last_line = line
364
 
    if last_line is not None:
365
 
        yield last_line
366
 
 
367
 
 
368
 
def parse_patches(iter_lines):
369
 
    iter_lines = iter_lines_handle_nl(iter_lines)
370
 
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
371
 
 
372
 
 
373
 
def difference_index(atext, btext):
374
 
    """Find the indext of the first character that differs between two texts
375
 
 
376
 
    :param atext: The first text
377
 
    :type atext: str
378
 
    :param btext: The second text
379
 
    :type str: str
380
 
    :return: The index, or None if there are no differences within the range
381
 
    :rtype: int or NoneType
382
 
    """
383
 
    length = len(atext)
384
 
    if len(btext) < length:
385
 
        length = len(btext)
386
 
    for i in range(length):
387
 
        if atext[i] != btext[i]:
388
 
            return i;
389
 
    return None
390
 
 
391
 
 
392
 
def iter_patched(orig_lines, patch_lines):
393
 
    """Iterate through a series of lines with a patch applied.
394
 
    This handles a single file, and does exact, not fuzzy patching.
395
 
    """
396
 
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
397
 
    get_patch_names(patch_lines)
398
 
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
399
 
 
400
 
 
401
 
def iter_patched_from_hunks(orig_lines, hunks):
402
 
    """Iterate through a series of lines with a patch applied.
403
 
    This handles a single file, and does exact, not fuzzy patching.
404
 
 
405
 
    :param orig_lines: The unpatched lines.
406
 
    :param hunks: An iterable of Hunk instances.
407
 
    """
408
 
    seen_patch = []
409
 
    line_no = 1
410
 
    if orig_lines is not None:
411
 
        orig_lines = iter(orig_lines)
412
 
    for hunk in hunks:
413
 
        while line_no < hunk.orig_pos:
414
 
            orig_line = orig_lines.next()
415
 
            yield orig_line
416
 
            line_no += 1
417
 
        for hunk_line in hunk.lines:
418
 
            seen_patch.append(str(hunk_line))
419
 
            if isinstance(hunk_line, InsertLine):
420
 
                yield hunk_line.contents
421
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
422
 
                orig_line = orig_lines.next()
423
 
                if orig_line != hunk_line.contents:
424
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
425
 
                if isinstance(hunk_line, ContextLine):
426
 
                    yield orig_line
427
 
                else:
428
 
                    if not isinstance(hunk_line, RemoveLine):
429
 
                        raise AssertionError(hunk_line)
430
 
                line_no += 1
431
 
    if orig_lines is not None:
432
 
        for line in orig_lines:
433
 
            yield line