~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/patches.py

  • Committer: Martin Pool
  • Date: 2005-04-28 07:24:55 UTC
  • Revision ID: mbp@sourcefrog.net-20050428072453-7b99afa993a1e549
todo

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005-2010 Aaron Bentley, Canonical Ltd
2
 
# <aaron.bentley@utoronto.ca>
3
 
#
4
 
# This program is free software; you can redistribute it and/or modify
5
 
# it under the terms of the GNU General Public License as published by
6
 
# the Free Software Foundation; either version 2 of the License, or
7
 
# (at your option) any later version.
8
 
#
9
 
# This program is distributed in the hope that it will be useful,
10
 
# but WITHOUT ANY WARRANTY; without even the implied warranty of
11
 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12
 
# GNU General Public License for more details.
13
 
#
14
 
# You should have received a copy of the GNU General Public License
15
 
# along with this program; if not, write to the Free Software
16
 
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
 
 
18
 
from __future__ import absolute_import
19
 
 
20
 
from bzrlib.errors import (
21
 
    BinaryFiles,
22
 
    MalformedHunkHeader,
23
 
    MalformedLine,
24
 
    MalformedPatchHeader,
25
 
    PatchConflict,
26
 
    PatchSyntax,
27
 
    )
28
 
 
29
 
import re
30
 
 
31
 
 
32
 
binary_files_re = 'Binary files (.*) and (.*) differ\n'
33
 
 
34
 
 
35
 
def get_patch_names(iter_lines):
36
 
    line = iter_lines.next()
37
 
    try:
38
 
        match = re.match(binary_files_re, line)
39
 
        if match is not None:
40
 
            raise BinaryFiles(match.group(1), match.group(2))
41
 
        if not line.startswith("--- "):
42
 
            raise MalformedPatchHeader("No orig name", line)
43
 
        else:
44
 
            orig_name = line[4:].rstrip("\n")
45
 
    except StopIteration:
46
 
        raise MalformedPatchHeader("No orig line", "")
47
 
    try:
48
 
        line = iter_lines.next()
49
 
        if not line.startswith("+++ "):
50
 
            raise PatchSyntax("No mod name")
51
 
        else:
52
 
            mod_name = line[4:].rstrip("\n")
53
 
    except StopIteration:
54
 
        raise MalformedPatchHeader("No mod line", "")
55
 
    return (orig_name, mod_name)
56
 
 
57
 
 
58
 
def parse_range(textrange):
59
 
    """Parse a patch range, handling the "1" special-case
60
 
 
61
 
    :param textrange: The text to parse
62
 
    :type textrange: str
63
 
    :return: the position and range, as a tuple
64
 
    :rtype: (int, int)
65
 
    """
66
 
    tmp = textrange.split(',')
67
 
    if len(tmp) == 1:
68
 
        pos = tmp[0]
69
 
        range = "1"
70
 
    else:
71
 
        (pos, range) = tmp
72
 
    pos = int(pos)
73
 
    range = int(range)
74
 
    return (pos, range)
75
 
 
76
 
 
77
 
def hunk_from_header(line):
78
 
    import re
79
 
    matches = re.match(r'\@\@ ([^@]*) \@\@( (.*))?\n', line)
80
 
    if matches is None:
81
 
        raise MalformedHunkHeader("Does not match format.", line)
82
 
    try:
83
 
        (orig, mod) = matches.group(1).split(" ")
84
 
    except (ValueError, IndexError), e:
85
 
        raise MalformedHunkHeader(str(e), line)
86
 
    if not orig.startswith('-') or not mod.startswith('+'):
87
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
88
 
    try:
89
 
        (orig_pos, orig_range) = parse_range(orig[1:])
90
 
        (mod_pos, mod_range) = parse_range(mod[1:])
91
 
    except (ValueError, IndexError), e:
92
 
        raise MalformedHunkHeader(str(e), line)
93
 
    if mod_range < 0 or orig_range < 0:
94
 
        raise MalformedHunkHeader("Hunk range is negative", line)
