~bzr-pqm/bzr/bzr.dev

0.5.107 by John Arbash Meinel
Added a few more test patches.
1
First line change
2
# Copyright (C) 2004, 2005 Aaron Bentley
3
# <aaron.bentley@utoronto.ca>
4
#
5
#    This program is free software; you can redistribute it and/or modify
6
#    it under the terms of the GNU General Public License as published by
7
#    the Free Software Foundation; either version 2 of the License, or
8
#    (at your option) any later version.
9
#
10
#    This program is distributed in the hope that it will be useful,
11
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
12
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
#    GNU General Public License for more details.
14
#
15
#    You should have received a copy of the GNU General Public License
16
#    along with this program; if not, write to the Free Software
17
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
18
19
class PatchSyntax(Exception):
20
    def __init__(self, msg):
21
        Exception.__init__(self, msg)
22
23
24
class MalformedPatchHeader(PatchSyntax):
25
    def __init__(self, desc, line):
26
        self.desc = desc
27
        self.line = line
28
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
29
        PatchSyntax.__init__(self, msg)
30
31
class MalformedHunkHeader(PatchSyntax):
32
    def __init__(self, desc, line):
33
        self.desc = desc
34
        self.line = line
35
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
36
        PatchSyntax.__init__(self, msg)
37
38
class MalformedLine(PatchSyntax):
39
    def __init__(self, desc, line):
40
        self.desc = desc
41
        self.line = line
42
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
43
        PatchSyntax.__init__(self, msg)
44
45
def get_patch_names(iter_lines):
46
    try:
47
        line = iter_lines.next()
48
        if not line.startswith("--- "):
49
            raise MalformedPatchHeader("No orig name", line)
50
        else:
51
            orig_name = line[4:].rstrip("\n")
52
    except StopIteration:
53
        raise MalformedPatchHeader("No orig line", "")
54
    try:
55
        line = iter_lines.next()
56
        if not line.startswith("+++ "):
57
            raise PatchSyntax("No mod name")
58
        else:
59
            mod_name = line[4:].rstrip("\n")
60
    except StopIteration:
61
        raise MalformedPatchHeader("No mod line", "")
62
    return (orig_name, mod_name)
63
64
def parse_range(textrange):
65
    """Parse a patch range, handling the "1" special-case
66
67
    :param textrange: The text to parse
68
    :type textrange: str
69
    :return: the position and range, as a tuple
70
    :rtype: (int, int)
71
    """
72
    tmp = textrange.split(',')
73
    if len(tmp) == 1:
74
        pos = tmp[0]
75
        range = "1"
76
    else:
77
        (pos, range) = tmp
78
    pos = int(pos)
79
    range = int(range)
80
    return (pos, range)
81
82
 
83
def hunk_from_header(line):
84
    if not line.startswith("@@") or not line.endswith("@@\n") \
85
        or not len(line) > 4:
86
        raise MalformedHunkHeader("Does not start and end with @@.", line)
87
    try:
88
        (orig, mod) = line[3:-4].split(" ")
89
    except Exception, e:
90
        raise MalformedHunkHeader(str(e), line)
91
    if not orig.startswith('-') or not mod.startswith('+'):
92
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
93
    try:
94
        (orig_pos, orig_range) = parse_range(orig[1:])
95
        (mod_pos, mod_range) = parse_range(mod[1:])
96
    except Exception, e:
97
        raise MalformedHunkHeader(str(e), line)
98
    if mod_range < 0 or orig_range < 0:
99
        raise MalformedHunkHeader("Hunk range is negative", line)
100
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
101
102
103
class HunkLine:
104
    def __init__(self, contents):
105
        self.contents = contents
106
107
    def get_str(self, leadchar):
108
        if self.contents == "\n" and leadchar == " " and False:
109
            return "\n"
110
        if not self.contents.endswith('\n'):
111
            terminator = '\n' + NO_NL
112
        else:
113
            terminator = ''
114
        return leadchar + self.contents + terminator
115
116
117
class ContextLine(HunkLine):
118
    def __init__(self, contents):
119
        HunkLine.__init__(self, contents)
120
121
    def __str__(self):
122
        return self.get_str(" ")
123
124
125
class InsertLine(HunkLine):
126
    def __init__(self, contents):
127
        HunkLine.__init__(self, contents)
128
129
    def __str__(self):
130
        return self.get_str("+")
131
132
133
class RemoveLine(HunkLine):
134
    def __init__(self, contents):
135
        HunkLine.__init__(self, contents)
136
137
    def __str__(self):
138
        return self.get_str("-")
139
140
NO_NL = '\\ No newline at end of file\n'
141
__pychecker__="no-returnvalues"
142
143
def parse_line(line):
144
    if line.startswith("\n"):
