~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-07-18 11:23:40 UTC
  • Revision ID: mbp@sourcefrog.net-20050718112340-4ffbfa3624bb6ef3
- weavebench should set random seed to make it reproducible

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
 
 
18
 
from __future__ import absolute_import
19
 
 
20
 
from bzrlib.errors import (
21
 
    BinaryFiles,
22
 
    MalformedHunkHeader,
23
 
    MalformedLine,
24
 
    MalformedPatchHeader,
25
 
    PatchConflict,
26
 
    PatchSyntax,
27
 
    )
28
 
 
29
 
import re
30
 
 
31
 
 
32
 
binary_files_re = 'Binary files (.*) and (.*) differ\n'
33
 
 
34
 
def get_patch_names(iter_lines):
35
 
    try:
36
 
        line = iter_lines.next()
37
 
        match = re.match(binary_files_re, line)
38
 
        if match is not None:
39
 
            raise BinaryFiles(match.group(1), match.group(2))
40
 
        if not line.startswith("--- "):
41
 
            raise MalformedPatchHeader("No orig name", line)
42
 
        else:
43
 
            orig_name = line[4:].rstrip("\n")
44
 
    except StopIteration:
45
 
        raise MalformedPatchHeader("No orig line", "")
46
 
    try:
47
 
        line = iter_lines.next()
48
 
        if not line.startswith("+++ "):
49
 
            raise PatchSyntax("No mod name")
50
 
        else:
51
 
            mod_name = line[4:].rstrip("\n")
52
 
    except StopIteration:
53
 
        raise MalformedPatchHeader("No mod line", "")
54
 
    return (orig_name, mod_name)
55
 
 
56
 
 
57
 
def parse_range(textrange):
58
 
    """Parse a patch range, handling the "1" special-case
59
 
 
60
 
    :param textrange: The text to parse
61
 
    :type textrange: str
62
 
    :return: the position and range, as a tuple
63
 
    :rtype: (int, int)
64
 
    """
65
 
    tmp = textrange.split(',')
66
 
    if len(tmp) == 1:
67
 
        pos = tmp[0]
68
 
        range = "1"
69
 
    else:
70
 
        (pos, range) = tmp
71
 
    pos = int(pos)
72
 
    range = int(range)
73
 
    return (pos, range)
74
 
 
75
 
 
76
 
def hunk_from_header(line):
77
 
    import re
78
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
79
 
    if matches is None:
80
 
        raise MalformedHunkHeader("Does not match format.", line)
81
 
    try:
82
 
        (orig, mod) = matches.group(1).split(" ")
83
 
    except (ValueError, IndexError), e:
84
 
        raise MalformedHunkHeader(str(e), line)
85
 
    if not orig.startswith('-') or not mod.startswith('+'):
86
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
87
 
    try:
88
 
        (orig_pos, orig_range) = parse_range(orig[1:])
89
 
        (mod_pos, mod_range) = parse_range(mod[1:])
90
 
    except (ValueError, IndexError), e:
91
 
        raise MalformedHunkHeader(str(e), line)
92
 
    if mod_range < 0 or orig_range < 0:
93
 
        raise MalformedHunkHeader("Hunk range is negative", line)
94
 
    tail = matches.group(3)
95
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
96
 
 
97
 
 
98
 
class HunkLine:
99
 
    def __init__(self, contents):
100
 
        self.contents = contents
101
 
 
102
 
    def get_str(self, leadchar):
103
 
        if self.contents == "\n" and leadchar == " " and False:
104
 
            return "\n"
105
 
        if not self.contents.endswith('\n'):
106
 
            terminator = '\n' + NO_NL
107
 
        else:
108
 
            terminator = ''
109
 
        return leadchar + self.contents + terminator
110
 
 
111
 
 
112
 
class ContextLine(HunkLine):
113
 
    def __init__(self, contents):
114
 
        HunkLine.__init__(self, contents)
115
 
 
116
 
    def __str__(self):
117
 
        return self.get_str(" ")
118
 
 
119
 
 
120
 
class InsertLine(HunkLine):
121
 
    def __init__(self, contents):
122
 
        HunkLine.__init__(self, contents)
123
 
 
124
 
    def __str__(self):
125
 
        return self.get_str("+")
126
 
 
127
 
 
128
 
class RemoveLine(HunkLine):
129
 
    def __init__(self, contents):
130
 
        HunkLine.__init__(self, contents)
