~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-06-06 11:53:29 UTC
  • Revision ID: mbp@sourcefrog.net-20050606115329-1596352add25bffd
- merge aaron's updated merge/pull code

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2004 - 2006 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
 
18
 
 
19
 
class PatchSyntax(Exception):
20
 
    def __init__(self, msg):
21
 
        Exception.__init__(self, msg)
22
 
 
23
 
 
24
 
class MalformedPatchHeader(PatchSyntax):
25
 
    def __init__(self, desc, line):
26
 
        self.desc = desc
27
 
        self.line = line
28
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
29
 
        PatchSyntax.__init__(self, msg)
30
 
 
31
 
 
32
 
class MalformedHunkHeader(PatchSyntax):
33
 
    def __init__(self, desc, line):
34
 
        self.desc = desc
35
 
        self.line = line
36
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
37
 
        PatchSyntax.__init__(self, msg)
38
 
 
39
 
 
40
 
class MalformedLine(PatchSyntax):
41
 
    def __init__(self, desc, line):
42
 
        self.desc = desc
43
 
        self.line = line
44
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
45
 
        PatchSyntax.__init__(self, msg)
46
 
 
47
 
 
48
 
class PatchConflict(Exception):
49
 
    def __init__(self, line_no, orig_line, patch_line):
50
 
        orig = orig_line.rstrip('\n')
51
 
        patch = str(patch_line).rstrip('\n')
52
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
53
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
54
 
        Exception.__init__(self, msg)
55
 
 
56
 
 
57
 
def get_patch_names(iter_lines):
58
 
    try:
59
 
        line = iter_lines.next()
60
 
        if not line.startswith("--- "):
61
 
            raise MalformedPatchHeader("No orig name", line)
62
 
        else:
63
 
            orig_name = line[4:].rstrip("\n")
64
 
    except StopIteration:
65
 
        raise MalformedPatchHeader("No orig line", "")
66
 
    try:
67
 
        line = iter_lines.next()
68
 
        if not line.startswith("+++ "):
69
 
            raise PatchSyntax("No mod name")
70
 
        else:
71
 
            mod_name = line[4:].rstrip("\n")
72
 
    except StopIteration:
73
 
        raise MalformedPatchHeader("No mod line", "")
74
 
    return (orig_name, mod_name)
75
 
 
76
 
 
77
 
def parse_range(textrange):
78
 
    """Parse a patch range, handling the "1" special-case
79
 
 
80
 
    :param textrange: The text to parse
81
 
    :type textrange: str
82
 
    :return: the position and range, as a tuple
83
 
    :rtype: (int, int)
84
 
    """
85
 
    tmp = textrange.split(',')
86
 
    if len(tmp) == 1:
87
 
        pos = tmp[0]
88
 
        range = "1"
89
 
    else:
90
 
        (pos, range) = tmp
91
 
    pos = int(pos)
92
 
    range = int(range)
93
 
    return (pos, range)
94
 
 
95
 
 
96
 
def hunk_from_header(line):
97
 
    if not line.startswith("@@") or not line.endswith("@@\n") \
98
 
        or not len(line) > 4:
99
 
        raise MalformedHunkHeader("Does not start and end with @@.", line)
100
 
    try:
101
 
        (orig, mod) = line[3:-4].split(" ")
102
 
    except Exception, e:
103
 
        raise MalformedHunkHeader(str(e), line)
104
 
    if not orig.startswith('-') or not mod.startswith('+'):
105
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
106
 
    try:
107
 
        (orig_pos, orig_range) = parse_range(orig[1:])
108
 
        (mod_pos, mod_range) = parse_range(mod[1:])
109
 
    except Exception, e:
110
 
        raise MalformedHunkHeader(str(e), line)
111
 
    if mod_range < 0 or orig_range < 0:
112
 
        raise MalformedHunkHeader("Hunk range is negative", line)
113
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
114
 
 
115
 
 
116
 
class HunkLine:
117
 
    def __init__(self, contents):
118
 
        self.contents = contents
119
 
 
120
 
    def get_str(self, leadchar):
121
 
        if self.contents == "\n" and leadchar == " " and False:
122
 
            return "\n"
123
 
        if not self.contents.endswith('\n'):
124
 
            terminator = '\n' + NO_NL
125
 
        else:
126
 
            terminator = ''
127
 
        return leadchar + self.contents + terminator
128
 
 
129
 
 
130
 
class ContextLine(HunkLine):
131
 
    def __init__(self, contents):
132
 
        HunkLine.__init__(self, contents)
133
 
 
134
 
    def __str__(self):
135
 
        return self.get_str(" ")
136
 
 
137
 
 
138
 
class InsertLine(HunkLine):
139
 
    def __init__(self, contents):
140
 
        HunkLine.__init__(self, contents)
141
 
 
142
 
    def __str__(self):
143
 
        return self.get_str("+")
