~abentley/bzrtools/bzrtools.dev

« back to all changes in this revision

Viewing changes to patches.py

  • Committer: Aaron Bentley
  • Date: 2005-06-15 15:25:13 UTC
  • Revision ID: abentley@panoramicfeedback.com-20050615152512-e2afe3f794604a12
Added Michael Ellerman's shelf/unshelf

Show diffs side-by-side

added added

removed removed

Lines of Context:
14
14
#    You should have received a copy of the GNU General Public License
15
15
#    along with this program; if not, write to the Free Software
16
16
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
17
 
 
 
17
import sys
 
18
import progress
18
19
class PatchSyntax(Exception):
19
20
    def __init__(self, msg):
20
21
        Exception.__init__(self, msg)
24
25
    def __init__(self, desc, line):
25
26
        self.desc = desc
26
27
        self.line = line
27
 
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
 
28
        msg = "Malformed patch header.  %s\n%s" % (self.desc, self.line)
28
29
        PatchSyntax.__init__(self, msg)
29
30
 
30
31
class MalformedHunkHeader(PatchSyntax):
31
32
    def __init__(self, desc, line):
32
33
        self.desc = desc
33
34
        self.line = line
34
 
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
 
35
        msg = "Malformed hunk header.  %s\n%s" % (self.desc, self.line)
35
36
        PatchSyntax.__init__(self, msg)
36
37
 
37
38
class MalformedLine(PatchSyntax):
106
107
    def get_str(self, leadchar):
107
108
        if self.contents == "\n" and leadchar == " " and False:
108
109
            return "\n"
109
 
        if not self.contents.endswith('\n'):
110
 
            terminator = '\n' + NO_NL
111
 
        else:
112
 
            terminator = ''
113
 
        return leadchar + self.contents + terminator
114
 
 
 
110
        return leadchar + self.contents
115
111
 
116
112
class ContextLine(HunkLine):
117
113
    def __init__(self, contents):
136
132
    def __str__(self):
137
133
        return self.get_str("-")
138
134
 
139
 
NO_NL = '\\ No newline at end of file\n'
140
135
__pychecker__="no-returnvalues"
141
 
 
142
136
def parse_line(line):
143
137
    if line.startswith("\n"):
144
138
        return ContextLine(line)
148
142
        return InsertLine(line[1:])
149
143
    elif line.startswith("-"):
150
144
        return RemoveLine(line[1:])
151
 
    elif line == NO_NL:
152
 
        return NO_NL
153
145
    else:
154
146
        raise MalformedLine("Unknown line type", line)
155
147
__pychecker__=""
218
210
def iter_hunks(iter_lines):
219
211
    hunk = None
220
212
    for line in iter_lines:
221
 
        if line == "\n":
 
213
        if line.startswith("@@"):
222
214
            if hunk is not None:
223
215
                yield hunk
224
 
                hunk = None
225
 
            continue
226
 
        if hunk is not None:
227
 
            yield hunk
228
 
        hunk = hunk_from_header(line)
229
 
        orig_size = 0
230
 
        mod_size = 0
231
 
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
232
 
            hunk_line = parse_line(iter_lines.next())
233
 
            hunk.lines.append(hunk_line)
234
 
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
235
 
                orig_size += 1
236
 
            if isinstance(hunk_line, (InsertLine, ContextLine)):
237
 
                mod_size += 1
 
216
            hunk = hunk_from_header(line)
 
217
        else:
 
218
            hunk.lines.append(parse_line(line))
 
219
 
238
220
    if hunk is not None:
239
221
        yield hunk
240
222
 
245
227
        self.hunks = []
246
228
 
247
229
    def __str__(self):
248
 
        ret = self.get_header() 
 
230
        ret =  "--- %s\n+++ %s\n" % (self.oldname, self.newname) 
249
231
        ret += "".join([str(h) for h in self.hunks])
250
232
        return ret
251
233
 
252
 
    def get_header(self):
253
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
254
 
 
255
234
    def stats_str(self):
