~abentley/bzrtools/bzrtools.dev

« back to all changes in this revision

Viewing changes to patches.py

  • Committer: Michael Ellerman
  • Date: 2005-10-19 09:34:33 UTC
  • mto: (0.3.1 shelf-dev) (325.1.2 bzrtools)
  • mto: This revision was merged to the branch mainline in revision 246.
  • Revision ID: michael@ellerman.id.au-20051019093433-39720aedce6799e9
Upated patches.py to version from bzrtools-0.1.1

Show diffs side-by-side

added added

removed removed

Lines of Context:
24
24
    def __init__(self, desc, line):
25
25
        self.desc = desc
26
26
        self.line = line
27
 
        msg = "Malformed patch header.  %s\n%s" % (self.desc, self.line)
 
27
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
28
28
        PatchSyntax.__init__(self, msg)
29
29
 
30
30
class MalformedHunkHeader(PatchSyntax):
31
31
    def __init__(self, desc, line):
32
32
        self.desc = desc
33
33
        self.line = line
34
 
        msg = "Malformed hunk header.  %s\n%s" % (self.desc, self.line)
 
34
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
35
35
        PatchSyntax.__init__(self, msg)
36
36
 
37
37
class MalformedLine(PatchSyntax):
44
44
def get_patch_names(iter_lines):
45
45
    try:
46
46
        line = iter_lines.next()
47
 
        if line.startswith("*** "):
48
 
            line = iter_lines.next()
49
47
        if not line.startswith("--- "):
50
48
            raise MalformedPatchHeader("No orig name", line)
51
49
        else:
108
106
    def get_str(self, leadchar):
109
107
        if self.contents == "\n" and leadchar == " " and False:
110
108
            return "\n"
111
 
        return leadchar + self.contents
 
109
        if not self.contents.endswith('\n'):
 
110
            terminator = '\n' + NO_NL
 
111
        else:
 
112
            terminator = ''
 
113
        return leadchar + self.contents + terminator
 
114
 
112
115
 
113
116
class ContextLine(HunkLine):
114
117
    def __init__(self, contents):
125
128
    def __str__(self):
126
129
        return self.get_str("+")
127
130
 
 
131
 
128
132
class RemoveLine(HunkLine):
129
133
    def __init__(self, contents):
130
134
        HunkLine.__init__(self, contents)
132
136
    def __str__(self):
133
137
        return self.get_str("-")
134
138
 
 
139
NO_NL = '\\ No newline at end of file\n'
135
140
__pychecker__="no-returnvalues"
 
141
 
136
142
def parse_line(line):
137
143
    if line.startswith("\n"):
138
144
        return ContextLine(line)
142
148
        return InsertLine(line[1:])
143
149
    elif line.startswith("-"):
144
150
        return RemoveLine(line[1:])
 
151
    elif line == NO_NL:
 
152
        return NO_NL
145
153
    else:
146
154
        raise MalformedLine("Unknown line type", line)
147
155
__pychecker__=""
210
218
def iter_hunks(iter_lines):
211
219
    hunk = None
212
220
    for line in iter_lines:
213
 
        if line.startswith("@@"):
 
221
        if line == "\n":
214
222
            if hunk is not None:
215
223
                yield hunk
216
 
            hunk = hunk_from_header(line)
217
 
        else:
218
 
            hunk.lines.append(parse_line(line))
219
 
 
 
224
                hunk = None
 
225
            continue
 
226
        if hunk is not None:
 
227
            yield hunk
 
228
        hunk = hunk_from_header(line)
 
229
        orig_size = 0
 
230
        mod_size = 0
 
231
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
232
            hunk_line = parse_line(iter_lines.next())
 
233
            hunk.lines.append(hunk_line)
 
234
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
235
                orig_size += 1
 
236
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
237
                mod_size += 1
220
238
    if hunk is not None:
221
239
        yield hunk
222
240
 
226
244
        self.newname = newname
227
245
        self.hunks = []
228
246
 
