~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-10-04 11:13:33 UTC
  • mto: (1185.13.3)
  • mto: This revision was merged to the branch mainline in revision 1403.
  • Revision ID: mbp@sourcefrog.net-20051004111332-f7b8a6bd41b9fe22
- tweak capture_tree formatting

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004 - 2006 Aaron Bentley
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
#    This program is free software; you can redistribute it and/or modify
5
 
#    it under the terms of the GNU General Public License as published by
6
 
#    the Free Software Foundation; either version 2 of the License, or
7
 
#    (at your option) any later version.
8
 
#
9
 
#    This program is distributed in the hope that it will be useful,
10
 
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
#    GNU General Public License for more details.
13
 
#
14
 
#    You should have received a copy of the GNU General Public License
15
 
#    along with this program; if not, write to the Free Software
16
 
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
 
18
 
 
19
 
class PatchSyntax(Exception):
20
 
    def __init__(self, msg):
21
 
        Exception.__init__(self, msg)
22
 
 
23
 
 
24
 
class MalformedPatchHeader(PatchSyntax):
25
 
    def __init__(self, desc, line):
26
 
        self.desc = desc
27
 
        self.line = line
28
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
29
 
        PatchSyntax.__init__(self, msg)
30
 
 
31
 
 
32
 
class MalformedHunkHeader(PatchSyntax):
33
 
    def __init__(self, desc, line):
34
 
        self.desc = desc
35
 
        self.line = line
36
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
37
 
        PatchSyntax.__init__(self, msg)
38
 
 
39
 
 
40
 
class MalformedLine(PatchSyntax):
41
 
    def __init__(self, desc, line):
42
 
        self.desc = desc
43
 
        self.line = line
44
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
45
 
        PatchSyntax.__init__(self, msg)
46
 
 
47
 
 
48
 
class PatchConflict(Exception):
49
 
    def __init__(self, line_no, orig_line, patch_line):
50
 
        orig = orig_line.rstrip('\n')
51
 
        patch = str(patch_line).rstrip('\n')
52
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
53
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
54
 
        Exception.__init__(self, msg)
55
 
 
56
 
 
57
 
def get_patch_names(iter_lines):
58
 
    try:
59
 
        line = iter_lines.next()
60
 
        if not line.startswith("--- "):
61
 
            raise MalformedPatchHeader("No orig name", line)
62
 
        else:
63
 
            orig_name = line[4:].rstrip("\n")
64
 
    except StopIteration:
65
 
        raise MalformedPatchHeader("No orig line", "")
66
 
    try:
67
 
        line = iter_lines.next()
68
 
        if not line.startswith("+++ "):
69
 
            raise PatchSyntax("No mod name")
70
 
        else:
71
 
            mod_name = line[4:].rstrip("\n")
72
 
    except StopIteration:
73
 
        raise MalformedPatchHeader("No mod line", "")
74
 
    return (orig_name, mod_name)
75
 
 
76
 
 
77
 
def parse_range(textrange):
78
 
    """Parse a patch range, handling the "1" special-case
79
 
 
80
 
    :param textrange: The text to parse
81
 
    :type textrange: str
82
 
    :return: the position and range, as a tuple
83
 
    :rtype: (int, int)
84
 
    """
85
 
    tmp = textrange.split(',')
86
 
    if len(tmp) == 1:
87
 
        pos = tmp[0]
88
 
        range = "1"
89
 
    else:
90
 
        (pos, range) = tmp
91
 
    pos = int(pos)
92
 
    range = int(range)
93
 
    return (pos, range)
94
 
 
95
 
 
96
 
def hunk_from_header(line):
97
 
    if not line.startswith("@@") or not line.endswith("@@\n") \
98
 
        or not len(line) > 4:
99
 
        raise MalformedHunkHeader("Does not start and end with @@.", line)
100
 
    try:
101
 
        (orig, mod) = line[3:-4].split(" ")
102
 
    except Exception, e:
103
 
        raise MalformedHunkHeader(str(e), line)
104
 
    if not orig.startswith('-') or not mod.startswith('+'):
105
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
106
 
    try:
107
 
        (orig_pos, orig_range) = parse_range(orig[1:])
108
 
        (mod_pos, mod_range) = parse_range(mod[1:])
109
 
    except Exception, e:
110
 
        raise MalformedHunkHeader(str(e), line)
111
 
    if mod_range < 0 or orig_range < 0:
112
 
        raise MalformedHunkHeader("Hunk range is negative", line)
113
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
114
 
 
115
 
 
116
 
class HunkLine:
117
 
    def __init__(self, contents):
118
 
        self.contents = contents
119
 
 
120
 
    def get_str(self, leadchar):
121
 
        if self.contents == "\n" and leadchar == " " and False:
122
 
            return "\n"
123
 
        if not self.contents.endswith('\n'):
124
 
            terminator = '\n' + NO_NL
125
 
        else:
126
 
            terminator = ''
127
 
        return leadchar + self.contents + terminator