256
235
        """Return a string of patch statistics"""
257
236
        removes = 0
302
281
    for line in iter_lines:
303
282
        if line.startswith('*** '):
304
283
            continue
305
 
        if line.startswith('==='):
306
 
            continue
307
284
        elif line.startswith('--- '):
308
285
            if len(saved_lines) > 0:
309
286
                yield saved_lines
313
290
        yield saved_lines
314
291
 
315
292
 
316
 
def iter_lines_handle_nl(iter_lines):
317
 
    """
318
 
    Iterates through lines, ensuring that lines that originally had no
319
 
    terminating \n are produced without one.  This transformation may be
320
 
    applied at any point up until hunk line parsing, and is safe to apply
321
 
    repeatedly.
322
 
    """
323
 
    last_line = None
324
 
    for line in iter_lines:
325
 
        if line == NO_NL:
326
 
            assert last_line.endswith('\n')
327
 
            last_line = last_line[:-1]
328
 
            line = None
329
 
        if last_line is not None:
330
 
            yield last_line
331
 
        last_line = line
332
 
    if last_line is not None:
333
 
        yield last_line
334
 
 
335
 
 
336
293
def parse_patches(iter_lines):
337
 
    iter_lines = iter_lines_handle_nl(iter_lines)
338
294
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
339
295
 
340
296
 
 
297
class AnnotateLine:
 
298
    """A line associated with the log that produced it"""
 
299
    def __init__(self, text, log=None):
 
300
        self.text = text
 
301
        self.log = log
 
302
 
 
303
class CantGetRevisionData(Exception):
 
304
    def __init__(self, revision):
 
305
        Exception.__init__(self, "Can't get data for revision %s" % revision)
 
306
        
 
307
def annotate_file2(file_lines, anno_iter):
 
308
    for result in iter_annotate_file(file_lines, anno_iter):
 
309
        pass
 
310
    return result
 
311
 
 
312
        
 
313
def iter_annotate_file(file_lines, anno_iter):
 
314
    lines = [AnnotateLine(f) for f in file_lines]
 
315
    patches = []
 
316
    try:
 
317
        for result in anno_iter:
 
318
            if isinstance(result, progress.Progress):
 
319
                yield result
 
320
                continue
 
321
            log, iter_inserted, patch = result
 
322
            for (num, line) in iter_inserted:
 
323
                old_num = num
 
324
 
 
325
                for cur_patch in patches:
 
326
                    num = cur_patch.pos_in_mod(num)
 
327
                    if num == None: 
 
328
                        break
 
329
 
 
330
                if num >= len(lines):
 
331
                    continue
 
332
                if num is not None and lines[num].log is None:
 
333
                    lines[num].log = log
 
334
            patches=[patch]+patches
 
335
    except CantGetRevisionData:
 
336
        pass
 
337
    yield lines
 
338
 
 
339
 
341
340
def difference_index(atext, btext):
342
341
    """Find the indext of the first character that differs betweeen two texts
343
342
 
356
355
            return i;
357
356
    return None
358
357
 
359
 
class PatchConflict(Exception):
360
 
    def __init__(self, line_no, orig_line, patch_line):
361
 
        orig = orig_line.rstrip('\n')
362
 
        patch = str(patch_line).rstrip('\n')
363
 
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
364
 
            ' but patch says it should be "%s"' % (line_no, orig, patch)
365
 
        Exception.__init__(self, msg)
366
 
 
367
 
 
368
 
def iter_patched(orig_lines, patch_lines):
369
 
    """Iterate through a series of lines with a patch applied.
370
 
    This handles a single file, and does exact, not fuzzy patching.
