~abentley/bzrtools/bzrtools.dev

« back to all changes in this revision

Viewing changes to patches.py

  • Committer: Aaron Bentley
  • Date: 2006-04-12 01:27:15 UTC
  • mto: This revision was merged to the branch mainline in revision 362.
  • Revision ID: aaron.bentley@utoronto.ca-20060412012715-b8e33bae91b660a0
Tweak switch command

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2004, 2005 Aaron Bentley
 
2
# <aaron.bentley@utoronto.ca>
 
3
#
 
4
#    This program is free software; you can redistribute it and/or modify
 
5
#    it under the terms of the GNU General Public License as published by
 
6
#    the Free Software Foundation; either version 2 of the License, or
 
7
#    (at your option) any later version.
 
8
#
 
9
#    This program is distributed in the hope that it will be useful,
 
10
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
#    GNU General Public License for more details.
 
13
#
 
14
#    You should have received a copy of the GNU General Public License
 
15
#    along with this program; if not, write to the Free Software
 
16
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
17
 
 
18
class PatchSyntax(Exception):
 
19
    def __init__(self, msg):
 
20
        Exception.__init__(self, msg)
 
21
 
 
22
 
 
23
class MalformedPatchHeader(PatchSyntax):
 
24
    def __init__(self, desc, line):
 
25
        self.desc = desc
 
26
        self.line = line
 
27
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
 
28
        PatchSyntax.__init__(self, msg)
 
29
 
 
30
class MalformedHunkHeader(PatchSyntax):
 
31
    def __init__(self, desc, line):
 
32
        self.desc = desc
 
33
        self.line = line
 
34
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
 
35
        PatchSyntax.__init__(self, msg)
 
36
 
 
37
class MalformedLine(PatchSyntax):
 
38
    def __init__(self, desc, line):
 
39
        self.desc = desc
 
40
        self.line = line
 
41
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
 
42
        PatchSyntax.__init__(self, msg)
 
43
 
 
44
def get_patch_names(iter_lines):
 
45
    try:
 
46
        line = iter_lines.next()
 
47
        if not line.startswith("--- "):
 
48
            raise MalformedPatchHeader("No orig name", line)
 
49
        else:
 
50
            orig_name = line[4:].rstrip("\n")
 
51
    except StopIteration:
 
52
        raise MalformedPatchHeader("No orig line", "")
 
53
    try:
 
54
        line = iter_lines.next()
 
55
        if not line.startswith("+++ "):
 
56
            raise PatchSyntax("No mod name")
 
57
        else:
 
58
            mod_name = line[4:].rstrip("\n")
 
59
    except StopIteration:
 
60
        raise MalformedPatchHeader("No mod line", "")
 
61
    return (orig_name, mod_name)
 
62
 
 
63
def parse_range(textrange):
 
64
    """Parse a patch range, handling the "1" special-case
 
65
 
 
66
    :param textrange: The text to parse
 
67
    :type textrange: str
 
68
    :return: the position and range, as a tuple
 
69
    :rtype: (int, int)
 
70
    """
 
71
    tmp = textrange.split(',')
 
72
    if len(tmp) == 1:
 
73
        pos = tmp[0]
 
74
        range = "1"
 
75
    else:
 
76
        (pos, range) = tmp
 
77
    pos = int(pos)
 
78
    range = int(range)
 
79
    return (pos, range)
 
80
 
 
81
 
 
82
def hunk_from_header(line):
 
83
    if not line.startswith("@@") or not line.endswith("@@\n") \
 
84
        or not len(line) > 4:
 
85
        raise MalformedHunkHeader("Does not start and end with @@.", line)
 
86
    try:
 
87
        (orig, mod) = line[3:-4].split(" ")
 
88
    except Exception, e:
 
89
        raise MalformedHunkHeader(str(e), line)
 
90
    if not orig.startswith('-') or not mod.startswith('+'):
 
91
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
 
92
    try:
 
93
        (orig_pos, orig_range) = parse_range(orig[1:])
 
94
        (mod_pos, mod_range) = parse_range(mod[1:])
 
95
    except Exception, e:
 
96
        raise MalformedHunkHeader(str(e), line)
 
97
    if mod_range < 0 or orig_range < 0:
 
98
        raise MalformedHunkHeader("Hunk range is negative", line)
 
99
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
 
100
 
 
101
 
 
102
class HunkLine:
 