145
        return ContextLine(line)
146
    elif line.startswith(" "):
147
        return ContextLine(line[1:])
148
    elif line.startswith("+"):
149
        return InsertLine(line[1:])
150
    elif line.startswith("-"):
151
        return RemoveLine(line[1:])
152
    elif line == NO_NL:
153
        return NO_NL
154
    else:
155
        raise MalformedLine("Unknown line type", line)
156
__pychecker__=""
157
158
159
class Hunk:
160
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
161
        self.orig_pos = orig_pos
162
        self.orig_range = orig_range
163
        self.mod_pos = mod_pos
164
        self.mod_range = mod_range
165
        self.lines = []
166
167
    def get_header(self):
168
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
169
                                                   self.orig_range),
170
                                    self.range_str(self.mod_pos, 
171
                                                   self.mod_range))
172
173
    def range_str(self, pos, range):
174
        """Return a file range, special-casing for 1-line files.
175
176
        :param pos: The position in the file
177
        :type pos: int
178
        :range: The range in the file
179
        :type range: int
180
        :return: a string in the format 1,4 except when range == pos == 1
181
        """
182
        if range == 1:
183
            return "%i" % pos
184
        else:
185
            return "%i,%i" % (pos, range)
186
187
    def __str__(self):
188
        lines = [self.get_header()]
189
        for line in self.lines:
190
            lines.append(str(line))
191
        return "".join(lines)
192
193
    def shift_to_mod(self, pos):
194
        if pos < self.orig_pos-1:
195
            return 0
196
        elif pos > self.orig_pos+self.orig_range:
197
            return self.mod_range - self.orig_range
198
        else:
199
            return self.shift_to_mod_lines(pos)
200
201
    def shift_to_mod_lines(self, pos):
202
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
203
        position = self.orig_pos-1
204
        shift = 0
205
        for line in self.lines:
206
            if isinstance(line, InsertLine):
207
                shift += 1
208
            elif isinstance(line, RemoveLine):
209
                if position == pos:
210
                    return None
211
                shift -= 1
212
                position += 1
213
            elif isinstance(line, ContextLine):
214
                position += 1
215
            if position > pos:
216
                break
217
        return shift
218
219
def iter_hunks(iter_lines):
220
    hunk = None
221
    for line in iter_lines:
222
        if line == "\n":
223
            if hunk is not None:
224
                yield hunk
225
                hunk = None
226
            continue
227
        if hunk is not None:
228
            yield hunk
229
        hunk = hunk_from_header(line)
230
        orig_size = 0
231
        mod_size = 0
232
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
233
            hunk_line = parse_line(iter_lines.next())
234
            hunk.lines.append(hunk_line)
235
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
236
                orig_size += 1
237
            if isinstance(hunk_line, (InsertLine, ContextLine)):
238
                mod_size += 1
239
    if hunk is not None:
240
        yield hunk
241
242
class Patch:
243
    def __init__(self, oldname, newname):
244
        self.oldname = oldname
245
        self.newname = newname
246
        self.hunks = []
247
248
    def __str__(self):
249
        ret = self.get_header() 
250
        ret += "".join([str(h) for h in self.hunks])
251
        return ret
252
253
    def get_header(self):
254
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
255
256
    def stats_str(self):
257
        """Return a string of patch statistics"""
258
        removes = 0
259
        inserts = 0
260
        for hunk in self.hunks:
261
            for line in hunk.lines:
262
                if isinstance(line, InsertLine):
263
                     inserts+=1;
264
                elif isinstance(line, RemoveLine):
265
                     removes+=1;
266
        return "%i inserts, %i removes in %i hunks" % \
267
            (inserts, removes, len(self.hunks))
268
269
    def pos_in_mod(self, position):
270
        newpos = position
271
        for hunk in self.hunks:
272
            shift = hunk.shift_to_mod(position)
273
            if shift is None:
274
                return None
275
            newpos += shift
276
        return newpos
277
            
278
    def iter_inserted(self):
279
        """Iteraties through inserted lines
280
        
281
        :return: Pair of line number, line
282
        :rtype: iterator of (int, InsertLine)
283
        """
284
        for hunk in self.hunks:
285
            pos = hunk.mod_pos - 1;
286
            for line in hunk.lines:
287
                if isinstance(line, InsertLine):
288
                    yield (pos, line)
289
                    pos += 1
290
                if isinstance(line, ContextLine):
291
                    pos += 1
292
293
def parse_patch(iter_lines):
294
    (orig_name, mod_name) = get_patch_names(iter_lines)
295
    patch = Patch(orig_name, mod_name)
296
    for hunk in iter_hunks(iter_lines):