371
 
    """
372
 
    if orig_lines is not None:
373
 
        orig_lines = orig_lines.__iter__()
374
 
    seen_patch = []
375
 
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
376
 
    get_patch_names(patch_lines)
377
 
    line_no = 1
378
 
    for hunk in iter_hunks(patch_lines):
379
 
        while line_no < hunk.orig_pos:
380
 
            orig_line = orig_lines.next()
381
 
            yield orig_line
382
 
            line_no += 1
383
 
        for hunk_line in hunk.lines:
384
 
            seen_patch.append(str(hunk_line))
385
 
            if isinstance(hunk_line, InsertLine):
386
 
                yield hunk_line.contents
387
 
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
388
 
                orig_line = orig_lines.next()
389
 
                if orig_line != hunk_line.contents:
390
 
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
391
 
                if isinstance(hunk_line, ContextLine):
392
 
                    yield orig_line
393
 
                else:
394
 
                    assert isinstance(hunk_line, RemoveLine)
395
 
                line_no += 1
396
 
                    
397
 
import unittest
398
 
import os.path
399
 
class PatchesTester(unittest.TestCase):
400
 
    def datafile(self, filename):
401
 
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
402
 
                                 filename)
403
 
        return file(data_path, "rb")
404
 
 
405
 
    def testValidPatchHeader(self):
406
 
        """Parse a valid patch header"""
407
 
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
408
 
        (orig, mod) = get_patch_names(lines.__iter__())
409
 
        assert(orig == "orig/commands.py")
410
 
        assert(mod == "mod/dommands.py")
411
 
 
412
 
    def testInvalidPatchHeader(self):
413
 
        """Parse an invalid patch header"""
414
 
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
415
 
        self.assertRaises(MalformedPatchHeader, get_patch_names,
416
 
                          lines.__iter__())
417
 
 
418
 
    def testValidHunkHeader(self):
419
 
        """Parse a valid hunk header"""
420
 
        header = "@@ -34,11 +50,6 @@\n"
421
 
        hunk = hunk_from_header(header);
422
 
        assert (hunk.orig_pos == 34)
423
 
        assert (hunk.orig_range == 11)
424
 
        assert (hunk.mod_pos == 50)
425
 
        assert (hunk.mod_range == 6)
426
 
        assert (str(hunk) == header)
427
 
 
428
 
    def testValidHunkHeader2(self):
429
 
        """Parse a tricky, valid hunk header"""
430
 
        header = "@@ -1 +0,0 @@\n"
431
 
        hunk = hunk_from_header(header);
432
 
        assert (hunk.orig_pos == 1)
433
 
        assert (hunk.orig_range == 1)
434
 
        assert (hunk.mod_pos == 0)
435
 
        assert (hunk.mod_range == 0)
436
 
        assert (str(hunk) == header)
437
 
 
438
 
    def makeMalformed(self, header):
439
 
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
440
 
 
441
 
    def testInvalidHeader(self):
442
 
        """Parse an invalid hunk header"""
443
 
        self.makeMalformed(" -34,11 +50,6 \n")
444
 
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
445
 
        self.makeMalformed("@@ -34,11 +50,6 @@")
446
 
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
447
 
        self.makeMalformed("@@-34,11 +50,6@@\n")
448
 
        self.makeMalformed("@@ 34,11 50,6 @@\n")
449
 
        self.makeMalformed("@@ -34,11 @@\n")
450
 
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
451
 
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
452
 
 
453
 
    def lineThing(self,text, type):
454
 
        line = parse_line(text)
455
 
        assert(isinstance(line, type))
456
 
        assert(str(line)==text)
457
 
 
458
 
    def makeMalformedLine(self, text):
459
 
        self.assertRaises(MalformedLine, parse_line, text)
460
 
 
461
 
    def testValidLine(self):
462
 
        """Parse a valid hunk line"""
463
 
        self.lineThing(" hello\n", ContextLine)
464
 
        self.lineThing("+hello\n", InsertLine)
465
 
        self.lineThing("-hello\n", RemoveLine)
466
 
    
467
 
    def testMalformedLine(self):
468
 
        """Parse invalid valid hunk lines"""
469
 
        self.makeMalformedLine("hello\n")
470
 
    
471
 
    def compare_parsed(self, patchtext):
472
 
        lines = patchtext.splitlines(True)
473
 
        patch = parse_patch(lines.__iter__())
474
 
        pstr = str(patch)
475
 
        i = difference_index(patchtext, pstr)
476
 
        if i is not None:
477
 
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
478
 
        self.assertEqual (patchtext, str(patch))
479
 
 
480
 
    def testAll(self):
481
 
        """Test parsing a whole patch"""
482
 
        patchtext = """--- orig/commands.py
 