229
 
    def get_header(self):
230
 
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
231
 
 
232
247
    def __str__(self):
233
 
        ret =  self.get_header()
 
248
        ret = self.get_header() 
234
249
        ret += "".join([str(h) for h in self.hunks])
235
250
        return ret
236
251
 
 
252
    def get_header(self):
 
253
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
254
 
237
255
    def stats_str(self):
238
256
        """Return a string of patch statistics"""
239
257
        removes = 0
278
296
        patch.hunks.append(hunk)
279
297
    return patch
280
298
 
281
 
def parse_patches(lines):
282
 
    def parse(lines):
283
 
        if len(lines) > 0:
284
 
            return [ parse_patch(lines.__iter__()) ]
285
 
        else:
286
 
            return []
287
299
 
288
 
    iter_lines = lines.__iter__()
289
 
    patches = []
 
300
def iter_file_patch(iter_lines):
290
301
    saved_lines = []
291
 
    while True:
292
 
        try: line = iter_lines.next()
293
 
        except StopIteration:
294
 
            patches.extend(parse(saved_lines))
295
 
            break
296
 
 
 
302
    for line in iter_lines:
297
303
        if line.startswith('*** '):
298
 
            patches.extend(parse(saved_lines))
 
304
            continue
 
305
        if line.startswith('==='):
 
306
            continue
 
307
        elif line.startswith('--- '):
 
308
            if len(saved_lines) > 0:
 
309
                yield saved_lines
299
310
            saved_lines = []
300
 
            continue
301
 
        elif line.startswith('--- ') and len(saved_lines) > 1:
302
 
            patches.extend(parse(saved_lines))
303
 
            saved_lines = [ line ]
304
 
            continue
305
 
 
306
311
        saved_lines.append(line)
307
 
 
308
 
    return patches
 
312
    if len(saved_lines) > 0:
 
313
        yield saved_lines
 
314
 
 
315
 
 
316
def iter_lines_handle_nl(iter_lines):
 
317
    """
 
318
    Iterates through lines, ensuring that lines that originally had no
 
319
    terminating \n are produced without one.  This transformation may be
 
320
    applied at any point up until hunk line parsing, and is safe to apply
 
321
    repeatedly.
 
322
    """
 
323
    last_line = None
 
324
    for line in iter_lines:
 
325
        if line == NO_NL:
 
326
            assert last_line.endswith('\n')
 
327
            last_line = last_line[:-1]
 
328
            line = None
 
329
        if last_line is not None:
 
330
            yield last_line
 
331
        last_line = line
 
332
    if last_line is not None:
 
333
        yield last_line
 
334
 
 
335
 
 
336
def parse_patches(iter_lines):
 
337
    iter_lines = iter_lines_handle_nl(iter_lines)
 
338
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
 
339
 
 
340
 
 
341
def difference_index(atext, btext):
 
342
    """Find the indext of the first character that differs betweeen two texts
 
343
 
 
344
    :param atext: The first text
 
345
    :type atext: str
 
346
    :param btext: The second text
 
347
    :type str: str
 
348
    :return: The index, or None if there are no differences within the range
 
349
    :rtype: int or NoneType
 
350
    """
 
351
    length = len(atext)
 
352
    if len(btext) < length:
 
353
        length = len(btext)
 
354
    for i in range(length):
 
355
        if atext[i] != btext[i]:
 
356
            return i;
 
357
    return None
 
358
 
 
359
class PatchConflict(Exception):
 
360
    def __init__(self, line_no, orig_line, patch_line):
 
361
        orig = orig_line.rstrip('\n')
 
362
        patch = str(patch_line).rstrip('\n')
 
363
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
364
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
365
        Exception.__init__(self, msg)
 
366
 
 
367
 
 
368
def iter_patched(orig_lines, patch_lines):
 
369
    """Iterate through a series of lines with a patch applied.
 
370
    This handles a single file, and does exact, not fuzzy patching.
 
371
    """
 
372
    if orig_lines is not None:
 