103
    def __init__(self, contents):
 
104
        self.contents = contents
 
105
 
 
106
    def get_str(self, leadchar):
 
107
        if self.contents == "\n" and leadchar == " " and False:
 
108
            return "\n"
 
109
        if not self.contents.endswith('\n'):
 
110
            terminator = '\n' + NO_NL
 
111
        else:
 
112
            terminator = ''
 
113
        return leadchar + self.contents + terminator
 
114
 
 
115
 
 
116
class ContextLine(HunkLine):
 
117
    def __init__(self, contents):
 
118
        HunkLine.__init__(self, contents)
 
119
 
 
120
    def __str__(self):
 
121
        return self.get_str(" ")
 
122
 
 
123
 
 
124
class InsertLine(HunkLine):
 
125
    def __init__(self, contents):
 
126
        HunkLine.__init__(self, contents)
 
127
 
 
128
    def __str__(self):
 
129
        return self.get_str("+")
 
130
 
 
131
 
 
132
class RemoveLine(HunkLine):
 
133
    def __init__(self, contents):
 
134
        HunkLine.__init__(self, contents)
 
135
 
 
136
    def __str__(self):
 
137
        return self.get_str("-")
 
138
 
 
139
NO_NL = '\\ No newline at end of file\n'
 
140
__pychecker__="no-returnvalues"
 
141
 
 
142
def parse_line(line):
 
143
    if line.startswith("\n"):
 
144
        return ContextLine(line)
 
145
    elif line.startswith(" "):
 
146
        return ContextLine(line[1:])
 
147
    elif line.startswith("+"):
 
148
        return InsertLine(line[1:])
 
149
    elif line.startswith("-"):
 
150
        return RemoveLine(line[1:])
 
151
    elif line == NO_NL:
 
152
        return NO_NL
 
153
    else:
 
154
        raise MalformedLine("Unknown line type", line)
 
155
__pychecker__=""
 
156
 
 
157
 
 
158
class Hunk:
 
159
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
 
160
        self.orig_pos = orig_pos
 
161
        self.orig_range = orig_range
 
162
        self.mod_pos = mod_pos
 
163
        self.mod_range = mod_range
 
164
        self.lines = []
 
165
 
 
166
    def get_header(self):
 
167
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
 
168
                                                   self.orig_range),
 
169
                                    self.range_str(self.mod_pos, 
 
170
                                                   self.mod_range))
 
171
 
 
172
    def range_str(self, pos, range):
 
173
        """Return a file range, special-casing for 1-line files.
 
174
 
 
175
        :param pos: The position in the file
 
176
        :type pos: int
 
177
        :range: The range in the file
 
178
        :type range: int
 
179
        :return: a string in the format 1,4 except when range == pos == 1
 
180
        """
 
181
        if range == 1:
 
182
            return "%i" % pos
 
183
        else:
 
184
            return "%i,%i" % (pos, range)
 
185
 
 
186
    def __str__(self):
 
187
        lines = [self.get_header()]
 
188
        for line in self.lines:
 
189
            lines.append(str(line))
 
190
        return "".join(lines)
 
191
 
 
192
    def shift_to_mod(self, pos):
 
193
        if pos < self.orig_pos-1:
 
194
            return 0
 
195
        elif pos > self.orig_pos+self.orig_range:
 
196
            return self.mod_range - self.orig_range
 
197
        else:
 
198
            return self.shift_to_mod_lines(pos)
 
199
 
 
200
    def shift_to_mod_lines(self, pos):
 
201
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
 
202
        position = self.orig_pos-1
 
203
        shift = 0
 
204
        for line in self.lines:
 
205
            if isinstance(line, InsertLine):
 
206
                shift += 1
 
207
            elif isinstance(line, RemoveLine):
 
208
                if position == pos:
 
209
                    return None
 
210
                shift -= 1
 
211
                position += 1
 
212
            elif isinstance(line, ContextLine):
 
213
                position += 1
 
214
            if position > pos:
 
215
                break
 
216
        return shift
 
217
 
 
218
def iter_hunks(iter_lines):
 
219
    hunk = None
 
220
    for line in iter_lines:
 
221
        if line == "\n":
 
222
            if hunk is not None:
 
223
                yield hunk
 
224
                hunk = None
 
225
            continue
 
226
        if hunk is not None:
 
227
            yield hunk
 
228
        hunk = hunk_from_header(line)
 