144
 
 
145
 
 
146
 
class RemoveLine(HunkLine):
147
 
    def __init__(self, contents):
148
 
        HunkLine.__init__(self, contents)
149
 
 
150
 
    def __str__(self):
151
 
        return self.get_str("-")
152
 
 
153
 
NO_NL = '\\ No newline at end of file\n'
154
 
__pychecker__="no-returnvalues"
155
 
 
156
 
def parse_line(line):
157
 
    if line.startswith("\n"):
158
 
        return ContextLine(line)
159
 
    elif line.startswith(" "):
160
 
        return ContextLine(line[1:])
161
 
    elif line.startswith("+"):
162
 
        return InsertLine(line[1:])
163
 
    elif line.startswith("-"):
164
 
        return RemoveLine(line[1:])
165
 
    elif line == NO_NL:
166
 
        return NO_NL
167
 
    else:
168
 
        raise MalformedLine("Unknown line type", line)
169
 
__pychecker__=""
170
 
 
171
 
 
172
 
class Hunk:
173
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
174
 
        self.orig_pos = orig_pos
175
 
        self.orig_range = orig_range
176
 
        self.mod_pos = mod_pos
177
 
        self.mod_range = mod_range
178
 
        self.lines = []
179
 
 
180
 
    def get_header(self):
181
 
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
182
 
                                                   self.orig_range),
183
 
                                    self.range_str(self.mod_pos, 
184
 
                                                   self.mod_range))
185
 
 
186
 
    def range_str(self, pos, range):
187
 
        """Return a file range, special-casing for 1-line files.
188
 
 
189
 
        :param pos: The position in the file
190
 
        :type pos: int
191
 
        :range: The range in the file
192
 
        :type range: int
193
 
        :return: a string in the format 1,4 except when range == pos == 1
194
 
        """
195
 
        if range == 1:
196
 
            return "%i" % pos
197
 
        else:
198
 
            return "%i,%i" % (pos, range)
199
 
 
200
 
    def __str__(self):
201
 
        lines = [self.get_header()]
202
 
        for line in self.lines:
203
 
            lines.append(str(line))
204
 
        return "".join(lines)
205
 
 
206
 
    def shift_to_mod(self, pos):
207
 
        if pos < self.orig_pos-1:
208
 
            return 0
209
 
        elif pos > self.orig_pos+self.orig_range:
210
 
            return self.mod_range - self.orig_range
211
 
        else:
212
 
            return self.shift_to_mod_lines(pos)
213
 
 
214
 
    def shift_to_mod_lines(self, pos):
215
 
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
216
 
        position = self.orig_pos-1
217
 
        shift = 0
218
 
        for line in self.lines:
219
 
            if isinstance(line, InsertLine):
220
 
                shift += 1
221
 
            elif isinstance(line, RemoveLine):
222
 
                if position == pos:
223
 
                    return None
224
 
                shift -= 1
225
 
                position += 1
226
 
            elif isinstance(line, ContextLine):
227
 
                position += 1
228
 
            if position > pos:
229
 
                break
230
 
        return shift
231
 
 
232
 
 
233
 
def iter_hunks(iter_lines):
234
 
    hunk = None
235
 
    for line in iter_lines:
236
 
        if line == "\n":
237
 
            if hunk is not None:
238
 
                yield hunk
239
 
                hunk = None
240
 
            continue
241
 
        if hunk is not None:
242
 
            yield hunk
243
 
        hunk = hunk_from_header(line)
244
 
        orig_size = 0
245
 
        mod_size = 0
246
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
247
 
            hunk_line = parse_line(iter_lines.next())
248
 
            hunk.lines.append(hunk_line)
249
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
250
 
                orig_size += 1
251
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
252
 
                mod_size += 1
253
 
    if hunk is not None:
254
 
        yield hunk
255
 
 
256
 
 
257
 
class Patch:
258
 
    def __init__(self, oldname, newname):
259
 
        self.oldname = oldname
260
 
        self.newname = newname
261
 
        self.hunks = []
262
 
 
263
 
    def __str__(self):
264
 
        ret = self.get_header() 
265
 
        ret += "".join([str(h) for h in self.hunks])
266
 
        return ret
267
 
 
268
 
    def get_header(self):
269
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
270
 
 
271
 
    def stats_str(self):
272
 
        """Return a string of patch statistics"""
273
 
        removes = 0
274
 
        inserts = 0
275
 
        for hunk in self.hunks:
276
 
            for line in hunk.lines:
277
 
                if isinstance(line, InsertLine):
278
 
                     inserts+=1;
279
 
                elif isinstance(line, RemoveLine):
280
 
                     removes+=1;
281
 
        return "%i inserts, %i removes in %i hunks" % \
282
 
            (inserts, removes, len(self.hunks))
283
 
 
284
 
    def pos_in_mod(self, position):
285
 
        newpos = position
