~abentley/bzrtools/bzrtools.dev

« back to all changes in this revision

Viewing changes to patches.py

  • Committer: Aaron Bentley
  • Date: 2005-09-22 23:30:34 UTC
  • Revision ID: aaron.bentley@utoronto.ca-20050922233034-0bb63c8bef90f19a
Updated NEWS

Show diffs side-by-side

added added

removed removed

Lines of Context:
14
14
#    You should have received a copy of the GNU General Public License
15
15
#    along with this program; if not, write to the Free Software
16
16
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
17
 
17
18
class PatchSyntax(Exception):
18
19
    def __init__(self, msg):
19
20
        Exception.__init__(self, msg)
23
24
    def __init__(self, desc, line):
24
25
        self.desc = desc
25
26
        self.line = line
26
 
        msg = "Malformed patch header.  %s\n%s" % (self.desc, self.line)
 
27
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
27
28
        PatchSyntax.__init__(self, msg)
28
29
 
29
30
class MalformedHunkHeader(PatchSyntax):
30
31
    def __init__(self, desc, line):
31
32
        self.desc = desc
32
33
        self.line = line
33
 
        msg = "Malformed hunk header.  %s\n%s" % (self.desc, self.line)
 
34
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
34
35
        PatchSyntax.__init__(self, msg)
35
36
 
36
37
class MalformedLine(PatchSyntax):
105
106
    def get_str(self, leadchar):
106
107
        if self.contents == "\n" and leadchar == " " and False:
107
108
            return "\n"
108
 
        return leadchar + self.contents
 
109
        if not self.contents.endswith('\n'):
 
110
            terminator = '\n' + NO_NL
 
111
        else:
 
112
            terminator = ''
 
113
        return leadchar + self.contents + terminator
 
114
 
109
115
 
110
116
class ContextLine(HunkLine):
111
117
    def __init__(self, contents):
130
136
    def __str__(self):
131
137
        return self.get_str("-")
132
138
 
 
139
NO_NL = '\\ No newline at end of file\n'
133
140
__pychecker__="no-returnvalues"
 
141
 
134
142
def parse_line(line):
135
143
    if line.startswith("\n"):
136
144
        return ContextLine(line)
140
148
        return InsertLine(line[1:])
141
149
    elif line.startswith("-"):
142
150
        return RemoveLine(line[1:])
 
151
    elif line == NO_NL:
 
152
        return NO_NL
143
153
    else:
144
154
        raise MalformedLine("Unknown line type", line)
145
155
__pychecker__=""
208
218
def iter_hunks(iter_lines):
209
219
    hunk = None
210
220
    for line in iter_lines:
211
 
        if line.startswith("@@"):
 
221
        if line == "\n":
212
222
            if hunk is not None:
213
223
                yield hunk
214
 
            hunk = hunk_from_header(line)
215
 
        else:
216
 
            hunk.lines.append(parse_line(line))
217
 
 
 
224
                hunk = None
 
225
            continue
 
226
        if hunk is not None:
 
227
            yield hunk
 
228
        hunk = hunk_from_header(line)
 
229
        orig_size = 0
 
230
        mod_size = 0
 
231
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
232
            hunk_line = parse_line(iter_lines.next())
 
233
            hunk.lines.append(hunk_line)
 
234
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
235
                orig_size += 1
 
236
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
237
                mod_size += 1
218
238
    if hunk is not None:
219
239
        yield hunk
220
240
 
225
245
        self.hunks = []
226
246
 
227
247
    def __str__(self):
228
 
        ret =  "--- %s\n+++ %s\n" % (self.oldname, self.newname) 
 
248
        ret = self.get_header() 
229
249
        ret += "".join([str(h) for h in self.hunks])
230
250
        return ret
231
251
 
 
252
    def get_header(self):
 
253
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
254
 
232
255
    def stats_str(self):
233
256
        """Return a string of patch statistics"""
234
257
        removes = 0
279
302
    for line in iter_lines:
280
303
        if line.startswith('*** '):
281
304
            continue
 
305
        if line.startswith('==='):
 
306
            continue
282
307
        elif line.startswith('--- '):
283
308
            if len(saved_lines) > 0:
284
309
                yield saved_lines
288
313
        yield saved_lines
289
314
 
290
315
 
 
316
def iter_lines_handle_nl(iter_lines):
 
317
    """
 
318
    Iterates through lines, ensuring that lines that originally had no
 
319
    terminating \n are produced without one.  This transformation may be
 
320
    applied at any point up until hunk line parsing, and is safe to apply
 
321
    repeatedly.
 
322
    """
 