229
        orig_size = 0
 
230
        mod_size = 0
 
231
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
232
            hunk_line = parse_line(iter_lines.next())
 
233
            hunk.lines.append(hunk_line)
 
234
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
235
                orig_size += 1
 
236
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
237
                mod_size += 1
 
238
    if hunk is not None:
 
239
        yield hunk
 
240
 
 
241
class Patch:
 
242
    def __init__(self, oldname, newname):
 
243
        self.oldname = oldname
 
244
        self.newname = newname
 
245
        self.hunks = []
 
246
 
 
247
    def __str__(self):
 
248
        ret = self.get_header() 
 
249
        ret += "".join([str(h) for h in self.hunks])
 
250
        return ret
 
251
 
 
252
    def get_header(self):
 
253
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
254
 
 
255
    def stats_str(self):
 
256
        """Return a string of patch statistics"""
 
257
        removes = 0
 
258
        inserts = 0
 
259
        for hunk in self.hunks:
 
260
            for line in hunk.lines:
 
261
                if isinstance(line, InsertLine):
 
262
                     inserts+=1;
 
263
                elif isinstance(line, RemoveLine):
 
264
                     removes+=1;
 
265
        return "%i inserts, %i removes in %i hunks" % \
 
266
            (inserts, removes, len(self.hunks))
 
267
 
 
268
    def pos_in_mod(self, position):
 
269
        newpos = position
 
270
        for hunk in self.hunks:
 
271
            shift = hunk.shift_to_mod(position)
 
272
            if shift is None:
 
273
                return None
 
274
            newpos += shift
 
275
        return newpos
 
276
            
 
277
    def iter_inserted(self):
 
278
        """Iteraties through inserted lines
 
279
        
 
280
        :return: Pair of line number, line
 
281
        :rtype: iterator of (int, InsertLine)
 
282
        """
 
283
        for hunk in self.hunks:
 
284
            pos = hunk.mod_pos - 1;
 
285
            for line in hunk.lines:
 
286
                if isinstance(line, InsertLine):
 
287
                    yield (pos, line)
 
288
                    pos += 1
 
289
                if isinstance(line, ContextLine):
 
290
                    pos += 1
 
291
 
 
292
def parse_patch(iter_lines):
 
293
    (orig_name, mod_name) = get_patch_names(iter_lines)
 
294
    patch = Patch(orig_name, mod_name)
 
295
    for hunk in iter_hunks(iter_lines):
 
296
        patch.hunks.append(hunk)
 
297
    return patch
 
298
 
 
299
 
 
300
def iter_file_patch(iter_lines):
 
301
    saved_lines = []
 
302
    for line in iter_lines:
 
303
        if line.startswith('*** '):
 
304
            continue
 
305
        if line.startswith('#'):
 
306
            continue
 
307
        if line.startswith('==='):
 
308
            continue
 
309
        elif line.startswith('--- '):
 
310
            if len(saved_lines) > 0:
 
311
                yield saved_lines
 
312
            saved_lines = []
 
313
        saved_lines.append(line)
 
314
    if len(saved_lines) > 0:
 
315
        yield saved_lines
 
316
 
 
317
 
 
318
def iter_lines_handle_nl(iter_lines):
 
319
    """
 
320
    Iterates through lines, ensuring that lines that originally had no
 
321
    terminating \n are produced without one.  This transformation may be
 
322
    applied at any point up until hunk line parsing, and is safe to apply
 
323
    repeatedly.
 
324
    """
 
325
    last_line = None
 
326
    for line in iter_lines:
 
327
        if line == NO_NL:
 
328
            assert last_line.endswith('\n')
 
329
            last_line = last_line[:-1]
 
330
            line = None
 
331
        if last_line is not None:
 
332
            yield last_line
 
333
        last_line = line
 
334
    if last_line is not None:
 
335
        yield last_line
 
336
 
 
337
 
 
338
def parse_patches(iter_lines):
 
339
    iter_lines = iter_lines_handle_nl(iter_lines)
 
340
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
 
341
 
 
342
 
 
343
def difference_index(atext, btext):
 
344
    """Find the indext of the first character that differs betweeen two texts
 
345
 
 
346
    :param atext: The first text
 
347
    :type atext: str
 
348
    :param btext: The second text
 
349
    :type str: str
 
350
    :return: The index, or None if there are no differences within the range
 
351
    :rtype: int or NoneType
 
352
    """
 
