~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-09-13 23:08:19 UTC
  • Revision ID: mbp@sourcefrog.net-20050913230819-6ceae96050d32faa
ignore .bzr-shelf

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004 - 2006 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
import re
18
 
 
19
 
 
20
 
class PatchSyntax(Exception):
21
 
    def __init__(self, msg):
22
 
        Exception.__init__(self, msg)
23
 
 
24
 
 
25
 
class MalformedPatchHeader(PatchSyntax):
26
 
    def __init__(self, desc, line):
27
 
        self.desc = desc
28
 
        self.line = line
29
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
30
 
        PatchSyntax.__init__(self, msg)
31
 
 
32
 
 
33
 
class MalformedHunkHeader(PatchSyntax):
34
 
    def __init__(self, desc, line):
35
 
        self.desc = desc
36
 
        self.line = line
37
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
38
 
        PatchSyntax.__init__(self, msg)
39
 
 
40
 
 
41
 
class MalformedLine(PatchSyntax):
42
 
    def __init__(self, desc, line):
43
 
        self.desc = desc
44
 
        self.line = line
45
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
46
 
        PatchSyntax.__init__(self, msg)
47
 
 
48
 
 
49
 
class PatchConflict(Exception):
50
 
    def __init__(self, line_no, orig_line, patch_line):
51
 
        orig = orig_line.rstrip('\n')
52
 
        patch = str(patch_line).rstrip('\n')
53
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
54
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
55
 
        Exception.__init__(self, msg)
56
 
 
57
 
 
58
 
def get_patch_names(iter_lines):
59
 
    try:
60
 
        line = iter_lines.next()
61
 
        if not line.startswith("--- "):
62
 
            raise MalformedPatchHeader("No orig name", line)
63
 
        else:
64
 
            orig_name = line[4:].rstrip("\n")
65
 
    except StopIteration:
66
 
        raise MalformedPatchHeader("No orig line", "")
67
 
    try:
68
 
        line = iter_lines.next()
69
 
        if not line.startswith("+++ "):
70
 
            raise PatchSyntax("No mod name")
71
 
        else:
72
 
            mod_name = line[4:].rstrip("\n")
73
 
    except StopIteration:
74
 
        raise MalformedPatchHeader("No mod line", "")
75
 
    return (orig_name, mod_name)
76
 
 
77
 
 
78
 
def parse_range(textrange):
79
 
    """Parse a patch range, handling the "1" special-case
80
 
 
81
 
    :param textrange: The text to parse
82
 
    :type textrange: str
83
 
    :return: the position and range, as a tuple
84
 
    :rtype: (int, int)
85
 
    """
86
 
    tmp = textrange.split(',')
87
 
    if len(tmp) == 1:
88
 
        pos = tmp[0]
89
 
        range = "1"
90
 
    else:
91
 
        (pos, range) = tmp
92
 
    pos = int(pos)
93
 
    range = int(range)
94
 
    return (pos, range)
95
 
 
96
 
 
97
 
def hunk_from_header(line):
98
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
99
 
    if matches is None:
100
 
        raise MalformedHunkHeader("Does not match format.", line)
101
 
    try:
102
 
        (orig, mod) = matches.group(1).split(" ")
103
 
    except (ValueError, IndexError), e:
104
 
        raise MalformedHunkHeader(str(e), line)
105
 
    if not orig.startswith('-') or not mod.startswith('+'):
106
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
107
 
    try:
108
 
        (orig_pos, orig_range) = parse_range(orig[1:])
109
 
        (mod_pos, mod_range) = parse_range(mod[1:])
110
 
    except (ValueError, IndexError), e:
111
 
        raise MalformedHunkHeader(str(e), line)
112
 
    if mod_range < 0 or orig_range < 0:
113
 
        raise MalformedHunkHeader("Hunk range is negative", line)
114
 
    tail = matches.group(3)
115
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
116
 
 
117
 
 
118
 
class HunkLine:
119
 
    def __init__(self, contents):
120
 
        self.contents = contents
121
 
 
122
 
    def get_str(self, leadchar):
123
 
        if self.contents == "\n" and leadchar == " " and False:
124
 
            return "\n"
125
 
        if not self.contents.endswith('\n'):
126
 
            terminator = '\n' + NO_NL
127
 
        else:
128
 
            terminator = ''
129
 
        return leadchar + self.contents + terminator
130
 
 
131
 
 
132
 
class ContextLine(HunkLine):
133
 
    def __init__(self, contents):
134
 
        HunkLine.__init__(self, contents)
135
 
 
136
 
    def __str__(self):
137
 
        return self.get_str(" ")
138
 
 
139
 
 
140
 
class InsertLine(HunkLine):
141
 
    def __init__(self, contents):
142
 
        HunkLine.__init__(self, contents)
143
 
 
144
 
    def __str__(self):
145
 
        return self.get_str("+")
146
 
 
147
 
 
148
 