297
        patch.hunks.append(hunk)
298
    return patch
299
300
301
def iter_file_patch(iter_lines):
302
    saved_lines = []
303
    for line in iter_lines:
304
        if line.startswith('=== '):
305
            continue
306
        elif line.startswith('--- '):
307
            if len(saved_lines) > 0:
308
                yield saved_lines
309
            saved_lines = []
310
        saved_lines.append(line)
311
    if len(saved_lines) > 0:
312
        yield saved_lines
313
314
315
def iter_lines_handle_nl(iter_lines):
316
    """
317
    Iterates through lines, ensuring that lines that originally had no
318
    terminating \n are produced without one.  This transformation may be
319
    applied at any point up until hunk line parsing, and is safe to apply
320
    repeatedly.
321
    """
322
    last_line = None
323
    for line in iter_lines:
324
        if line == NO_NL:
325
            assert last_line.endswith('\n')
326
            last_line = last_line[:-1]
327
            line = None
328
        if last_line is not None:
329
            yield last_line
330
        last_line = line
331
    if last_line is not None:
332
        yield last_line
333
334
335
def parse_patches(iter_lines):
336
    iter_lines = iter_lines_handle_nl(iter_lines)
337
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
338
339
340
def difference_index(atext, btext):
341
    """Find the indext of the first character that differs betweeen two texts
342
343
    :param atext: The first text
344
    :type atext: str
345
    :param btext: The second text
346
    :type str: str
347
    :return: The index, or None if there are no differences within the range
348
    :rtype: int or NoneType
349
    """
350
    length = len(atext)
351
    if len(btext) < length:
352
        length = len(btext)
353
    for i in range(length):
354
        if atext[i] != btext[i]:
355
            return i;
356
    return None
357
358
class PatchConflict(Exception):
359
    def __init__(self, line_no, orig_line, patch_line):
360
        orig = orig_line.rstrip('\n')
361
        patch = str(patch_line).rstrip('\n')
362
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
363
            ' but patch says it should be "%s"' % (line_no, orig, patch)
364
        Exception.__init__(self, msg)
365
366
367
def iter_patched(orig_lines, patch_lines):
368
    """Iterate through a series of lines with a patch applied.
369
    This handles a single file, and does exact, not fuzzy patching.
370
    """
371
    if orig_lines is not None:
372
        orig_lines = orig_lines.__iter__()
373
    seen_patch = []
374
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
375
    get_patch_names(patch_lines)
376
    line_no = 1
377
    for hunk in iter_hunks(patch_lines):
378
        while line_no < hunk.orig_pos:
379
            orig_line = orig_lines.next()
380
            yield orig_line
381
            line_no += 1
382
        for hunk_line in hunk.lines:
383
            seen_patch.append(str(hunk_line))
384
            if isinstance(hunk_line, InsertLine):
385
                yield hunk_line.contents
386
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
387
                orig_line = orig_lines.next()
388
                if orig_line != hunk_line.contents:
389
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
390
                if isinstance(hunk_line, ContextLine):
391
                    yield orig_line
392
                else:
393
                    assert isinstance(hunk_line, RemoveLine)
394
                line_no += 1
395
    for line in orig_lines:
396
        yield line
397
                    
398
import unittest
399
import os.path
400
class PatchesTester(unittest.TestCase):
401
    def datafile(self, filename):
402
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
403
                                 filename)
404
        return file(data_path, "rb")
405
406
    def testValidPatchHeader(self):
407
        """Parse a valid patch header"""
408
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
409
        (orig, mod) = get_patch_names(lines.__iter__())
410
        assert(orig == "orig/commands.py")
411
        assert(mod == "mod/dommands.py")
412
413
    def testInvalidPatchHeader(self):
414
        """Parse an invalid patch header"""
415
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
416
        self.assertRaises(MalformedPatchHeader, get_patch_names,
417
                          lines.__iter__())
418
419
    def testValidHunkHeader(self):
420
        """Parse a valid hunk header"""
421
        header = "@@ -34,11 +50,6 @@\n"
422
        hunk = hunk_from_header(header);
423
        assert (hunk.orig_pos == 34)
424
        assert (hunk.orig_range == 11)
425
        assert (hunk.mod_pos == 50)
426
        assert (hunk.mod_range == 6)
427
        assert (str(hunk) == header)
428
429
    def testValidHunkHeader2(self):
430
        """Parse a tricky, valid hunk header"""
431
        header = "@@ -1 +0,0 @@\n"
432
        hunk = hunk_from_header(header);
433
        assert (hunk.orig_pos == 1)
434
        assert (hunk.orig_range == 1)
435
        assert (hunk.mod_pos == 0)
436
        assert (hunk.mod_range == 0)