358
 
 
359
def test():
 
360
    import unittest
 
361
    class PatchesTester(unittest.TestCase):
 
362
        def testValidPatchHeader(self):
 
363
            """Parse a valid patch header"""
 
364
            lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
 
365
            (orig, mod) = get_patch_names(lines.__iter__())
 
366
            assert(orig == "orig/commands.py")
 
367
            assert(mod == "mod/dommands.py")
 
368
 
 
369
        def testInvalidPatchHeader(self):
 
370
            """Parse an invalid patch header"""
 
371
            lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
 
372
            self.assertRaises(MalformedPatchHeader, get_patch_names,
 
373
                              lines.__iter__())
 
374
 
 
375
        def testValidHunkHeader(self):
 
376
            """Parse a valid hunk header"""
 
377
            header = "@@ -34,11 +50,6 @@\n"
 
378
            hunk = hunk_from_header(header);
 
379
            assert (hunk.orig_pos == 34)
 
380
            assert (hunk.orig_range == 11)
 
381
            assert (hunk.mod_pos == 50)
 
382
            assert (hunk.mod_range == 6)
 
383
            assert (str(hunk) == header)
 
384
 
 
385
        def testValidHunkHeader2(self):
 
386
            """Parse a tricky, valid hunk header"""
 
387
            header = "@@ -1 +0,0 @@\n"
 
388
            hunk = hunk_from_header(header);
 
389
            assert (hunk.orig_pos == 1)
 
390
            assert (hunk.orig_range == 1)
 
391
            assert (hunk.mod_pos == 0)
 
392
            assert (hunk.mod_range == 0)
 
393
            assert (str(hunk) == header)
 
394
 
 
395
        def makeMalformed(self, header):
 
396
            self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
 
397
 
 
398
        def testInvalidHeader(self):
 
399
            """Parse an invalid hunk header"""
 
400
            self.makeMalformed(" -34,11 +50,6 \n")
 
401
            self.makeMalformed("@@ +50,6 -34,11 @@\n")
 
402
            self.makeMalformed("@@ -34,11 +50,6 @@")
 
403
            self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
 
404
            self.makeMalformed("@@-34,11 +50,6@@\n")
 
405
            self.makeMalformed("@@ 34,11 50,6 @@\n")
 
406
            self.makeMalformed("@@ -34,11 @@\n")
 
407
            self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
 
408
            self.makeMalformed("@@ -34,11 +50,-6 @@\n")
 
409
 
 
410
        def lineThing(self,text, type):
 
411
            line = parse_line(text)
 
412
            assert(isinstance(line, type))
 
413
            assert(str(line)==text)
 
414
 
 
415
        def makeMalformedLine(self, text):
 
416
            self.assertRaises(MalformedLine, parse_line, text)
 
417
 
 
418
        def testValidLine(self):
 
419
            """Parse a valid hunk line"""
 
420
            self.lineThing(" hello\n", ContextLine)
 
421
            self.lineThing("+hello\n", InsertLine)
 
422
            self.lineThing("-hello\n", RemoveLine)
 
423
        
 
424
        def testMalformedLine(self):
 
425
            """Parse invalid valid hunk lines"""
 
426
            self.makeMalformedLine("hello\n")
 
427
        
 
428
        def compare_parsed(self, patchtext):
 
429
            lines = patchtext.splitlines(True)
 
430
            patch = parse_patch(lines.__iter__())
 
431
            pstr = str(patch)
 
432
            i = difference_index(patchtext, pstr)
 
433
            if i is not None:
 
434
                print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
 
435
            assert (patchtext == str(patch))
 
436
 
 
437
        def testAll(self):
 
438
            """Test parsing a whole patch"""
 