323
    last_line = None
 
324
    for line in iter_lines:
 
325
        if line == NO_NL:
 
326
            assert last_line.endswith('\n')
 
327
            last_line = last_line[:-1]
 
328
            line = None
 
329
        if last_line is not None:
 
330
            yield last_line
 
331
        last_line = line
 
332
    if last_line is not None:
 
333
        yield last_line
 
334
 
 
335
 
291
336
def parse_patches(iter_lines):
 
337
    iter_lines = iter_lines_handle_nl(iter_lines)
292
338
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
293
339
 
294
340
 
310
356
            return i;
311
357
    return None
312
358
 
313
 
 
314
 
def test():
315
 
    import unittest
316
 
    class PatchesTester(unittest.TestCase):
317
 
        def testValidPatchHeader(self):
318
 
            """Parse a valid patch header"""
319
 
            lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
320
 
            (orig, mod) = get_patch_names(lines.__iter__())
321
 
            assert(orig == "orig/commands.py")
322
 
            assert(mod == "mod/dommands.py")
323
 
 
324
 
        def testInvalidPatchHeader(self):
325
 
            """Parse an invalid patch header"""
326
 
            lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
327
 
            self.assertRaises(MalformedPatchHeader, get_patch_names,
328
 
                              lines.__iter__())
329
 
 
330
 
        def testValidHunkHeader(self):
331
 
            """Parse a valid hunk header"""
332
 
            header = "@@ -34,11 +50,6 @@\n"
333
 
            hunk = hunk_from_header(header);
334
 
            assert (hunk.orig_pos == 34)
335
 
            assert (hunk.orig_range == 11)
336
 
            assert (hunk.mod_pos == 50)
337
 
            assert (hunk.mod_range == 6)
338
 
            assert (str(hunk) == header)
339
 
 
340
 
        def testValidHunkHeader2(self):
341
 
            """Parse a tricky, valid hunk header"""
342
 
            header = "@@ -1 +0,0 @@\n"
343
 
            hunk = hunk_from_header(header);
344
 
            assert (hunk.orig_pos == 1)
345
 
            assert (hunk.orig_range == 1)
346
 
            assert (hunk.mod_pos == 0)
347
 
            assert (hunk.mod_range == 0)
348
 
            assert (str(hunk) == header)
349
 
 
350
 
        def makeMalformed(self, header):
351
 
            self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
352
 
 
353
 
        def testInvalidHeader(self):
354
 
            """Parse an invalid hunk header"""
355
 
            self.makeMalformed(" -34,11 +50,6 \n")
356
 
            self.makeMalformed("@@ +50,6 -34,11 @@\n")
357
 
            self.makeMalformed("@@ -34,11 +50,6 @@")
358
 
            self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
359
 
            self.makeMalformed("@@-34,11 +50,6@@\n")
360
 
            self.makeMalformed("@@ 34,11 50,6 @@\n")
361
 
            self.makeMalformed("@@ -34,11 @@\n")
362
 
            self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
363
 
            self.makeMalformed("@@ -34,11 +50,-6 @@\n")
364
 
 
365
 
        def lineThing(self,text, type):
366
 
            line = parse_line(text)
367
 
            assert(isinstance(line, type))
368
 
            assert(str(line)==text)
369
 
 
370
 
        def makeMalformedLine(self, text):
371
 
            self.assertRaises(MalformedLine, parse_line, text)
372
 
 
373
 
        def testValidLine(self):
374
 
            """Parse a valid hunk line"""
375
 
            self.lineThing(" hello\n", ContextLine)
376
 
            self.lineThing("+hello\n", InsertLine)
377
 
            self.lineThing("-hello\n", RemoveLine)
378
 
        
379
 
        def testMalformedLine(self):
380
 
            """Parse invalid valid hunk lines"""
381
 
            self.makeMalformedLine("hello\n")
382
 
        
383
 
        def compare_parsed(self, patchtext):
384
 
            lines = patchtext.splitlines(True)
385
 
            patch = parse_patch(lines.__iter__())
386
 
            pstr = str(patch)
387
 
            i = difference_index(patchtext, pstr)
388
 
            if i is not None:
389
 
                print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
390
 
            assert (patchtext == str(patch))
391
 
 
392
 
        def testAll(self):
393
 
            """Test parsing a whole patch"""