95
 
    tail = matches.group(3)
96
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range, tail)
97
 
 
98
 
 
99
 
class HunkLine:
100
 
    def __init__(self, contents):
101
 
        self.contents = contents
102
 
 
103
 
    def get_str(self, leadchar):
104
 
        if self.contents == "\n" and leadchar == " " and False:
105
 
            return "\n"
106
 
        if not self.contents.endswith('\n'):
107
 
            terminator = '\n' + NO_NL
108
 
        else:
109
 
            terminator = ''
110
 
        return leadchar + self.contents + terminator
111
 
 
112
 
 
113
 
class ContextLine(HunkLine):
114
 
    def __init__(self, contents):
115
 
        HunkLine.__init__(self, contents)
116
 
 
117
 
    def __str__(self):
118
 
        return self.get_str(" ")
119
 
 
120
 
 
121
 
class InsertLine(HunkLine):
122
 
    def __init__(self, contents):
123
 
        HunkLine.__init__(self, contents)
124
 
 
125
 
    def __str__(self):
126
 
        return self.get_str("+")
127
 
 
128
 
 
129
 
class RemoveLine(HunkLine):
130
 
    def __init__(self, contents):
131
 
        HunkLine.__init__(self, contents)
132
 
 
133
 
    def __str__(self):
134
 
        return self.get_str("-")
135
 
 
136
 
NO_NL = '\\ No newline at end of file\n'
137
 
__pychecker__="no-returnvalues"
138
 
 
139
 
def parse_line(line):
140
 
    if line.startswith("\n"):
141
 
        return ContextLine(line)
142
 
    elif line.startswith(" "):
143
 
        return ContextLine(line[1:])
144
 
    elif line.startswith("+"):
145
 
        return InsertLine(line[1:])
146
 
    elif line.startswith("-"):
147
 
        return RemoveLine(line[1:])
148
 
    else:
149
 
        raise MalformedLine("Unknown line type", line)
150
 
__pychecker__=""
151
 
 
152
 
 
153
 
class Hunk:
154
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range, tail=None):
155
 
        self.orig_pos = orig_pos
156
 
        self.orig_range = orig_range
157
 
        self.mod_pos = mod_pos
158
 
        self.mod_range = mod_range
159
 
        self.tail = tail
160
 
        self.lines = []
161
 
 
162
 
    def get_header(self):
163
 
        if self.tail is None:
164
 
            tail_str = ''
165
 
        else:
166
 
            tail_str = ' ' + self.tail
167
 
        return "@@ -%s +%s @@%s\n" % (self.range_str(self.orig_pos,
168
 
                                                     self.orig_range),
169
 
                                      self.range_str(self.mod_pos,
170
 
                                                     self.mod_range),
171
 
                                      tail_str)
172
 
 
173
 
    def range_str(self, pos, range):
174
 
        """Return a file range, special-casing for 1-line files.
175
 
 
176
 
        :param pos: The position in the file
177
 
        :type pos: int
178
 
        :range: The range in the file
179
 
        :type range: int
180
 
        :return: a string in the format 1,4 except when range == pos == 1
181
 
        """
182
 
        if range == 1:
183
 
            return "%i" % pos
184
 
        else:
185
 
            return "%i,%i" % (pos, range)
186
 
 
187
 
    def __str__(self):
188
 
        lines = [self.get_header()]
189
 
        for line in self.lines:
190
 
            lines.append(str(line))
191
 
        return "".join(lines)
192
 
 
193
 
    def shift_to_mod(self, pos):
194
 
        if pos < self.orig_pos-1:
195
 
            return 0
196
 
        elif pos > self.orig_pos+self.orig_range:
197
 
            return self.mod_range - self.orig_range
198
 
        else:
199
 
            return self.shift_to_mod_lines(pos)
200
 
 
201
 
    def shift_to_mod_lines(self, pos):
202
 
        position = self.orig_pos-1
203
 
        shift = 0
204
 
        for line in self.lines:
205
 
            if isinstance(line, InsertLine):
206
 
                shift += 1