439
            patchtext = """--- orig/commands.py
483
440
+++ mod/commands.py
484
441
@@ -1337,7 +1337,8 @@
485
442
 
505
462
         log.save()
506
463
     try:
507
464
"""
508
 
        self.compare_parsed(patchtext)
 
465
            self.compare_parsed(patchtext)
509
466
 
510
 
    def testInit(self):
511
 
        """Handle patches missing half the position, range tuple"""
512
 
        patchtext = \
 
467
        def testInit(self):
 
468
            """Handle patches missing half the position, range tuple"""
 
469
            patchtext = \
513
470
"""--- orig/__init__.py
514
471
+++ mod/__init__.py
515
472
@@ -1 +1,2 @@
516
473
 __docformat__ = "restructuredtext en"
517
 
+__doc__ = An alternate Arch commandline interface
518
 
"""
519
 
        self.compare_parsed(patchtext)
520
 
        
521
 
 
522
 
 
523
 
    def testLineLookup(self):
524
 
        import sys
525
 
        """Make sure we can accurately look up mod line from orig"""
526
 
        patch = parse_patch(self.datafile("diff"))
527
 
        orig = list(self.datafile("orig"))
528
 
        mod = list(self.datafile("mod"))
529
 
        removals = []
530
 
        for i in range(len(orig)):
531
 
            mod_pos = patch.pos_in_mod(i)
532
 
            if mod_pos is None:
533
 
                removals.append(orig[i])
534
 
                continue
535
 
            assert(mod[mod_pos]==orig[i])
536
 
        rem_iter = removals.__iter__()
537
 
        for hunk in patch.hunks:
538
 
            for line in hunk.lines:
539
 
                if isinstance(line, RemoveLine):
540
 
                    next = rem_iter.next()
541
 
                    if line.contents != next:
542
 
                        sys.stdout.write(" orig:%spatch:%s" % (next,
543
 
                                         line.contents))
544
 
                    assert(line.contents == next)
545
 
        self.assertRaises(StopIteration, rem_iter.next)
546
 
 
547
 
    def testFirstLineRenumber(self):
548
 
        """Make sure we handle lines at the beginning of the hunk"""
549
 
        patch = parse_patch(self.datafile("insert_top.patch"))
550
 
        assert (patch.pos_in_mod(0)==1)
551
 
 
552
 
def test():
 
474
+__doc__ = An alternate Arch commandline interface"""
 
475
            self.compare_parsed(patchtext)
 
476
            
 
477
 
 
478
 
 
479
        def testLineLookup(self):
 
480
            """Make sure we can accurately look up mod line from orig"""
 
481
            patch = parse_patch(open("testdata/diff"))
 
482
            orig = list(open("testdata/orig"))
 
483
            mod = list(open("testdata/mod"))
 
484
            removals = []
 
485
            for i in range(len(orig)):
 
486
                mod_pos = patch.pos_in_mod(i)
 
487
                if mod_pos is None:
 
488
                    removals.append(orig[i])
 
489
                    continue
 
490
                assert(mod[mod_pos]==orig[i])
 
491
            rem_iter = removals.__iter__()
 
492
            for hunk in patch.hunks:
 
493
                for line in hunk.lines:
 
494
                    if isinstance(line, RemoveLine):
 
495
                        next = rem_iter.next()
 
496
                        if line.contents != next:
 
497
                            sys.stdout.write(" orig:%spatch:%s" % (next,
 
498
                                             line.contents))
 
499
                        assert(line.contents == next)
 
500
            self.assertRaises(StopIteration, rem_iter.next)
 
501
 
 
502
        def testFirstLineRenumber(self):
 
503
            """Make sure we handle lines at the beginning of the hunk"""
 
504
            patch = parse_patch(open("testdata/insert_top.patch"))
 
505
            assert (patch.pos_in_mod(0)==1)
 
506
    
 
507
            
553
508
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
554
509
    runner = unittest.TextTestRunner(verbosity=0)
555
510
    return runner.run(patchesTestSuite)