353
    length = len(atext)
 
354
    if len(btext) < length:
 
355
        length = len(btext)
 
356
    for i in range(length):
 
357
        if atext[i] != btext[i]:
 
358
            return i;
 
359
    return None
 
360
 
 
361
class PatchConflict(Exception):
 
362
    def __init__(self, line_no, orig_line, patch_line):
 
363
        orig = orig_line.rstrip('\n')
 
364
        patch = str(patch_line).rstrip('\n')
 
365
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
366
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
367
        Exception.__init__(self, msg)
 
368
 
 
369
 
 
370
def iter_patched(orig_lines, patch_lines):
 
371
    """Iterate through a series of lines with a patch applied.
 
372
    This handles a single file, and does exact, not fuzzy patching.
 
373
    """
 
374
    if orig_lines is not None:
 
375
        orig_lines = orig_lines.__iter__()
 
376
    seen_patch = []
 
377
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
 
378
    get_patch_names(patch_lines)
 
379
    line_no = 1
 
380
    for hunk in iter_hunks(patch_lines):
 
381
        while line_no < hunk.orig_pos:
 
382
            orig_line = orig_lines.next()
 
383
            yield orig_line
 
384
            line_no += 1
 
385
        for hunk_line in hunk.lines:
 
386
            seen_patch.append(str(hunk_line))
 
387
            if isinstance(hunk_line, InsertLine):
 
388
                yield hunk_line.contents
 
389
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
390
                orig_line = orig_lines.next()
 
391
                if orig_line != hunk_line.contents:
 
392
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
393
                if isinstance(hunk_line, ContextLine):
 
394
                    yield orig_line
 
395
                else:
 
396
                    assert isinstance(hunk_line, RemoveLine)
 
397
                line_no += 1
 
398
                    
 
399
import unittest
 
400
import os.path
 
401
class PatchesTester(unittest.TestCase):
 
402
    def datafile(self, filename):
 
403
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
 
404
                                 filename)
 
405
        return file(data_path, "rb")
 
406
 
 
407
    def testValidPatchHeader(self):
 
408
        """Parse a valid patch header"""
 
409
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
 
410
        (orig, mod) = get_patch_names(lines.__iter__())
 
411
        assert(orig == "orig/commands.py")
 
412
        assert(mod == "mod/dommands.py")
 
413
 
 
414
    def testInvalidPatchHeader(self):
 
415
        """Parse an invalid patch header"""
 
416
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
 
417
        self.assertRaises(MalformedPatchHeader, get_patch_names,
 
418
                          lines.__iter__())
 
419
 
 
420
    def testValidHunkHeader(self):
 
421
        """Parse a valid hunk header"""
 
422
        header = "@@ -34,11 +50,6 @@\n"
 
423
        hunk = hunk_from_header(header);
 
424
        assert (hunk.orig_pos == 34)
 
425
        assert (hunk.orig_range == 11)
 
426
        assert (hunk.mod_pos == 50)
 
427
        assert (hunk.mod_range == 6)
 
428
        assert (str(hunk) == header)
 
429
 
 
430
    def testValidHunkHeader2(self):
 
431
        """Parse a tricky, valid hunk header"""
 
432
        header = "@@ -1 +0,0 @@\n"
 
433
        hunk = hunk_from_header(header);
 
434
        assert (hunk.orig_pos == 1)
 
435
        assert (hunk.orig_range == 1)
 
436
        assert (hunk.mod_pos == 0)
 
437
        assert (hunk.mod_range == 0)
 
438
        assert (str(hunk) == header)
 
439
 
 
440
    def makeMalformed(self, header):
 
441
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
 
442
 
 
443
    def testInvalidHeader(self):
 
444
        """Parse an invalid hunk header"""
 
445
        self.makeMalformed(" -34,11 +50,6 \n")
 
446
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
 
447
        self.makeMalformed("@@ -34,11 +50,6 @@")
 
448
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
 
449
        self.makeMalformed("@@-34,11 +50,6@@\n")
 
450
        self.makeMalformed("@@ 34,11 50,6 @@\n")
 
451
        self.makeMalformed("@@ -34,11 @@\n")
 
452
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
 
453
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
 
454
 
 
455
    def lineThing(self,text, type):
 
456
        line = parse_line(text)
 
457
        assert(isinstance(line, type))
 
458
        assert(str(line)==text)
 