207
 
            elif isinstance(line, RemoveLine):
208
 
                if position == pos:
209
 
                    return None
210
 
                shift -= 1
211
 
                position += 1
212
 
            elif isinstance(line, ContextLine):
213
 
                position += 1
214
 
            if position > pos:
215
 
                break
216
 
        return shift
217
 
 
218
 
 
219
 
def iter_hunks(iter_lines, allow_dirty=False):
220
 
    '''
221
 
    :arg iter_lines: iterable of lines to parse for hunks
222
 
    :kwarg allow_dirty: If True, when we encounter something that is not
223
 
        a hunk header when we're looking for one, assume the rest of the lines
224
 
        are not part of the patch (comments or other junk).  Default False
225
 
    '''
226
 
    hunk = None
227
 
    for line in iter_lines:
228
 
        if line == "\n":
229
 
            if hunk is not None:
230
 
                yield hunk
231
 
                hunk = None
232
 
            continue
233
 
        if hunk is not None:
234
 
            yield hunk
235
 
        try:
236
 
            hunk = hunk_from_header(line)
237
 
        except MalformedHunkHeader:
238
 
            if allow_dirty:
239
 
                # If the line isn't a hunk header, then we've reached the end
240
 
                # of this patch and there's "junk" at the end.  Ignore the
241
 
                # rest of this patch.
242
 
                return
243
 
            raise
244
 
        orig_size = 0
245
 
        mod_size = 0
246
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
247
 
            hunk_line = parse_line(iter_lines.next())
248
 
            hunk.lines.append(hunk_line)
249
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
250
 
                orig_size += 1
251
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
252
 
                mod_size += 1
253
 
    if hunk is not None:
254
 
        yield hunk
255
 
 
256
 
 
257
 
class BinaryPatch(object):
258
 
    def __init__(self, oldname, newname):
259
 
        self.oldname = oldname
260
 
        self.newname = newname
261
 
 
262
 
    def __str__(self):
263
 
        return 'Binary files %s and %s differ\n' % (self.oldname, self.newname)
264
 
 
265
 
 
266
 
class Patch(BinaryPatch):
267
 
 
268
 
    def __init__(self, oldname, newname):
269
 
        BinaryPatch.__init__(self, oldname, newname)
270
 
        self.hunks = []
271
 
 
272
 
    def __str__(self):
273
 
        ret = self.get_header()
274
 
        ret += "".join([str(h) for h in self.hunks])
275
 
        return ret
276
 
 
277
 
    def get_header(self):
278
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
279
 
 
280
 
    def stats_values(self):
281
 
        """Calculate the number of inserts and removes."""
282
 
        removes = 0
283
 
        inserts = 0
284
 
        for hunk in self.hunks:
285
 
            for line in hunk.lines:
286
 
                if isinstance(line, InsertLine):
287
 
                     inserts+=1;
288
 
                elif isinstance(line, RemoveLine):
289
 
                     removes+=1;
290
 
        return (inserts, removes, len(self.hunks))
291
 
 
292
 
    def stats_str(self):
293
 
        """Return a string of patch statistics"""
294
 
        return "%i inserts, %i removes in %i hunks" % \
295
 
            self.stats_values()
296
 
 
297
 
    def pos_in_mod(self, position):
298
 
        newpos = position
299
 
        for hunk in self.hunks:
300
 
            shift = hunk.shift_to_mod(position)
301
 
            if shift is None:
302
 
                return None
303
 
            newpos += shift
304
 
        return newpos
305
 
 
306
 
    def iter_inserted(self):
307
 
        """Iteraties through inserted lines
308
 
 
309
 
        :return: Pair of line number, line
310
 
        :rtype: iterator of (int, InsertLine)
311
 
        """
312
 
        for hunk in self.hunks:
313
 
            pos = hunk.mod_pos - 1;
314
 
            for line in hunk.lines:
315
 
                if isinstance(line, InsertLine):
316
 
                    yield (pos, line)
317
 
                    pos += 1
318
 
                if isinstance(line, ContextLine):