131
 
 
132
 
    def __str__(self):
133
 
        return self.get_str("-")
134
 
 
135
 
NO_NL = '\\ No newline at end of file\n'
136
 
__pychecker__="no-returnvalues"
137
 
 
138
 
def parse_line(line):
139
 
    if line.startswith("\n"):
140
 
        return ContextLine(line)
141
 
    elif line.startswith(" "):
142
 
        return ContextLine(line[1:])
143
 
    elif line.startswith("+"):
144
 
        return InsertLine(line[1:])
145
 
    elif line.startswith("-"):
146
 
        return RemoveLine(line[1:])
147
 
    else:
148
 
        raise MalformedLine("Unknown line type", line)
149
 
__pychecker__=""
150
 
 
151
 
 
152
 
class Hunk:
153
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
154
 
        self.orig_pos = orig_pos
155
 
        self.orig_range = orig_range
156
 
        self.mod_pos = mod_pos
157
 
        self.mod_range = mod_range
158
 
        self.tail = tail
159
 
        self.lines = []
160
 
 
161
 
    def get_header(self):
162
 
        if self.tail is None:
163
 
            tail_str = ''
164
 
        else:
165
 
            tail_str = ' ' + self.tail
166
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
167
 
                                                     self.orig_range),
168
 
                                      self.range_str(self.mod_pos,
169
 
                                                     self.mod_range),
170
 
                                      tail_str)
171
 
 
172
 
    def range_str(self, pos, range):
173
 
        """Return a file range, special-casing for 1-line files.
174
 
 
175
 
        :param pos: The position in the file
176
 
        :type pos: int
177
 
        :range: The range in the file
178
 
        :type range: int
179
 
        :return: a string in the format 1,4 except when range == pos == 1
180
 
        """
181
 
        if range == 1:
182
 
            return "%i" % pos
183
 
        else:
184
 
            return "%i,%i" % (pos, range)
185
 
 
186
 
    def __str__(self):
187
 
        lines = [self.get_header()]
188
 
        for line in self.lines:
189
 
            lines.append(str(line))
190
 
        return "".join(lines)
191
 
 
192
 
    def shift_to_mod(self, pos):
193
 
        if pos < self.orig_pos-1:
194
 
            return 0
195
 
        elif pos > self.orig_pos+self.orig_range:
196
 
            return self.mod_range - self.orig_range
197
 
        else:
198
 
            return self.shift_to_mod_lines(pos)
199
 
 
200
 
    def shift_to_mod_lines(self, pos):
201
 
        position = self.orig_pos-1
202
 
        shift = 0
203
 
        for line in self.lines:
204
 
            if isinstance(line, InsertLine):
205
 
                shift += 1
206
 
            elif isinstance(line, RemoveLine):
207
 
                if position == pos:
208
 
                    return None
209
 
                shift -= 1
210
 
                position += 1
211
 
            elif isinstance(line, ContextLine):
212
 
                position += 1
213
 
            if position > pos:
214
 
                break
215
 
        return shift
216
 
 
217
 
 
218
 
def iter_hunks(iter_lines, allow_dirty=False):
219
 
    '''
220
 
    :arg iter_lines: iterable of lines to parse for hunks
221
 
    :kwarg allow_dirty: If True, when we encounter something that is not
222
 
        a hunk header when we're looking for one, assume the rest of the lines
223
 
        are not part of the patch (comments or other junk).  Default False
224
 
    '''
225
 
    hunk = None
226
 
    for line in iter_lines:
227
 
        if line == "\n":
228
 
            if hunk is not None:
229
 
                yield hunk
230
 
                hunk = None
231
 
            continue
232
 
        if hunk is not None:
233
 
            yield hunk
234
 
        try:
235
 
            hunk = hunk_from_header(line)
236
 
        except MalformedHunkHeader:
237
 
            if allow_dirty:
238
 
                # If the line isn't a hunk header, then we've reached the end
239
 
                # of this patch and there's "junk" at the end.  Ignore the
240
 
                # rest of this patch.
241
 
                return
242
 
            raise
243
 
        orig_size = 0
244
 
        mod_size = 0
245
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
246
 
            hunk_line = parse_line(iter_lines.next())
247
 
            hunk.lines.append(hunk_line)
248
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
249
 
                orig_size += 1
250
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
251
 
                mod_size += 1
252
 
    if hunk is not None:
253
 
        yield hunk