373
        orig_lines = orig_lines.__iter__()
 
374
    seen_patch = []
 
375
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
 
376
    get_patch_names(patch_lines)
 
377
    line_no = 1
 
378
    for hunk in iter_hunks(patch_lines):
 
379
        while line_no < hunk.orig_pos:
 
380
            orig_line = orig_lines.next()
 
381
            yield orig_line
 
382
            line_no += 1
 
383
        for hunk_line in hunk.lines:
 
384
            seen_patch.append(str(hunk_line))
 
385
            if isinstance(hunk_line, InsertLine):
 
386
                yield hunk_line.contents
 
387
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
388
                orig_line = orig_lines.next()
 
389
                if orig_line != hunk_line.contents:
 
390
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
391
                if isinstance(hunk_line, ContextLine):
 
392
                    yield orig_line
 
393
                else:
 
394
                    assert isinstance(hunk_line, RemoveLine)
 
395
                line_no += 1
 
396
                    
 
397
import unittest
 
398
import os.path
 
399
class PatchesTester(unittest.TestCase):
 
400
    def datafile(self, filename):
 
401
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
 
402
                                 filename)
 
403
        return file(data_path, "rb")
 
404
 
 
405
    def testValidPatchHeader(self):
 
406
        """Parse a valid patch header"""
 
407
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
 
408
        (orig, mod) = get_patch_names(lines.__iter__())
 
409
        assert(orig == "orig/commands.py")
 
410
        assert(mod == "mod/dommands.py")
 
411
 
 
412
    def testInvalidPatchHeader(self):
 
413
        """Parse an invalid patch header"""
 
414
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
 
415
        self.assertRaises(MalformedPatchHeader, get_patch_names,
 
416
                          lines.__iter__())
 
417
 
 
418
    def testValidHunkHeader(self):
 
419
        """Parse a valid hunk header"""
 
420
        header = "@@ -34,11 +50,6 @@\n"
 
421
        hunk = hunk_from_header(header);
 
422
        assert (hunk.orig_pos == 34)
 
423
        assert (hunk.orig_range == 11)
 
424
        assert (hunk.mod_pos == 50)
 
425
        assert (hunk.mod_range == 6)
 
426
        assert (str(hunk) == header)
 
427
 
 
428
    def testValidHunkHeader2(self):
 
429
        """Parse a tricky, valid hunk header"""
 
430
        header = "@@ -1 +0,0 @@\n"
 
431
        hunk = hunk_from_header(header);
 
432
        assert (hunk.orig_pos == 1)
 
433
        assert (hunk.orig_range == 1)
 
434
        assert (hunk.mod_pos == 0)
 
435
        assert (hunk.mod_range == 0)
 
436
        assert (str(hunk) == header)
 
437
 
 
438
    def makeMalformed(self, header):
 
439
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
 
440
 
 
441
    def testInvalidHeader(self):
 
442
        """Parse an invalid hunk header"""
 
443
        self.makeMalformed(" -34,11 +50,6 \n")
 
444
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
 
445
        self.makeMalformed("@@ -34,11 +50,6 @@")
 
446
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
 
447
        self.makeMalformed("@@-34,11 +50,6@@\n")
 
448
        self.makeMalformed("@@ 34,11 50,6 @@\n")
 
449
        self.makeMalformed("@@ -34,11 @@\n")
 
450
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
 
451
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
 
452
 
 
453
    def lineThing(self,text, type):
 
454
        line = parse_line(text)
 
455
        assert(isinstance(line, type))
 
456
        assert(str(line)==text)
 
457
 
 
458
    def makeMalformedLine(self, text):
 
459
        self.assertRaises(MalformedLine, parse_line, text)
 
460
 
 
461
    def testValidLine(self):
 
462
        """Parse a valid hunk line"""
 
463
        self.lineThing(" hello\n", ContextLine)
 
464
        self.lineThing("+hello\n", InsertLine)
 
465
        self.lineThing("-hello\n", RemoveLine)
 