class RemoveLine(HunkLine):
149
 
    def __init__(self, contents):
150
 
        HunkLine.__init__(self, contents)
151
 
 
152
 
    def __str__(self):
153
 
        return self.get_str("-")
154
 
 
155
 
NO_NL = '\\ No newline at end of file\n'
156
 
__pychecker__="no-returnvalues"
157
 
 
158
 
def parse_line(line):
159
 
    if line.startswith("\n"):
160
 
        return ContextLine(line)
161
 
    elif line.startswith(" "):
162
 
        return ContextLine(line[1:])
163
 
    elif line.startswith("+"):
164
 
        return InsertLine(line[1:])
165
 
    elif line.startswith("-"):
166
 
        return RemoveLine(line[1:])
167
 
    elif line == NO_NL:
168
 
        return NO_NL
169
 
    else:
170
 
        raise MalformedLine("Unknown line type", line)
171
 
__pychecker__=""
172
 
 
173
 
 
174
 
class Hunk:
175
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
176
 
        self.orig_pos = orig_pos
177
 
        self.orig_range = orig_range
178
 
        self.mod_pos = mod_pos
179
 
        self.mod_range = mod_range
180
 
        self.tail = tail
181
 
        self.lines = []
182
 
 
183
 
    def get_header(self):
184
 
        if self.tail is None:
185
 
            tail_str = ''
186
 
        else:
187
 
            tail_str = ' ' + self.tail
188
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
189
 
                                                     self.orig_range),
190
 
                                      self.range_str(self.mod_pos,
191
 
                                                     self.mod_range),
192
 
                                      tail_str)
193
 
 
194
 
    def range_str(self, pos, range):
195
 
        """Return a file range, special-casing for 1-line files.
196
 
 
197
 
        :param pos: The position in the file
198
 
        :type pos: int
199
 
        :range: The range in the file
200
 
        :type range: int
201
 
        :return: a string in the format 1,4 except when range == pos == 1
202
 
        """
203
 
        if range == 1:
204
 
            return "%i" % pos
205
 
        else:
206
 
            return "%i,%i" % (pos, range)
207
 
 
208
 
    def __str__(self):
209
 
        lines = [self.get_header()]
210
 
        for line in self.lines:
211
 
            lines.append(str(line))
212
 
        return "".join(lines)
213
 
 
214
 
    def shift_to_mod(self, pos):
215
 
        if pos < self.orig_pos-1:
216
 
            return 0
217
 
        elif pos > self.orig_pos+self.orig_range:
218
 
            return self.mod_range - self.orig_range
219
 
        else:
220
 
            return self.shift_to_mod_lines(pos)
221
 
 
222
 
    def shift_to_mod_lines(self, pos):
223
 
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
224
 
        position = self.orig_pos-1
225
 
        shift = 0
226
 
        for line in self.lines:
227
 
            if isinstance(line, InsertLine):
228
 
                shift += 1
229
 
            elif isinstance(line, RemoveLine):
230
 
                if position == pos:
231
 
                    return None
232
 
                shift -= 1
233
 
                position += 1
234
 
            elif isinstance(line, ContextLine):
235
 
                position += 1
236
 
            if position > pos:
237
 
                break
238
 
        return shift
239
 
 
240
 
 
241
 
def iter_hunks(iter_lines):
242
 
    hunk = None
243
 
    for line in iter_lines:
244
 
        if line == "\n":
245
 
            if hunk is not None:
246
 
                yield hunk
247
 
                hunk = None
248
 
            continue
249
 
        if hunk is not None:
250
 
            yield hunk
251
 
        hunk = hunk_from_header(line)
252
 
        orig_size = 0
253
 
        mod_size = 0
254
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
255
 
            hunk_line = parse_line(iter_lines.next())
256
 
            hunk.lines.append(hunk_line)
257
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
258
 
                orig_size += 1
259
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
260
 
                mod_size += 1
261
 
    if hunk is not None:
262
 
        yield hunk
263
 
 
264
 
 
265
 
class Patch:
266
 
    def __init__(self, oldname, newname):
267
 
        self.oldname = oldname
268
 
        self.newname = newname
269
 
        self.hunks = []
270
 
 
271
 
    def __str__(self):
272
 
        ret = self.get_header() 
273
 
        ret += "".join([str(h) for h in self.hunks])
274
 
        return ret
275
 
 
276
 
    def get_header(self):
277
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
278
 
 
279
 
    def stats_str(self):
280
 
        """Return a string of patch statistics"""
281
 
        removes = 0
282
 
        inserts = 0
283
 
        for hunk in self.hunks:
284
 
            for line in hunk.lines:
285
 
                if isinstance(line, InsertLine):
286
 
                     inserts+=1;
287
 
                elif isinstance(line, RemoveLine):
288
 
                     removes+=1;
289
 
        return "%i inserts, %i removes in %i hunks" % \
290
 
            (inserts, removes, len(self.hunks))