254
 
 
255
 
 
256
 
class BinaryPatch(object):
257
 
    def __init__(self, oldname, newname):
258
 
        self.oldname = oldname
259
 
        self.newname = newname
260
 
 
261
 
    def __str__(self):
262
 
        return 'Binary files %s and %s differ\n' % (self.oldname, self.newname)
263
 
 
264
 
 
265
 
class Patch(BinaryPatch):
266
 
 
267
 
    def __init__(self, oldname, newname):
268
 
        BinaryPatch.__init__(self, oldname, newname)
269
 
        self.hunks = []
270
 
 
271
 
    def __str__(self):
272
 
        ret = self.get_header()
273
 
        ret += "".join([str(h) for h in self.hunks])
274
 
        return ret
275
 
 
276
 
    def get_header(self):
277
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
278
 
 
279
 
    def stats_values(self):
280
 
        """Calculate the number of inserts and removes."""
281
 
        removes = 0
282
 
        inserts = 0
283
 
        for hunk in self.hunks:
284
 
            for line in hunk.lines:
285
 
                if isinstance(line, InsertLine):
286
 
                     inserts+=1;
287
 
                elif isinstance(line, RemoveLine):
288
 
                     removes+=1;
289
 
        return (inserts, removes, len(self.hunks))
290
 
 
291
 
    def stats_str(self):
292
 
        """Return a string of patch statistics"""
293
 
        return "%i inserts, %i removes in %i hunks" % \
294
 
            self.stats_values()
295
 
 
296
 
    def pos_in_mod(self, position):
297
 
        newpos = position
298
 
        for hunk in self.hunks:
299
 
            shift = hunk.shift_to_mod(position)
300
 
            if shift is None:
301
 
                return None
302
 
            newpos += shift
303
 
        return newpos
304
 
 
305
 
    def iter_inserted(self):
306
 
        """Iteraties through inserted lines
307
 
 
308
 
        :return: Pair of line number, line
309
 
        :rtype: iterator of (int, InsertLine)
310
 
        """
311
 
        for hunk in self.hunks:
312
 
            pos = hunk.mod_pos - 1;
313
 
            for line in hunk.lines:
314
 
                if isinstance(line, InsertLine):
315
 
                    yield (pos, line)
316
 
                    pos += 1
317
 
                if isinstance(line, ContextLine):
318
 
                    pos += 1
319
 
 
320
 
 
321
 
def parse_patch(iter_lines, allow_dirty=False):
322
 
    '''
323
 
    :arg iter_lines: iterable of lines to parse
324
 
    :kwarg allow_dirty: If True, allow the patch to have trailing junk.
325
 
        Default False
326
 
    '''
327
 
    iter_lines = iter_lines_handle_nl(iter_lines)
328
 
    try:
329
 
        (orig_name, mod_name) = get_patch_names(iter_lines)
330
 
    except BinaryFiles, e:
331
 
        return BinaryPatch(e.orig_name, e.mod_name)
332
 
    else:
333
 
        patch = Patch(orig_name, mod_name)
334
 
        for hunk in iter_hunks(iter_lines, allow_dirty):
335
 
            patch.hunks.append(hunk)
336
 
        return patch
337
 
 
338
 
 
339
 
def iter_file_patch(iter_lines, allow_dirty=False):
340
 
    '''
341
 
    :arg iter_lines: iterable of lines to parse for patches
342
 
    :kwarg allow_dirty: If True, allow comments and other non-patch text
343
 
        before the first patch.  Note that the algorithm here can only find
344
 
        such text before any patches have been found.  Comments after the
345
 
        first patch are stripped away in iter_hunks() if it is also passed
346
 
        allow_dirty=True.  Default False.
347
 
    '''
348
 
    ### FIXME: Docstring is not quite true.  We allow certain comments no
349
 
    # matter what, If they startwith '===', '***', or '#' Someone should
350
 
    # reexamine this logic and decide if we should include those in
351
 
    # allow_dirty or restrict those to only being before the patch is found
352
 
    # (as allow_dirty does).
353
 
    regex = re.compile(binary_files_re)
354
 
    saved_lines = []
355
 
    orig_range = 0
356
 
    beginning = True
357
 
    for line in iter_lines:
358
 
        if line.startswith('=== ') or line.startswith('*** '):
359
 
            continue
360
 
        if line.startswith('#'):
361
 
            continue
362
 
        elif orig_range > 0:
