~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: John Arbash Meinel
  • Author(s): Mark Hammond
  • Date: 2008-09-09 17:02:21 UTC
  • mto: This revision was merged to the branch mainline in revision 3697.
  • Revision ID: john@arbash-meinel.com-20080909170221-svim3jw2mrz0amp3
An updated transparent icon for bzr.

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004 - 2006 Aaron Bentley, Canonical Ltd
 
2
# <aaron.bentley@utoronto.ca>
 
3
#
 
4
# This program is free software; you can redistribute it and/or modify
 
5
# it under the terms of the GNU General Public License as published by
 
6
# the Free Software Foundation; either version 2 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
17
import re
 
18
 
 
19
 
 
20
class PatchSyntax(Exception):
 
21
    def __init__(self, msg):
 
22
        Exception.__init__(self, msg)
 
23
 
 
24
 
 
25
class MalformedPatchHeader(PatchSyntax):
 
26
    def __init__(self, desc, line):
 
27
        self.desc = desc
 
28
        self.line = line
 
29
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
 
30
        PatchSyntax.__init__(self, msg)
 
31
 
 
32
 
 
33
class MalformedHunkHeader(PatchSyntax):
 
34
    def __init__(self, desc, line):
 
35
        self.desc = desc
 
36
        self.line = line
 
37
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
 
38
        PatchSyntax.__init__(self, msg)
 
39
 
 
40
 
 
41
class MalformedLine(PatchSyntax):
 
42
    def __init__(self, desc, line):
 
43
        self.desc = desc
 
44
        self.line = line
 
45
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
 
46
        PatchSyntax.__init__(self, msg)
 
47
 
 
48
 
 
49
class PatchConflict(Exception):
 
50
    def __init__(self, line_no, orig_line, patch_line):
 
51
        orig = orig_line.rstrip('\n')
 
52
        patch = str(patch_line).rstrip('\n')
 
53
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
54
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
55
        Exception.__init__(self, msg)
 
56
 
 
57
 
 
58
def get_patch_names(iter_lines):
 
59
    try:
 
60
        line = iter_lines.next()
 
61
        if not line.startswith("--- "):
 
62
            raise MalformedPatchHeader("No orig name", line)
 
63
        else:
 
64
            orig_name = line[4:].rstrip("\n")
 
65
    except StopIteration:
 
66
        raise MalformedPatchHeader("No orig line", "")
 
67
    try:
 
68
        line = iter_lines.next()
 
69
        if not line.startswith("+++ "):
 
70
            raise PatchSyntax("No mod name")
 
71
        else:
 
72
            mod_name = line[4:].rstrip("\n")
 
73
    except StopIteration:
 
74
        raise MalformedPatchHeader("No mod line", "")
 
75
    return (orig_name, mod_name)
 
76
 
 
77
 
 
78
def parse_range(textrange):
 
79
    """Parse a patch range, handling the "1" special-case
 
80
 
 
81
    :param textrange: The text to parse
 
82
    :type textrange: str
 
83
    :return: the position and range, as a tuple
 
84
    :rtype: (int, int)
 
85
    """
 
86
    tmp = textrange.split(',')
 
87
    if len(tmp) == 1:
 
88
        pos = tmp[0]
 
89
        range = "1"
 
90
    else:
 
91
        (pos, range) = tmp
 
92
    pos = int(pos)
 
93
    range = int(range)
 
94
    return (pos, range)
 
95
 
 
96
 
 
97
def hunk_from_header(line):
 
98
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
 
99
    if matches is None:
 
100
        raise MalformedHunkHeader("Does not match format.", line)
 
101
    try:
 
102
        (orig, mod) = matches.group(1).split(" ")
 
103
    except (ValueError, IndexError), e:
 
104
        raise MalformedHunkHeader(str(e), line)
 
105
    if not orig.startswith('-') or not mod.startswith('+'):
 
106
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
 
107
    try:
 
108
        (orig_pos, orig_range) = parse_range(orig[1:])
 
109
        (mod_pos, mod_range) = parse_range(mod[1:])
 
110
    except (ValueError, IndexError), e:
 
111
        raise MalformedHunkHeader(str(e), line)
 
112
    if mod_range < 0 or orig_range < 0:
 
113
        raise MalformedHunkHeader("Hunk range is negative", line)
 
114
    tail = matches.group(3)
 
115
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
 
116
 
 
117
 
 
118
class HunkLine:
 
119
    def __init__(self, contents):
 
120
        self.contents = contents
 
121
 
 
122
    def get_str(self, leadchar):
 
123
        if self.contents == "\n" and leadchar == " " and False:
 
124
            return "\n"
 
125
        if not self.contents.endswith('\n'):
 
126
            terminator = '\n' + NO_NL
 
127
        else:
 
128
            terminator = ''
 
129
        return leadchar + self.contents + terminator
 
130
 
 
131
 
 
132
class ContextLine(HunkLine):
 
133
    def __init__(self, contents):
 
134
        HunkLine.__init__(self, contents)
 
135
 
 
136
    def __str__(self):
 
137
        return self.get_str(" ")
 