128
 
 
129
 
 
130
 
class ContextLine(HunkLine):
131
 
    def __init__(self, contents):
132
 
        HunkLine.__init__(self, contents)
133
 
 
134
 
    def __str__(self):
135
 
        return self.get_str(" ")
136
 
 
137
 
 
138
 
class InsertLine(HunkLine):
139
 
    def __init__(self, contents):
140
 
        HunkLine.__init__(self, contents)
141
 
 
142
 
    def __str__(self):
143
 
        return self.get_str("+")
144
 
 
145
 
 
146
 
class RemoveLine(HunkLine):
147
 
    def __init__(self, contents):
148
 
        HunkLine.__init__(self, contents)
149
 
 
150
 
    def __str__(self):
151
 
        return self.get_str("-")
152
 
 
153
 
NO_NL = '\\ No newline at end of file\n'
154
 
__pychecker__="no-returnvalues"
155
 
 
156
 
def parse_line(line):
157
 
    if line.startswith("\n"):
158
 
        return ContextLine(line)
159
 
    elif line.startswith(" "):
160
 
        return ContextLine(line[1:])
161
 
    elif line.startswith("+"):
162
 
        return InsertLine(line[1:])
163
 
    elif line.startswith("-"):
164
 
        return RemoveLine(line[1:])
165
 
    elif line == NO_NL:
166
 
        return NO_NL
167
 
    else:
168
 
        raise MalformedLine("Unknown line type", line)
169
 
__pychecker__=""
170
 
 
171
 
 
172
 
class Hunk:
173
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
174
 
        self.orig_pos = orig_pos
175
 
        self.orig_range = orig_range
176
 
        self.mod_pos = mod_pos
177
 
        self.mod_range = mod_range
178
 
        self.lines = []
179
 
 
180
 
    def get_header(self):
181
 
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
182
 
                                                   self.orig_range),
183
 
                                    self.range_str(self.mod_pos, 
184
 
                                                   self.mod_range))
185
 
 
186
 
    def range_str(self, pos, range):
187
 
        """Return a file range, special-casing for 1-line files.
188
 
 
189
 
        :param pos: The position in the file
190
 
        :type pos: int
191
 
        :range: The range in the file
192
 
        :type range: int
193
 
        :return: a string in the format 1,4 except when range == pos == 1
194
 
        """
195
 
        if range == 1:
196
 
            return "%i" % pos
197
 
        else:
198
 
            return "%i,%i" % (pos, range)
199
 
 
200
 
    def __str__(self):
201
 
        lines = [self.get_header()]
202
 
        for line in self.lines:
203
 
            lines.append(str(line))
204
 
        return "".join(lines)
205
 
 
206
 
    def shift_to_mod(self, pos):
207
 
        if pos < self.orig_pos-1:
208
 
            return 0
209
 
        elif pos > self.orig_pos+self.orig_range:
210
 
            return self.mod_range - self.orig_range
211
 
        else:
212
 
            return self.shift_to_mod_lines(pos)
213
 
 
214
 
    def shift_to_mod_lines(self, pos):
215
 
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
216
 
        position = self.orig_pos-1
217
 
        shift = 0
218
 
        for line in self.lines:
219
 
            if isinstance(line, InsertLine):
220
 
                shift += 1
221
 
            elif isinstance(line, RemoveLine):
222
 
                if position == pos:
223
 
                    return None
224
 
                shift -= 1
225
 
                position += 1
226
 
            elif isinstance(line, ContextLine):
227
 
                position += 1
228
 
            if position > pos:
229
 
                break
230
 
        return shift
231
 
 
232
 
 
233
 
def iter_hunks(iter_lines):
234
 
    hunk = None
235
 
    for line in iter_lines:
236
 
        if line == "\n":
237
 
            if hunk is not None:
238
 
                yield hunk
239
 
                hunk = None
240
 
            continue
241
 
        if hunk is not None:
242
 
            yield hunk
243
 
        hunk = hunk_from_header(line)
244
 
        orig_size = 0
245
 
        mod_size = 0
246
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
247
 
            hunk_line = parse_line(iter_lines.next())
248
 
            hunk.lines.append(hunk_line)
249
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
250
 
                orig_size += 1
251
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
252
 
                mod_size += 1
253
 
    if hunk is not None:
254
 
        yield hunk
255
 
 
256
 
 
257
 
class Patch:
258
 
    def __init__(self, oldname, newname):
259
 
        self.oldname = oldname
260
 
        self.newname = newname
261
 
        self.hunks = []
262
 
 
263
 
    def __str__(self):
264
 
        ret = self.get_header() 
265
 
        ret += "".join([str(h) for h in self.hunks])
266
 
        return ret
267
 
 
268
 
    def get_header(self):
269
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
270
 
 
271
 
    def stats_str(self):
272
 
        """Return a string of patch statistics"""
273
 
        removes = 0
274
 
        inserts = 0
275
 
        for hunk in self.hunks:
276
 
            for line in hunk.lines:
277
 
                if isinstance(line, InsertLine):
278
 
                     inserts+=1;
279
 
                elif isinstance(line, RemoveLine):
280
 
                     removes+=1;
281
 
        return "%i inserts, %i removes in %i hunks" % \
282
 
            (inserts, removes, len(self.hunks))
283
 
 
284
 
    def pos_in_mod(self, position):
285
 
        newpos = position
286
 
        for hunk in self.hunks:
287
 
            shift = hunk.shift_to_mod(position)
288
 
            if shift is None:
289
 
                return None
290
 
            newpos += shift
291
 
        return newpos
292
 
            
293
 
    def iter_inserted(self):
294
 
        """Iteraties through inserted lines
295
 
        
296
 
        :return: Pair of line number, line
297
 
        :rtype: iterator of (int, InsertLine)
298
 
        """
299
 
        for hunk in self.hunks:
300
 
            pos = hunk.mod_pos - 1;
301
 
            for line in hunk.lines:
302
 
                if isinstance(line, InsertLine):
303
 
                    yield (pos, line)
304
 
                    pos += 1
305
 
                if isinstance(line, ContextLine):
306
 
                    pos += 1
307
 
 
308
 
 
309
 
def parse_patch(iter_lines):
310
 
    (orig_name, mod_name) = get_patch_names(iter_lines)
311
 
    patch = Patch(orig_name, mod_name)
312
 
    for hunk in iter_hunks(iter_lines):
313
 
        patch.hunks.append(hunk)
314
 
    return patch
315
 
 
316
 
 
317
 
def iter_file_patch(iter_lines):
318
 
    saved_lines = []
319
 
    for line in iter_lines:
320
 
        if line.startswith('=== ') or line.startswith('*** '):
321
 
            continue
322
 
        if line.startswith('#'):
323
 
            continue
324
 
        elif line.startswith('--- '):
325
 
            if len(saved_lines) > 0:
326
 
                yield saved_lines
327
 
            saved_lines = []
328
 
        saved_lines.append(line)
329
 
    if len(saved_lines) > 0:
330
 
        yield saved_lines
331
 
 
332
 
 
333
 
def iter_lines_handle_nl(iter_lines):
334
 
    """
335
 
    Iterates through lines, ensuring that lines that originally had no
336
 
    terminating \n are produced without one.  This transformation may be
337
 
    applied at any point up until hunk line parsing, and is safe to apply
338
 
    repeatedly.
339
 
    """
340
 
    last_line = None
341
 
    for line in iter_lines:
342
 
        if line == NO_NL:
343
 
            assert last_line.endswith('\n')
344
 
            last_line = last_line[:-1]
345
 
            line = None
346
 
        if last_line is not None:
347
 
            yield last_line
348
 
        last_line = line
349
 
    if last_line is not None:
350
 
        yield last_line
351
 
 
352
 
 
353
 
def parse_patches(iter_lines):
354
 
    iter_lines = iter_lines_handle_nl(iter_lines)
355
 
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
356
 
 
357
 
 
358
 
def difference_index(atext, btext):
359
 
    """Find the indext of the first character that differs between two texts
360
 
 
361
 
    :param atext: The first text
362
 
    :type atext: str
363
 
    :param btext: The second text
364
 
    :type str: str
365
 
    :return: The index, or None if there are no differences within the range
366
 
    :rtype: int or NoneType
367
 
    """
368
 
    length = len(atext)
369
 
    if len(btext) < length:
370
 
        length = len(btext)
371
 
    for i in range(length):
372
 
        if atext[i] != btext[i]:
373
 
            return i;
374
 
    return None
375
 
 
376
 
 
377
 
def iter_patched(orig_lines, patch_lines):
378
 
    """Iterate through a series of lines with a patch applied.
379
 
    This handles a single file, and does exact, not fuzzy patching.
380
 
    """
381
 
    if orig_lines is not None:
382
 
        orig_lines = orig_lines.__iter__()
383
 
    seen_patch = []
384
 
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
385
 
    get_patch_names(patch_lines)
386
 
    line_no = 1
387
 
    for hunk in iter_hunks(patch_lines):
388
 
        while line_no < hunk.orig_pos:
389
 
            orig_line = orig_lines.next()
390
 
            yield orig_line
391
 
            line_no += 1
392
 
        for hunk_line in hunk.lines:
393
 
            seen_patch.append(str(hunk_line))
394
 
            if isinstance(hunk_line, InsertLine):
395
 
                yield hunk_line.contents
396
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
397
 
                orig_line = orig_lines.next()
398
 
                if orig_line != hunk_line.contents:
399
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
400
 
                if isinstance(hunk_line, ContextLine):
401
 
                    yield orig_line
402
 
                else:
403
 
                    assert isinstance(hunk_line, RemoveLine)
404
 
                line_no += 1
405
 
    if orig_lines is not None:
406
 
        for line in orig_lines:
407
 
            yield line