319
 
                    pos += 1
320
 
 
321
 
def parse_patch(iter_lines, allow_dirty=False):
322
 
    '''
323
 
    :arg iter_lines: iterable of lines to parse
324
 
    :kwarg allow_dirty: If True, allow the patch to have trailing junk.
325
 
        Default False
326
 
    '''
327
 
    iter_lines = iter_lines_handle_nl(iter_lines)
328
 
    try:
329
 
        (orig_name, mod_name) = get_patch_names(iter_lines)
330
 
    except BinaryFiles, e:
331
 
        return BinaryPatch(e.orig_name, e.mod_name)
332
 
    else:
333
 
        patch = Patch(orig_name, mod_name)
334
 
        for hunk in iter_hunks(iter_lines, allow_dirty):
335
 
            patch.hunks.append(hunk)
336
 
        return patch
337
 
 
338
 
 
339
 
def iter_file_patch(iter_lines, allow_dirty=False, keep_dirty=False):
340
 
    '''
341
 
    :arg iter_lines: iterable of lines to parse for patches
342
 
    :kwarg allow_dirty: If True, allow comments and other non-patch text
343
 
        before the first patch.  Note that the algorithm here can only find
344
 
        such text before any patches have been found.  Comments after the
345
 
        first patch are stripped away in iter_hunks() if it is also passed
346
 
        allow_dirty=True.  Default False.
347
 
    '''
348
 
    ### FIXME: Docstring is not quite true.  We allow certain comments no
349
 
    # matter what, If they startwith '===', '***', or '#' Someone should
350
 
    # reexamine this logic and decide if we should include those in
351
 
    # allow_dirty or restrict those to only being before the patch is found
352
 
    # (as allow_dirty does).
353
 
    regex = re.compile(binary_files_re)
354
 
    saved_lines = []
355
 
    dirty_head = []
356
 
    orig_range = 0
357
 
    beginning = True
358
 
 
359
 
    for line in iter_lines:
360
 
        if line.startswith('=== '):
361
 
            if len(saved_lines) > 0:
362
 
                if keep_dirty and len(dirty_head) > 0:
363
 
                    yield {'saved_lines': saved_lines,
364
 
                           'dirty_head': dirty_head}
365
 
                    dirty_head = []
366
 
                else:
367
 
                    yield saved_lines
368
 
                saved_lines = []
369
 
            dirty_head.append(line)
370
 
            continue
371
 
        if line.startswith('*** '):
372
 
            continue
373
 
        if line.startswith('#'):
374
 
            continue
375
 
        elif orig_range > 0:
376
 
            if line.startswith('-') or line.startswith(' '):
377
 
                orig_range -= 1
378
 
        elif line.startswith('--- ') or regex.match(line):
379
 
            if allow_dirty and beginning:
380
 
                # Patches can have "junk" at the beginning
381
 
                # Stripping junk from the end of patches is handled when we
382
 
                # parse the patch
383
 
                beginning = False
384
 
            elif len(saved_lines) > 0:
385
 
                if keep_dirty and len(dirty_head) > 0:
386
 
                    yield {'saved_lines': saved_lines,
387
 
                           'dirty_head': dirty_head}
388
 
                    dirty_head = []
389
 
                else:
390
 
                    yield saved_lines
391
 
            saved_lines = []
392
 
        elif line.startswith('@@'):
393
 
            hunk = hunk_from_header(line)
394
 
            orig_range = hunk.orig_range
395
 
        saved_lines.append(line)
396
 
    if len(saved_lines) > 0:
397
 
        if keep_dirty and len(dirty_head) > 0:
398
 
            yield {'saved_lines': saved_lines,
399
 
                   'dirty_head': dirty_head}
400
 
        else:
401
 
            yield saved_lines
402
 
 
403
 
 
404
 
def iter_lines_handle_nl(iter_lines):
405
 
    """
406
 
    Iterates through lines, ensuring that lines that originally had no
407
 
    terminating \n are produced without one.  This transformation may be
408
 
    applied at any point up until hunk line parsing, and is safe to apply
409
 
    repeatedly.
410
 
    """