291
 
 
292
 
    def pos_in_mod(self, position):
293
 
        newpos = position
294
 
        for hunk in self.hunks:
295
 
            shift = hunk.shift_to_mod(position)
296
 
            if shift is None:
297
 
                return None
298
 
            newpos += shift
299
 
        return newpos
300
 
            
301
 
    def iter_inserted(self):
302
 
        """Iteraties through inserted lines
303
 
        
304
 
        :return: Pair of line number, line
305
 
        :rtype: iterator of (int, InsertLine)
306
 
        """
307
 
        for hunk in self.hunks:
308
 
            pos = hunk.mod_pos - 1;
309
 
            for line in hunk.lines:
310
 
                if isinstance(line, InsertLine):
311
 
                    yield (pos, line)
312
 
                    pos += 1
313
 
                if isinstance(line, ContextLine):
314
 
                    pos += 1
315
 
 
316
 
 
317
 
def parse_patch(iter_lines):
318
 
    (orig_name, mod_name) = get_patch_names(iter_lines)
319
 
    patch = Patch(orig_name, mod_name)
320
 
    for hunk in iter_hunks(iter_lines):
321
 
        patch.hunks.append(hunk)
322
 
    return patch
323
 
 
324
 
 
325
 
def iter_file_patch(iter_lines):
326
 
    saved_lines = []
327
 
    orig_range = 0
328
 
    for line in iter_lines:
329
 
        if line.startswith('=== ') or line.startswith('*** '):
330
 
            continue
331
 
        if line.startswith('#'):
332
 
            continue
333
 
        elif orig_range > 0:
334
 
            if line.startswith('-') or line.startswith(' '):
335
 
                orig_range -= 1
336
 
        elif line.startswith('--- '):
337
 
            if len(saved_lines) > 0:
338
 
                yield saved_lines
339
 
            saved_lines = []
340
 
        elif line.startswith('@@'):
341
 
            hunk = hunk_from_header(line)
342
 
            orig_range = hunk.orig_range
343
 
        saved_lines.append(line)
344
 
    if len(saved_lines) > 0:
345
 
        yield saved_lines
346
 
 
347
 
 
348
 
def iter_lines_handle_nl(iter_lines):
349
 
    """
350
 
    Iterates through lines, ensuring that lines that originally had no
351
 
    terminating \n are produced without one.  This transformation may be
352
 
    applied at any point up until hunk line parsing, and is safe to apply
353
 
    repeatedly.
354
 
    """
355
 
    last_line = None
356
 
    for line in iter_lines:
357
 
        if line == NO_NL:
358
 
            assert last_line.endswith('\n')
359
 
            last_line = last_line[:-1]
360
 
            line = None
361
 
        if last_line is not None:
362
 
            yield last_line
363
 
        last_line = line
364
 
    if last_line is not None:
365
 
        yield last_line
366
 
 
367
 
 
368
 
def parse_patches(iter_lines):
369
 
    iter_lines = iter_lines_handle_nl(iter_lines)
370
 
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
371
 
 
372
 
 
373
 
def difference_index(atext, btext):
374
 
    """Find the indext of the first character that differs between two texts
375
 
 
376
 
    :param atext: The first text
377
 
    :type atext: str
378
 
    :param btext: The second text
379
 
    :type str: str
380
 
    :return: The index, or None if there are no differences within the range
381
 
    :rtype: int or NoneType
382
 
    """
383
 
    length = len(atext)
384
 
    if len(btext) < length:
385
 
        length = len(btext)
386
 
    for i in range(length):
387
 
        if atext[i] != btext[i]:
388
 
            return i;
389
 
    return None
390
 
 
391
 
 
392
 
def iter_patched(orig_lines, patch_lines):
393
 
    """Iterate through a series of lines with a patch applied.
394
 
    This handles a single file, and does exact, not fuzzy patching.
395
 
    """
396
 
    if orig_lines is not None:
397
 
        orig_lines = orig_lines.__iter__()
398
 
    seen_patch = []
399
 
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
400
 
    get_patch_names(patch_lines)
401
 
    line_no = 1
402
 
    for hunk in iter_hunks(patch_lines):
403
 
        while line_no < hunk.orig_pos:
404
 
            orig_line = orig_lines.next()
405
 
            yield orig_line
406
 
            line_no += 1
407
 
        for hunk_line in hunk.lines:
408
 
            seen_patch.append(str(hunk_line))
409
 
            if isinstance(hunk_line, InsertLine):
410
 
                yield hunk_line.contents
411
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
412
 
                orig_line = orig_lines.next()
413
 
                if orig_line != hunk_line.contents:
414
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
415
 
                if isinstance(hunk_line, ContextLine):
416
 
                    yield orig_line
417
 
                else:
418
 
                    assert isinstance(hunk_line, RemoveLine)
419
 
                line_no += 1
420
 
    if orig_lines is not None:
421
 
        for line in orig_lines:
422
 
            yield line