394
 
            patchtext = """--- orig/commands.py
 
359
class PatchConflict(Exception):
 
360
    def __init__(self, line_no, orig_line, patch_line):
 
361
        orig = orig_line.rstrip('\n')
 
362
        patch = str(patch_line).rstrip('\n')
 
363
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
364
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
365
        Exception.__init__(self, msg)
 
366
 
 
367
 
 
368
def iter_patched(orig_lines, patch_lines):
 
369
    """Iterate through a series of lines with a patch applied.
 
370
    This handles a single file, and does exact, not fuzzy patching.
 
371
    """
 
372
    if orig_lines is not None:
 
373
        orig_lines = orig_lines.__iter__()
 
374
    seen_patch = []
 
375
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
 
376
    get_patch_names(patch_lines)
 
377
    line_no = 1
 
378
    for hunk in iter_hunks(patch_lines):
 
379
        while line_no < hunk.orig_pos:
 
380
            orig_line = orig_lines.next()
 
381
            yield orig_line
 
382
            line_no += 1
 
383
        for hunk_line in hunk.lines:
 
384
            seen_patch.append(str(hunk_line))
 
385
            if isinstance(hunk_line, InsertLine):
 
386
                yield hunk_line.contents
 
387
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
388
                orig_line = orig_lines.next()
 
389
                if orig_line != hunk_line.contents:
 
390
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
391
                if isinstance(hunk_line, ContextLine):
 
392
                    yield orig_line
 
393
                else:
 
394
                    assert isinstance(hunk_line, RemoveLine)
 
395
                line_no += 1
 
396
                    
 
397
import unittest
 
398
import os.path
 
399
class PatchesTester(unittest.TestCase):
 
400
    def datafile(self, filename):
 
401
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
 
402
                                 filename)
 
403
        return file(data_path, "rb")
 
404
 
 
405
    def testValidPatchHeader(self):
 
406
        """Parse a valid patch header"""
 
407
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
 
408
        (orig, mod) = get_patch_names(lines.__iter__())
 
409
        assert(orig == "orig/commands.py")
 
410
        assert(mod == "mod/dommands.py")
 
411
 
 
412
    def testInvalidPatchHeader(self):
 
413
        """Parse an invalid patch header"""
 
414
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
 
415
        self.assertRaises(MalformedPatchHeader, get_patch_names,
 
416
                          lines.__iter__())
 
417
 
 
418
    def testValidHunkHeader(self):
 
419
        """Parse a valid hunk header"""
 
420
        header = "@@ -34,11 +50,6 @@\n"
 
421
        hunk = hunk_from_header(header);
 
422
        assert (hunk.orig_pos == 34)
 
423
        assert (hunk.orig_range == 11)
 
424
        assert (hunk.mod_pos == 50)
 
425
        assert (hunk.mod_range == 6)
 
426
        assert (str(hunk) == header)
 
427
 
 
428
    def testValidHunkHeader2(self):
 
429
        """Parse a tricky, valid hunk header"""
 
430
        header = "@@ -1 +0,0 @@\n"
 
431
        hunk = hunk_from_header(header);
 
432
        assert (hunk.orig_pos == 1)
 
433
        assert (hunk.orig_range == 1)
 
434
        assert (hunk.mod_pos == 0)
 
435
        assert (hunk.mod_range == 0)
 
436
        assert (str(hunk) == header)
 
437
 
 
438
    def makeMalformed(self, header):
 
439
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
 
440
 
 
441
    def testInvalidHeader(self):
 
442
        """Parse an invalid hunk header"""
 
443
        self.makeMalformed(" -34,11 +50,6 \n")
 
444
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
 
445
        self.makeMalformed("@@ -34,11 +50,6 @@")
 
446
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
 
447
        self.makeMalformed("@@-34,11 +50,6@@\n")
 
448
        self.makeMalformed("@@ 34,11 50,6 @@\n")
 
449
        self.makeMalformed("@@ -34,11 @@\n")
 
450
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
 
451
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
 
452
 
 
453
    def lineThing(self,text, type):
 
454
        line = parse_line(text)
 
455
        assert(isinstance(line, type))
 
456
        assert(str(line)==text)
 
457
 
 
458
    def makeMalformedLine(self, text):
 
459
        self.assertRaises(MalformedLine, parse_line, text)
 
460
 
 
461
    def testValidLine(self):
 
462
        """Parse a valid hunk line"""
 
463
        self.lineThing(" hello\n", ContextLine)
 
464
        self.lineThing("+hello\n", InsertLine)
 
465
        self.lineThing("-hello\n", RemoveLine)
 
466
    
 
467
    def testMalformedLine(self):
 
468
        """Parse invalid valid hunk lines"""
 
469
        self.makeMalformedLine("hello\n")
 
470
    
 
471
    def compare_parsed(self, patchtext):
 
472
        lines = patchtext.splitlines(True)
 
473
        patch = parse_patch(lines.__iter__())
 
474
        pstr = str(patch)
 
475
        i = difference_index(patchtext, pstr)
 
476
        if i is not None:
 
477
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
 
478
        self.assertEqual (patchtext, str(patch))
 
479
 
 
480
    def testAll(self):
 
481
        """Test parsing a whole patch"""
 
482
        patchtext = """--- orig/commands.py