138
 
 
139
 
 
140
class InsertLine(HunkLine):
 
141
    def __init__(self, contents):
 
142
        HunkLine.__init__(self, contents)
 
143
 
 
144
    def __str__(self):
 
145
        return self.get_str("+")
 
146
 
 
147
 
 
148
class RemoveLine(HunkLine):
 
149
    def __init__(self, contents):
 
150
        HunkLine.__init__(self, contents)
 
151
 
 
152
    def __str__(self):
 
153
        return self.get_str("-")
 
154
 
 
155
NO_NL = '\\ No newline at end of file\n'
 
156
__pychecker__="no-returnvalues"
 
157
 
 
158
def parse_line(line):
 
159
    if line.startswith("\n"):
 
160
        return ContextLine(line)
 
161
    elif line.startswith(" "):
 
162
        return ContextLine(line[1:])
 
163
    elif line.startswith("+"):
 
164
        return InsertLine(line[1:])
 
165
    elif line.startswith("-"):
 
166
        return RemoveLine(line[1:])
 
167
    elif line == NO_NL:
 
168
        return NO_NL
 
169
    else:
 
170
        raise MalformedLine("Unknown line type", line)
 
171
__pychecker__=""
 
172
 
 
173
 
 
174
class Hunk:
 
175
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
 
176
        self.orig_pos = orig_pos
 
177
        self.orig_range = orig_range
 
178
        self.mod_pos = mod_pos
 
179
        self.mod_range = mod_range
 
180
        self.tail = tail
 
181
        self.lines = []
 
182
 
 
183
    def get_header(self):
 
184
        if self.tail is None:
 
185
            tail_str = ''
 
186
        else:
 
187
            tail_str = ' ' + self.tail
 
188
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
 
189
                                                     self.orig_range),
 
190
                                      self.range_str(self.mod_pos,
 
191
                                                     self.mod_range),
 
192
                                      tail_str)
 
193
 
 
194
    def range_str(self, pos, range):
 
195
        """Return a file range, special-casing for 1-line files.
 
196
 
 
197
        :param pos: The position in the file
 
198
        :type pos: int
 
199
        :range: The range in the file
 
200
        :type range: int
 
201
        :return: a string in the format 1,4 except when range == pos == 1
 
202
        """
 
203
        if range == 1:
 
204
            return "%i" % pos
 
205
        else:
 
206
            return "%i,%i" % (pos, range)
 
207
 
 
208
    def __str__(self):
 
209
        lines = [self.get_header()]
 
210
        for line in self.lines:
 
211
            lines.append(str(line))
 
212
        return "".join(lines)
 
213
 
 
214
    def shift_to_mod(self, pos):
 
215
        if pos < self.orig_pos-1:
 
216
            return 0
 
217
        elif pos > self.orig_pos+self.orig_range:
 
218
            return self.mod_range - self.orig_range
 
219
        else:
 
220
            return self.shift_to_mod_lines(pos)
 
221
 
 
222
    def shift_to_mod_lines(self, pos):
 
223
        position = self.orig_pos-1
 
224
        shift = 0
 
225
        for line in self.lines:
 
226
            if isinstance(line, InsertLine):
 
227
                shift += 1
 
228
            elif isinstance(line, RemoveLine):
 
229
                if position == pos:
 
230
                    return None
 
231
                shift -= 1
 
232
                position += 1
 
233
            elif isinstance(line, ContextLine):
 
234
                position += 1
 
235
            if position > pos:
 
236
                break
 
237
        return shift
 
238
 
 
239
 
 
240
def iter_hunks(iter_lines):
 
241
    hunk = None
 
242
    for line in iter_lines:
 
243
        if line == "\n":
 
244
            if hunk is not None:
 
245
                yield hunk
 
246
                hunk = None
 
247
            continue
 
248
        if hunk is not None:
 
249
            yield hunk
 
250
        hunk = hunk_from_header(line)
 
251
        orig_size = 0
 
252
        mod_size = 0
 
253
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
254
            hunk_line = parse_line(iter_lines.next())
 
255
            hunk.lines.append(hunk_line)
 
256
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
257
                orig_size += 1
 
258
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
259
                mod_size += 1
 
260
    if hunk is not None:
 
261
        yield hunk
 
262
 
 
263
 
 
264
class Patch:
 
265
    def __init__(self, oldname, newname):
 
266
        self.oldname = oldname
 
267
        self.newname = newname
 
268
        self.hunks = []
 
269
 
 
270
    def __str__(self):
 
271
        ret = self.get_header() 
 
272
        ret += "".join([str(h) for h in self.hunks])
 
273
        return ret
 
274
 
 
275
    def get_header(self):
 
276
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
277
 
 
278
    def stats_str(self):
 
279
        """Return a string of patch statistics"""
 
280
        removes = 0
 
281
        inserts = 0
 
282
        for hunk in self.hunks:
 
283
            for line in hunk.lines:
 
284
                if isinstance(line, InsertLine):
 
285
                     inserts+=1;
 
286
                elif isinstance(line, RemoveLine):
 
287
                     removes+=1;
 