459
 
 
460
    def makeMalformedLine(self, text):
 
461
        self.assertRaises(MalformedLine, parse_line, text)
 
462
 
 
463
    def testValidLine(self):
 
464
        """Parse a valid hunk line"""
 
465
        self.lineThing(" hello\n", ContextLine)
 
466
        self.lineThing("+hello\n", InsertLine)
 
467
        self.lineThing("-hello\n", RemoveLine)
 
468
    
 
469
    def testMalformedLine(self):
 
470
        """Parse invalid valid hunk lines"""
 
471
        self.makeMalformedLine("hello\n")
 
472
    
 
473
    def compare_parsed(self, patchtext):
 
474
        lines = patchtext.splitlines(True)
 
475
        patch = parse_patch(lines.__iter__())
 
476
        pstr = str(patch)
 
477
        i = difference_index(patchtext, pstr)
 
478
        if i is not None:
 
479
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
 
480
        self.assertEqual (patchtext, str(patch))
 
481
 
 
482
    def testAll(self):
 
483
        """Test parsing a whole patch"""
 
484
        patchtext = """--- orig/commands.py
 
485
+++ mod/commands.py
 
486
@@ -1337,7 +1337,8 @@
 
487
 
 
488
     def set_title(self, command=None):
 
489
         try:
 
490
-            version = self.tree.tree_version.nonarch
 
491
+            version = pylon.alias_or_version(self.tree.tree_version, self.tree,
 
492
+                                             full=False)
 
493
         except:
 
494
             version = "[no version]"
 
495
         if command is None:
 
496
@@ -1983,7 +1984,11 @@
 
497
                                          version)
 
498
         if len(new_merges) > 0:
 
499
             if cmdutil.prompt("Log for merge"):
 
500
-                mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
501
+                if cmdutil.prompt("changelog for merge"):
 
502
+                    mergestuff = "Patches applied:\\n"
 
503
+                    mergestuff += pylon.changelog_for_merge(new_merges)
 
504
+                else:
 
505
+                    mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
506
                 log.description += mergestuff
 
507
         log.save()
 
508
     try:
 
509
"""
 
510
        self.compare_parsed(patchtext)
 
511
 
 
512
    def testInit(self):
 
513
        """Handle patches missing half the position, range tuple"""
 
514
        patchtext = \
 
515
"""--- orig/__init__.py
 
516
+++ mod/__init__.py
 
517
@@ -1 +1,2 @@
 
518
 __docformat__ = "restructuredtext en"
 
519
+__doc__ = An alternate Arch commandline interface
 
520
"""
 
521
        self.compare_parsed(patchtext)
 
522
        
 
523
 
 
524
 
 
525
    def testLineLookup(self):
 
526
        import sys
 
527
        """Make sure we can accurately look up mod line from orig"""
 
528
        patch = parse_patch(self.datafile("diff"))
 
529
        orig = list(self.datafile("orig"))
 
530
        mod = list(self.datafile("mod"))
 
531
        removals = []
 
532
        for i in range(len(orig)):
 
533
            mod_pos = patch.pos_in_mod(i)
 
534
            if mod_pos is None:
 
535
                removals.append(orig[i])
 
536
                continue
 
537
            assert(mod[mod_pos]==orig[i])
 
538
        rem_iter = removals.__iter__()
 
539
        for hunk in patch.hunks:
 
540
            for line in hunk.lines:
 
541
                if isinstance(line, RemoveLine):
 
542
                    next = rem_iter.next()
 
543
                    if line.contents != next:
 
544
                        sys.stdout.write(" orig:%spatch:%s" % (next,
 
545
                                         line.contents))
 
546
                    assert(line.contents == next)
 
547
        self.assertRaises(StopIteration, rem_iter.next)
 
548
 
 
549
    def testFirstLineRenumber(self):
 
550
        """Make sure we handle lines at the beginning of the hunk"""
 
551
        patch = parse_patch(self.datafile("insert_top.patch"))
 
552
        assert (patch.pos_in_mod(0)==1)
 
553
 
 
554
def test():
 
555
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
 
556
    runner = unittest.TextTestRunner(verbosity=0)
 
557
    return runner.run(patchesTestSuite)
 
558
    
 
559
 
 
560
if __name__ == "__main__":
 
561
    test()
 
562
# arch-tag: d1541a25-eac5-4de9-a476-08a7cecd5683