395
483
+++ mod/commands.py
396
484
@@ -1337,7 +1337,8 @@
397
485
 
417
505
         log.save()
418
506
     try:
419
507
"""
420
 
            self.compare_parsed(patchtext)
 
508
        self.compare_parsed(patchtext)
421
509
 
422
 
        def testInit(self):
423
 
            """Handle patches missing half the position, range tuple"""
424
 
            patchtext = \
 
510
    def testInit(self):
 
511
        """Handle patches missing half the position, range tuple"""
 
512
        patchtext = \
425
513
"""--- orig/__init__.py
426
514
+++ mod/__init__.py
427
515
@@ -1 +1,2 @@
428
516
 __docformat__ = "restructuredtext en"
429
 
+__doc__ = An alternate Arch commandline interface"""
430
 
            self.compare_parsed(patchtext)
431
 
            
432
 
 
433
 
 
434
 
        def testLineLookup(self):
435
 
            import sys
436
 
            """Make sure we can accurately look up mod line from orig"""
437
 
            patch = parse_patch(open("testdata/diff"))
438
 
            orig = list(open("testdata/orig"))
439
 
            mod = list(open("testdata/mod"))
440
 
            removals = []
441
 
            for i in range(len(orig)):
442
 
                mod_pos = patch.pos_in_mod(i)
443
 
                if mod_pos is None:
444
 
                    removals.append(orig[i])
445
 
                    continue
446
 
                assert(mod[mod_pos]==orig[i])
447
 
            rem_iter = removals.__iter__()
448
 
            for hunk in patch.hunks:
449
 
                for line in hunk.lines:
450
 
                    if isinstance(line, RemoveLine):
451
 
                        next = rem_iter.next()
452
 
                        if line.contents != next:
453
 
                            sys.stdout.write(" orig:%spatch:%s" % (next,
454
 
                                             line.contents))
455
 
                        assert(line.contents == next)
456
 
            self.assertRaises(StopIteration, rem_iter.next)
457
 
 
458
 
        def testFirstLineRenumber(self):
459
 
            """Make sure we handle lines at the beginning of the hunk"""
460
 
            patch = parse_patch(open("testdata/insert_top.patch"))
461
 
            assert (patch.pos_in_mod(0)==1)
462
 
    
463
 
            
 
517
+__doc__ = An alternate Arch commandline interface
 
518
"""
 
519
        self.compare_parsed(patchtext)
 
520
        
 
521
 
 
522
 
 
523
    def testLineLookup(self):
 
524
        import sys
 
525
        """Make sure we can accurately look up mod line from orig"""
 
526
        patch = parse_patch(self.datafile("diff"))
 
527
        orig = list(self.datafile("orig"))
 
528
        mod = list(self.datafile("mod"))
 
529
        removals = []
 
530
        for i in range(len(orig)):
 
531
            mod_pos = patch.pos_in_mod(i)
 
532
            if mod_pos is None:
 
533
                removals.append(orig[i])
 
534
                continue
 
535
            assert(mod[mod_pos]==orig[i])
 
536
        rem_iter = removals.__iter__()
 
537
        for hunk in patch.hunks:
 
538
            for line in hunk.lines:
 
539
                if isinstance(line, RemoveLine):
 
540
                    next = rem_iter.next()
 
541
                    if line.contents != next:
 
542
                        sys.stdout.write(" orig:%spatch:%s" % (next,
 
543
                                         line.contents))
 
544
                    assert(line.contents == next)
 
545
        self.assertRaises(StopIteration, rem_iter.next)
 
546
 
 
547
    def testFirstLineRenumber(self):
 
548
        """Make sure we handle lines at the beginning of the hunk"""
 
549
        patch = parse_patch(self.datafile("insert_top.patch"))
 
550
        assert (patch.pos_in_mod(0)==1)
 
551
 
 
552
def test():
464
553
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
465
554
    runner = unittest.TextTestRunner(verbosity=0)
466
555
    return runner.run(patchesTestSuite)