466
    
 
467
    def testMalformedLine(self):
 
468
        """Parse invalid valid hunk lines"""
 
469
        self.makeMalformedLine("hello\n")
 
470
    
 
471
    def compare_parsed(self, patchtext):
 
472
        lines = patchtext.splitlines(True)
 
473
        patch = parse_patch(lines.__iter__())
 
474
        pstr = str(patch)
 
475
        i = difference_index(patchtext, pstr)
 
476
        if i is not None:
 
477
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
 
478
        self.assertEqual (patchtext, str(patch))
 
479
 
 
480
    def testAll(self):
 
481
        """Test parsing a whole patch"""
 
482
        patchtext = """--- orig/commands.py
 
483
+++ mod/commands.py
 
484
@@ -1337,7 +1337,8 @@
 
485
 
 
486
     def set_title(self, command=None):
 
487
         try:
 
488
-            version = self.tree.tree_version.nonarch
 
489
+            version = pylon.alias_or_version(self.tree.tree_version, self.tree,
 
490
+                                             full=False)
 
491
         except:
 
492
             version = "[no version]"
 
493
         if command is None:
 
494
@@ -1983,7 +1984,11 @@
 
495
                                          version)
 
496
         if len(new_merges) > 0:
 
497
             if cmdutil.prompt("Log for merge"):
 
498
-                mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
499
+                if cmdutil.prompt("changelog for merge"):
 
500
+                    mergestuff = "Patches applied:\\n"
 
501
+                    mergestuff += pylon.changelog_for_merge(new_merges)
 
502
+                else:
 
503
+                    mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
504
                 log.description += mergestuff
 
505
         log.save()
 
506
     try:
 
507
"""
 
508
        self.compare_parsed(patchtext)
 
509
 
 
510
    def testInit(self):
 
511
        """Handle patches missing half the position, range tuple"""
 
512
        patchtext = \
 
513
"""--- orig/__init__.py
 
514
+++ mod/__init__.py
 
515
@@ -1 +1,2 @@
 
516
 __docformat__ = "restructuredtext en"
 
517
+__doc__ = An alternate Arch commandline interface
 
518
"""
 
519
        self.compare_parsed(patchtext)
 
520
        
 
521
 
 
522
 
 
523
    def testLineLookup(self):
 
524
        import sys
 
525
        """Make sure we can accurately look up mod line from orig"""
 
526
        patch = parse_patch(self.datafile("diff"))
 
527
        orig = list(self.datafile("orig"))
 
528
        mod = list(self.datafile("mod"))
 
529
        removals = []
 
530
        for i in range(len(orig)):
 
531
            mod_pos = patch.pos_in_mod(i)
 
532
            if mod_pos is None:
 
533
                removals.append(orig[i])
 
534
                continue
 
535
            assert(mod[mod_pos]==orig[i])
 
536
        rem_iter = removals.__iter__()
 
537
        for hunk in patch.hunks:
 
538
            for line in hunk.lines:
 
539
                if isinstance(line, RemoveLine):
 
540
                    next = rem_iter.next()
 
541
                    if line.contents != next:
 
542
                        sys.stdout.write(" orig:%spatch:%s" % (next,
 
543
                                         line.contents))
 
544
                    assert(line.contents == next)
 
545
        self.assertRaises(StopIteration, rem_iter.next)
 
546
 
 
547
    def testFirstLineRenumber(self):
 
548
        """Make sure we handle lines at the beginning of the hunk"""
 
549
        patch = parse_patch(self.datafile("insert_top.patch"))
 
550
        assert (patch.pos_in_mod(0)==1)
 
551
 
 
552
def test():
 
553
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
 
554
    runner = unittest.TextTestRunner(verbosity=0)
 
555
    return runner.run(patchesTestSuite)
 
556
    
 
557
 
 
558
if __name__ == "__main__":
 
559
    test()
 
560
# arch-tag: d1541a25-eac5-4de9-a476-08a7cecd5683