411
 
    last_line = None
412
 
    for line in iter_lines:
413
 
        if line == NO_NL:
414
 
            if not last_line.endswith('\n'):
415
 
                raise AssertionError()
416
 
            last_line = last_line[:-1]
417
 
            line = None
418
 
        if last_line is not None:
419
 
            yield last_line
420
 
        last_line = line
421
 
    if last_line is not None:
422
 
        yield last_line
423
 
 
424
 
 
425
 
def parse_patches(iter_lines, allow_dirty=False, keep_dirty=False):
426
 
    '''
427
 
    :arg iter_lines: iterable of lines to parse for patches
428
 
    :kwarg allow_dirty: If True, allow text that's not part of the patch at
429
 
        selected places.  This includes comments before and after a patch
430
 
        for instance.  Default False.
431
 
    :kwarg keep_dirty: If True, returns a dict of patches with dirty headers.
432
 
        Default False.
433
 
    '''
434
 
    patches = []
435
 
    for patch_lines in iter_file_patch(iter_lines, allow_dirty, keep_dirty):
436
 
        if 'dirty_head' in patch_lines:
437
 
            patches.append({'patch': parse_patch(
438
 
                patch_lines['saved_lines'], allow_dirty),
439
 
                            'dirty_head': patch_lines['dirty_head']})
440
 
        else:
441
 
            patches.append(parse_patch(patch_lines, allow_dirty))
442
 
    return patches
443
 
 
444
 
 
445
 
def difference_index(atext, btext):
446
 
    """Find the indext of the first character that differs between two texts
447
 
 
448
 
    :param atext: The first text
449
 
    :type atext: str
450
 
    :param btext: The second text
451
 
    :type str: str
452
 
    :return: The index, or None if there are no differences within the range
453
 
    :rtype: int or NoneType
454
 
    """
455
 
    length = len(atext)
456
 
    if len(btext) < length:
457
 
        length = len(btext)
458
 
    for i in range(length):
459
 
        if atext[i] != btext[i]:
460
 
            return i;
461
 
    return None
462
 
 
463
 
 
464
 
def iter_patched(orig_lines, patch_lines):
465
 
    """Iterate through a series of lines with a patch applied.
466
 
    This handles a single file, and does exact, not fuzzy patching.
467
 
    """
468
 
    patch_lines = iter_lines_handle_nl(iter(patch_lines))
469
 
    get_patch_names(patch_lines)
470
 
    return iter_patched_from_hunks(orig_lines, iter_hunks(patch_lines))
471
 
 
472
 
 
473
 
def iter_patched_from_hunks(orig_lines, hunks):
474
 
    """Iterate through a series of lines with a patch applied.
475
 
    This handles a single file, and does exact, not fuzzy patching.
476
 
 
477
 
    :param orig_lines: The unpatched lines.
478
 
    :param hunks: An iterable of Hunk instances.
479
 
    """
480
 
    seen_patch = []
481
 
    line_no = 1
482
 
    if orig_lines is not None:
483
 
        orig_lines = iter(orig_lines)
484
 
    for hunk in hunks:
485
 
        while line_no < hunk.orig_pos:
486
 
            orig_line = orig_lines.next()
487
 
            yield orig_line
488
 
            line_no += 1
489
 
        for hunk_line in hunk.lines:
490
 
            seen_patch.append(str(hunk_line))
491
 
            if isinstance(hunk_line, InsertLine):
492
 
                yield hunk_line.contents
493
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
494
 
                orig_line = orig_lines.next()
495
 
                if orig_line != hunk_line.contents:
496
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
497
 
                if isinstance(hunk_line, ContextLine):
498
 
                    yield orig_line
499
 
                else:
500
 
                    if not isinstance(hunk_line, RemoveLine):
501
 
                        raise AssertionError(hunk_line)
502
 
                line_no += 1
503
 
    if orig_lines is not None:
504
 
        for line in orig_lines:
505
 
            yield line