286
 
        for hunk in self.hunks:
287
 
            shift = hunk.shift_to_mod(position)
288
 
            if shift is None:
289
 
                return None
290
 
            newpos += shift
291
 
        return newpos
292
 
            
293
 
    def iter_inserted(self):
294
 
        """Iteraties through inserted lines
295
 
        
296
 
        :return: Pair of line number, line
297
 
        :rtype: iterator of (int, InsertLine)
298
 
        """
299
 
        for hunk in self.hunks:
300
 
            pos = hunk.mod_pos - 1;
301
 
            for line in hunk.lines:
302
 
                if isinstance(line, InsertLine):
303
 
                    yield (pos, line)
304
 
                    pos += 1
305
 
                if isinstance(line, ContextLine):
306
 
                    pos += 1
307
 
 
308
 
 
309
 
def parse_patch(iter_lines):
310
 
    (orig_name, mod_name) = get_patch_names(iter_lines)
311
 
    patch = Patch(orig_name, mod_name)
312
 
    for hunk in iter_hunks(iter_lines):
313
 
        patch.hunks.append(hunk)
314
 
    return patch
315
 
 
316
 
 
317
 
def iter_file_patch(iter_lines):
318
 
    saved_lines = []
319
 
    orig_range = 0
320
 
    for line in iter_lines:
321
 
        if line.startswith('=== ') or line.startswith('*** '):
322
 
            continue
323
 
        if line.startswith('#'):
324
 
            continue
325
 
        elif orig_range > 0:
326
 
            if line.startswith('-') or line.startswith(' '):
327
 
                orig_range -= 1
328
 
        elif line.startswith('--- '):
329
 
            if len(saved_lines) > 0:
330
 
                yield saved_lines
331
 
            saved_lines = []
332
 
        elif line.startswith('@@'):
333
 
            hunk = hunk_from_header(line)
334
 
            orig_range = hunk.orig_range
335
 
        saved_lines.append(line)
336
 
    if len(saved_lines) > 0:
337
 
        yield saved_lines
338
 
 
339
 
 
340
 
def iter_lines_handle_nl(iter_lines):
341
 
    """
342
 
    Iterates through lines, ensuring that lines that originally had no
343
 
    terminating \n are produced without one.  This transformation may be
344
 
    applied at any point up until hunk line parsing, and is safe to apply
345
 
    repeatedly.
346
 
    """
347
 
    last_line = None
348
 
    for line in iter_lines:
349
 
        if line == NO_NL:
350
 
            assert last_line.endswith('\n')
351
 
            last_line = last_line[:-1]
352
 
            line = None
353
 
        if last_line is not None:
354
 
            yield last_line
355
 
        last_line = line
356
 
    if last_line is not None:
357
 
        yield last_line
358
 
 
359
 
 
360
 
def parse_patches(iter_lines):
361
 
    iter_lines = iter_lines_handle_nl(iter_lines)
362
 
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
363
 
 
364
 
 
365
 
def difference_index(atext, btext):
366
 
    """Find the indext of the first character that differs between two texts
367
 
 
368
 
    :param atext: The first text
369
 
    :type atext: str
370
 
    :param btext: The second text
371
 
    :type str: str
372
 
    :return: The index, or None if there are no differences within the range
373
 
    :rtype: int or NoneType
374
 
    """
375
 
    length = len(atext)
376
 
    if len(btext) < length:
377
 
        length = len(btext)
378
 
    for i in range(length):
379
 
        if atext[i] != btext[i]:
380
 
            return i;
381
 
    return None
382
 
 
383
 
 
384
 
def iter_patched(orig_lines, patch_lines):
385
 
    """Iterate through a series of lines with a patch applied.
386
 
    This handles a single file, and does exact, not fuzzy patching.
387
 
    """
388
 
    if orig_lines is not None:
389
 
        orig_lines = orig_lines.__iter__()
390
 
    seen_patch = []
391
 
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
392
 
    get_patch_names(patch_lines)
393
 
    line_no = 1
394
 
    for hunk in iter_hunks(patch_lines):
395
 
        while line_no < hunk.orig_pos:
396
 
            orig_line = orig_lines.next()
397
 
            yield orig_line
398
 
            line_no += 1
399
 
        for hunk_line in hunk.lines:
400
 
            seen_patch.append(str(hunk_line))
401
 
            if isinstance(hunk_line, InsertLine):
402
 
                yield hunk_line.contents
403
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
404
 
                orig_line = orig_lines.next()
405
 
                if orig_line != hunk_line.contents:
406
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
407
 
                if isinstance(hunk_line, ContextLine):
408
 
                    yield orig_line
409
 
                else:
410
 
                    assert isinstance(hunk_line, RemoveLine)
411
 
                line_no += 1
412
 
    if orig_lines is not None:
413
 
        for line in orig_lines:
414
 
            yield line