437
        assert (str(hunk) == header)
438
439
    def makeMalformed(self, header):
440
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
441
442
    def testInvalidHeader(self):
443
        """Parse an invalid hunk header"""
444
        self.makeMalformed(" -34,11 +50,6 \n")
445
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
446
        self.makeMalformed("@@ -34,11 +50,6 @@")
447
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
448
        self.makeMalformed("@@-34,11 +50,6@@\n")
449
        self.makeMalformed("@@ 34,11 50,6 @@\n")
450
        self.makeMalformed("@@ -34,11 @@\n")
451
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
452
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
453
454
    def lineThing(self,text, type):
455
        line = parse_line(text)
456
        assert(isinstance(line, type))
457
        assert(str(line)==text)
458
459
    def makeMalformedLine(self, text):
460
        self.assertRaises(MalformedLine, parse_line, text)
461
462
    def testValidLine(self):
463
        """Parse a valid hunk line"""
464
        self.lineThing(" hello\n", ContextLine)
465
        self.lineThing("+hello\n", InsertLine)
466
        self.lineThing("-hello\n", RemoveLine)
467
    
468
    def testMalformedLine(self):
469
        """Parse invalid valid hunk lines"""
470
        self.makeMalformedLine("hello\n")
471
    
472
    def compare_parsed(self, patchtext):
473
        lines = patchtext.splitlines(True)
474
        patch = parse_patch(lines.__iter__())
475
        pstr = str(patch)
476
        i = difference_index(patchtext, pstr)
477
        if i is not None:
478
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
479
        self.assertEqual (patchtext, str(patch))
480
481
    def testAll(self):
482
        """Test parsing a whole patch"""
483
        patchtext = """--- orig/commands.py
484
+++ mod/commands.py
485
@@ -1337,7 +1337,8 @@
486
 
487
     def set_title(self, command=None):
488
         try:
489
-            version = self.tree.tree_version.nonarch
490
+            version = pylon.alias_or_version(self.tree.tree_version, self.tree,
491
+                                             full=False)
492
         except:
493
             version = "[no version]"
494
         if command is None:
495
@@ -1983,7 +1984,11 @@
496
                                          version)
497
         if len(new_merges) > 0:
498
             if cmdutil.prompt("Log for merge"):
499
-                mergestuff = cmdutil.log_for_merge(tree, comp_version)
500
+                if cmdutil.prompt("changelog for merge"):
501
+                    mergestuff = "Patches applied:\\n"
502
+                    mergestuff += pylon.changelog_for_merge(new_merges)
503
+                else:
504
+                    mergestuff = cmdutil.log_for_merge(tree, comp_version)
505
                 log.description += mergestuff
506
         log.save()
507
     try:
508
"""
509
        self.compare_parsed(patchtext)
510
511
    def testInit(self):
512
        """Handle patches missing half the position, range tuple"""
513
        patchtext = \
514
"""--- orig/__init__.py
515
+++ mod/__init__.py
516
@@ -1 +1,2 @@
517
 __docformat__ = "restructuredtext en"
518
+__doc__ = An alternate Arch commandline interface
519
"""
520
        self.compare_parsed(patchtext)
521
        
522
523
524
    def testLineLookup(self):
525
        import sys
526
        """Make sure we can accurately look up mod line from orig"""
527
        patch = parse_patch(self.datafile("diff"))
528
        orig = list(self.datafile("orig"))
529
        mod = list(self.datafile("mod"))
530
        removals = []
531
        for i in range(len(orig)):
532
            mod_pos = patch.pos_in_mod(i)
533
            if mod_pos is None:
534
                removals.append(orig[i])
535
                continue
536
            assert(mod[mod_pos]==orig[i])
537
        rem_iter = removals.__iter__()
538
        for hunk in patch.hunks:
539
            for line in hunk.lines:
540
                if isinstance(line, RemoveLine):
541
                    next = rem_iter.next()
542
                    if line.contents != next:
543
                        sys.stdout.write(" orig:%spatch:%s" % (next,
544
                                         line.contents))
545
                    assert(line.contents == next)
546
        self.assertRaises(StopIteration, rem_iter.next)
547
548
    def testFirstLineRenumber(self):
549
        """Make sure we handle lines at the beginning of the hunk"""
550
        patch = parse_patch(self.datafile("insert_top.patch"))
551
        assert (patch.pos_in_mod(0)==1)
552
553
def test():
554
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
555
    runner = unittest.TextTestRunner(verbosity=0)
556
    return runner.run(patchesTestSuite)
557
    
558
559
if __name__ == "__main__":
560
    test()
561
# arch-tag: d1541a25-eac5-4de9-a476-08a7cecd5683