363
 
            if line.startswith('-') or line.startswith(' '):
364
 
                orig_range -= 1
365
 
        elif line.startswith('--- ') or regex.match(line):
366
 
            if allow_dirty and beginning:
367
 
                # Patches can have "junk" at the beginning
368
 
                # Stripping junk from the end of patches is handled when we
369
 
                # parse the patch
370
 
                beginning = False
371
 
            elif len(saved_lines) > 0:
372
 
                yield saved_lines
373
 
            saved_lines = []
374
 
        elif line.startswith('@@'):
375
 
            hunk = hunk_from_header(line)
376
 
            orig_range = hunk.orig_range
377
 
        saved_lines.append(line)
378
 
    if len(saved_lines) > 0:
379
 
        yield saved_lines
380
 
 
381
 
 
382
 
def iter_lines_handle_nl(iter_lines):
383
 
    """
384
 
    Iterates through lines, ensuring that lines that originally had no
385
 
    terminating \n are produced without one.  This transformation may be
386
 
    applied at any point up until hunk line parsing, and is safe to apply
387
 
    repeatedly.
388
 
    """
389
 
    last_line = None
390
 
    for line in iter_lines:
391
 
        if line == NO_NL:
392
 
            if not last_line.endswith('\n'):
393
 
                raise AssertionError()
394
 
            last_line = last_line[:-1]
395
 
            line = None
396
 
        if last_line is not None:
397
 
            yield last_line
398
 
        last_line = line
399
 
    if last_line is not None:
400
 
        yield last_line
401
 
 
402
 
 
403
 
def parse_patches(iter_lines, allow_dirty=False):
404
 
    '''
405
 
    :arg iter_lines: iterable of lines to parse for patches
406
 
    :kwarg allow_dirty: If True, allow text that's not part of the patch at
407
 
        selected places.  This includes comments before and after a patch
408
 
        for instance.  Default False.
409
 
    '''
410
 
    return [parse_patch(f.__iter__(), allow_dirty) for f in
411
 
                        iter_file_patch(iter_lines, allow_dirty)]
412
 
 
413
 
 
414
 
def difference_index(atext, btext):
415
 
    """Find the indext of the first character that differs between two texts
416
 
 
417
 
    :param atext: The first text
418
 
    :type atext: str
419
 
    :param btext: The second text
420
 
    :type str: str
421
 
    :return: The index, or None if there are no differences within the range
422
 
    :rtype: int or NoneType
423
 
    """
424
 
    length = len(atext)
425
 
    if len(btext) < length:
426
 
        length = len(btext)
427
 
    for i in range(length):
428
 
        if atext[i] != btext[i]:
429
 
            return i;
430
 
    return None
431
 
 
432
 
 
433
 
def iter_patched(orig_lines, patch_lines):
434
 
    """Iterate through a series of lines with a patch applied.
435
 
    This handles a single file, and does exact, not fuzzy patching.
436
 
    """
437
 
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
438
 
    get_patch_names(patch_lines)
439
 
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
440
 
 
441
 
 
442
 
def iter_patched_from_hunks(orig_lines, hunks):
443
 
    """Iterate through a series of lines with a patch applied.
444
 
    This handles a single file, and does exact, not fuzzy patching.
445
 
 
446
 
    :param orig_lines: The unpatched lines.
447
 
    :param hunks: An iterable of Hunk instances.
448
 
    """
449
 
    seen_patch = []
450
 
    line_no = 1
451
 
    if orig_lines is not None:
452
 
        orig_lines = iter(orig_lines)
453
 
    for hunk in hunks:
454
 
        while line_no < hunk.orig_pos:
455
 
            orig_line = orig_lines.next()
456
 
            yield orig_line
457
 
            line_no += 1
458
 
        for hunk_line in hunk.lines:
459
 
            seen_patch.append(str(hunk_line))
460
 
            if isinstance(hunk_line, InsertLine):
461
 
                yield hunk_line.contents
462
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
463
 
                orig_line = orig_lines.next()
464
 
                if orig_line != hunk_line.contents:
465
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
466
 
                if isinstance(hunk_line, ContextLine):
467
 
                    yield orig_line
468
 
                else:
469
 
                    if not isinstance(hunk_line, RemoveLine):
470
 
                        raise AssertionError(hunk_line)
471
 
                line_no += 1
472
 
    if orig_lines is not None:
473
 
        for line in orig_lines:
474
 
            yield line