288
        return "%i inserts, %i removes in %i hunks" % \
 
289
            (inserts, removes, len(self.hunks))
 
290
 
 
291
    def pos_in_mod(self, position):
 
292
        newpos = position
 
293
        for hunk in self.hunks:
 
294
            shift = hunk.shift_to_mod(position)
 
295
            if shift is None:
 
296
                return None
 
297
            newpos += shift
 
298
        return newpos
 
299
            
 
300
    def iter_inserted(self):
 
301
        """Iteraties through inserted lines
 
302
        
 
303
        :return: Pair of line number, line
 
304
        :rtype: iterator of (int, InsertLine)
 
305
        """
 
306
        for hunk in self.hunks:
 
307
            pos = hunk.mod_pos - 1;
 
308
            for line in hunk.lines:
 
309
                if isinstance(line, InsertLine):
 
310
                    yield (pos, line)
 
311
                    pos += 1
 
312
                if isinstance(line, ContextLine):
 
313
                    pos += 1
 
314
 
 
315
 
 
316
def parse_patch(iter_lines):
 
317
    (orig_name, mod_name) = get_patch_names(iter_lines)
 
318
    patch = Patch(orig_name, mod_name)
 
319
    for hunk in iter_hunks(iter_lines):
 
320
        patch.hunks.append(hunk)
 
321
    return patch
 
322
 
 
323
 
 
324
def iter_file_patch(iter_lines):
 
325
    saved_lines = []
 
326
    orig_range = 0
 
327
    for line in iter_lines:
 
328
        if line.startswith('=== ') or line.startswith('*** '):
 
329
            continue
 
330
        if line.startswith('#'):
 
331
            continue
 
332
        elif orig_range > 0:
 
333
            if line.startswith('-') or line.startswith(' '):
 
334
                orig_range -= 1
 
335
        elif line.startswith('--- '):
 
336
            if len(saved_lines) > 0:
 
337
                yield saved_lines
 
338
            saved_lines = []
 
339
        elif line.startswith('@@'):
 
340
            hunk = hunk_from_header(line)
 
341
            orig_range = hunk.orig_range
 
342
        saved_lines.append(line)
 
343
    if len(saved_lines) > 0:
 
344
        yield saved_lines
 
345
 
 
346
 
 
347
def iter_lines_handle_nl(iter_lines):
 
348
    """
 
349
    Iterates through lines, ensuring that lines that originally had no
 
350
    terminating \n are produced without one.  This transformation may be
 
351
    applied at any point up until hunk line parsing, and is safe to apply
 
352
    repeatedly.
 
353
    """
 
354
    last_line = None
 
355
    for line in iter_lines:
 
356
        if line == NO_NL:
 
357
            if not last_line.endswith('\n'):
 
358
                raise AssertionError()
 
359
            last_line = last_line[:-1]
 
360
            line = None
 
361
        if last_line is not None:
 
362
            yield last_line
 
363
        last_line = line
 
364
    if last_line is not None:
 
365
        yield last_line
 
366
 
 
367
 
 
368
def parse_patches(iter_lines):
 
369
    iter_lines = iter_lines_handle_nl(iter_lines)
 
370
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
 
371
 
 
372
 
 
373
def difference_index(atext, btext):
 
374
    """Find the indext of the first character that differs between two texts
 
375
 
 
376
    :param atext: The first text
 
377
    :type atext: str
 
378
    :param btext: The second text
 
379
    :type str: str
 
380
    :return: The index, or None if there are no differences within the range
 
381
    :rtype: int or NoneType
 
382
    """
 
383
    length = len(atext)
 
384
    if len(btext) < length:
 
385
        length = len(btext)
 
386
    for i in range(length):
 
387
        if atext[i] != btext[i]:
 
388
            return i;
 
389
    return None
 
390
 
 
391
 
 
392
def iter_patched(orig_lines, patch_lines):
 
393
    """Iterate through a series of lines with a patch applied.
 
394
    This handles a single file, and does exact, not fuzzy patching.
 
395
    """
 
396
    if orig_lines is not None:
 
397
        orig_lines = orig_lines.__iter__()
 
398
    seen_patch = []
 
399
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
 
400
    get_patch_names(patch_lines)
 
401
    line_no = 1
 
402
    for hunk in iter_hunks(patch_lines):
 
403
        while line_no < hunk.orig_pos:
 
404
            orig_line = orig_lines.next()
 
405
            yield orig_line
 
406
            line_no += 1
 
407
        for hunk_line in hunk.lines:
 
408
            seen_patch.append(str(hunk_line))
 
409
            if isinstance(hunk_line, InsertLine):
 
410
                yield hunk_line.contents
 
411
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
412
                orig_line = orig_lines.next()
 
413
                if orig_line != hunk_line.contents:
 
414
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
415
                if isinstance(hunk_line, ContextLine):
 
416
                    yield orig_line
 
417
                else:
 
418
                    if not isinstance(hunk_line, RemoveLine):
 
419
                        raise AssertionError(hunk_line)
 
420
                line_no += 1
 
421
    if orig_lines is not None:
 
422
        for line in orig_lines:
 
423
            yield line