~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_patches_data/mod-3

Restore test_patches_data

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
First line change
 
2
# Copyright (C) 2004, 2005 Aaron Bentley
 
3
# <aaron.bentley@utoronto.ca>
 
4
#
 
5
#    This program is free software; you can redistribute it and/or modify
 
6
#    it under the terms of the GNU General Public License as published by
 
7
#    the Free Software Foundation; either version 2 of the License, or
 
8
#    (at your option) any later version.
 
9
#
 
10
#    This program is distributed in the hope that it will be useful,
 
11
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
 
12
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
13
#    GNU General Public License for more details.
 
14
#
 
15
#    You should have received a copy of the GNU General Public License
 
16
#    along with this program; if not, write to the Free Software
 
17
#    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
18
 
 
19
class PatchSyntax(Exception):
 
20
    def __init__(self, msg):
 
21
        Exception.__init__(self, msg)
 
22
 
 
23
 
 
24
class MalformedPatchHeader(PatchSyntax):
 
25
    def __init__(self, desc, line):
 
26
        self.desc = desc
 
27
        self.line = line
 
28
        msg = "Malformed patch header.  %s\n%r" % (self.desc, self.line)
 
29
        PatchSyntax.__init__(self, msg)
 
30
 
 
31
class MalformedHunkHeader(PatchSyntax):
 
32
    def __init__(self, desc, line):
 
33
        self.desc = desc
 
34
        self.line = line
 
35
        msg = "Malformed hunk header.  %s\n%r" % (self.desc, self.line)
 
36
        PatchSyntax.__init__(self, msg)
 
37
 
 
38
class MalformedLine(PatchSyntax):
 
39
    def __init__(self, desc, line):
 
40
        self.desc = desc
 
41
        self.line = line
 
42
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
 
43
        PatchSyntax.__init__(self, msg)
 
44
 
 
45
def get_patch_names(iter_lines):
 
46
    try:
 
47
        line = iter_lines.next()
 
48
        if not line.startswith("--- "):
 
49
            raise MalformedPatchHeader("No orig name", line)
 
50
        else:
 
51
            orig_name = line[4:].rstrip("\n")
 
52
    except StopIteration:
 
53
        raise MalformedPatchHeader("No orig line", "")
 
54
    try:
 
55
        line = iter_lines.next()
 
56
        if not line.startswith("+++ "):
 
57
            raise PatchSyntax("No mod name")
 
58
        else:
 
59
            mod_name = line[4:].rstrip("\n")
 
60
    except StopIteration:
 
61
        raise MalformedPatchHeader("No mod line", "")
 
62
    return (orig_name, mod_name)
 
63
 
 
64
def parse_range(textrange):
 
65
    """Parse a patch range, handling the "1" special-case
 
66
 
 
67
    :param textrange: The text to parse
 
68
    :type textrange: str
 
69
    :return: the position and range, as a tuple
 
70
    :rtype: (int, int)
 
71
    """
 
72
    tmp = textrange.split(',')
 
73
    if len(tmp) == 1:
 
74
        pos = tmp[0]
 
75
        range = "1"
 
76
    else:
 
77
        (pos, range) = tmp
 
78
    pos = int(pos)
 
79
    range = int(range)
 
80
    return (pos, range)
 
81
 
 
82
 
 
83
def hunk_from_header(line):
 
84
    if not line.startswith("@@") or not line.endswith("@@\n") \
 
85
        or not len(line) > 4:
 
86
        raise MalformedHunkHeader("Does not start and end with @@.", line)
 
87
    try:
 
88
        (orig, mod) = line[3:-4].split(" ")
 
89
    except Exception, e:
 
90
        raise MalformedHunkHeader(str(e), line)
 
91
    if not orig.startswith('-') or not mod.startswith('+'):
 
92
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
 
93
    try:
 
94
        (orig_pos, orig_range) = parse_range(orig[1:])
 
95
        (mod_pos, mod_range) = parse_range(mod[1:])
 
96
    except Exception, e:
 
97
        raise MalformedHunkHeader(str(e), line)
 
98
    if mod_range < 0 or orig_range < 0:
 
99
        raise MalformedHunkHeader("Hunk range is negative", line)
 
100
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
 
101
 
 
102
 
 
103
class HunkLine:
 
104
    def __init__(self, contents):
 
105
        self.contents = contents
 
106
 
 
107
    def get_str(self, leadchar):
 
108
        if self.contents == "\n" and leadchar == " " and False:
 
109
            return "\n"
 
110
        if not self.contents.endswith('\n'):
 
111
            terminator = '\n' + NO_NL
 
112
        else:
 
113
            terminator = ''
 
114
        return leadchar + self.contents + terminator
 
115
 
 
116
 
 
117
class ContextLine(HunkLine):
 
118
    def __init__(self, contents):
 
119
        HunkLine.__init__(self, contents)
 
120
 
 
121
    def __str__(self):
 
122
        return self.get_str(" ")
 
123
 
 
124
 
 
125
class InsertLine(HunkLine):
 
126
    def __init__(self, contents):
 
127
        HunkLine.__init__(self, contents)
 
128
 
 
129
    def __str__(self):
 
130
        return self.get_str("+")
 
131
 
 
132
 
 
133
class RemoveLine(HunkLine):
 
134
    def __init__(self, contents):
 
135
        HunkLine.__init__(self, contents)
 
136
 
 
137
    def __str__(self):
 
138
        return self.get_str("-")
 
139
 
 
140
NO_NL = '\\ No newline at end of file\n'
 
141
__pychecker__="no-returnvalues"
 
142
 
 
143
def parse_line(line):
 
144
    if line.startswith("\n"):
 
145
        return ContextLine(line)
 
146
    elif line.startswith(" "):
 
147
        return ContextLine(line[1:])
 
148
    elif line.startswith("+"):
 
149
        return InsertLine(line[1:])
 
150
    elif line.startswith("-"):
 
151
        return RemoveLine(line[1:])
 
152
    elif line == NO_NL:
 
153
        return NO_NL
 
154
    else:
 
155
        raise MalformedLine("Unknown line type", line)
 
156
__pychecker__=""
 
157
 
 
158
 
 
159
class Hunk:
 
160
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
 
161
        self.orig_pos = orig_pos
 
162
        self.orig_range = orig_range
 
163
        self.mod_pos = mod_pos
 
164
        self.mod_range = mod_range
 
165
        self.lines = []
 
166
 
 
167
    def get_header(self):
 
168
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
 
169
                                                   self.orig_range),
 
170
                                    self.range_str(self.mod_pos, 
 
171
                                                   self.mod_range))
 
172
 
 
173
    def range_str(self, pos, range):
 
174
        """Return a file range, special-casing for 1-line files.
 
175
 
 
176
        :param pos: The position in the file
 
177
        :type pos: int
 
178
        :range: The range in the file
 
179
        :type range: int
 
180
        :return: a string in the format 1,4 except when range == pos == 1
 
181
        """
 
182
        if range == 1:
 
183
            return "%i" % pos
 
184
        else:
 
185
            return "%i,%i" % (pos, range)
 
186
 
 
187
    def __str__(self):
 
188
        lines = [self.get_header()]
 
189
        for line in self.lines:
 
190
            lines.append(str(line))
 
191
        return "".join(lines)
 
192
 
 
193
    def shift_to_mod(self, pos):
 
194
        if pos < self.orig_pos-1:
 
195
            return 0
 
196
        elif pos > self.orig_pos+self.orig_range:
 
197
            return self.mod_range - self.orig_range
 
198
        else:
 
199
            return self.shift_to_mod_lines(pos)
 
200
 
 
201
    def shift_to_mod_lines(self, pos):
 
202
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
 
203
        position = self.orig_pos-1
 
204
        shift = 0
 
205
        for line in self.lines:
 
206
            if isinstance(line, InsertLine):
 
207
                shift += 1
 
208
            elif isinstance(line, RemoveLine):
 
209
                if position == pos:
 
210
                    return None
 
211
                shift -= 1
 
212
                position += 1
 
213
            elif isinstance(line, ContextLine):
 
214
                position += 1
 
215
            if position > pos:
 
216
                break
 
217
        return shift
 
218
 
 
219
def iter_hunks(iter_lines):
 
220
    hunk = None
 
221
    for line in iter_lines:
 
222
        if line == "\n":
 
223
            if hunk is not None:
 
224
                yield hunk
 
225
                hunk = None
 
226
            continue
 
227
        if hunk is not None:
 
228
            yield hunk
 
229
        hunk = hunk_from_header(line)
 
230
        orig_size = 0
 
231
        mod_size = 0
 
232
        while orig_size < hunk.orig_range or mod_size < hunk.mod_range:
 
233
            hunk_line = parse_line(iter_lines.next())
 
234
            hunk.lines.append(hunk_line)
 
235
            if isinstance(hunk_line, (RemoveLine, ContextLine)):
 
236
                orig_size += 1
 
237
            if isinstance(hunk_line, (InsertLine, ContextLine)):
 
238
                mod_size += 1
 
239
    if hunk is not None:
 
240
        yield hunk
 
241
 
 
242
class Patch:
 
243
    def __init__(self, oldname, newname):
 
244
        self.oldname = oldname
 
245
        self.newname = newname
 
246
        self.hunks = []
 
247
 
 
248
    def __str__(self):
 
249
        ret = self.get_header() 
 
250
        ret += "".join([str(h) for h in self.hunks])
 
251
        return ret
 
252
 
 
253
    def get_header(self):
 
254
        return "--- %s\n+++ %s\n" % (self.oldname, self.newname)
 
255
 
 
256
    def stats_str(self):
 
257
        """Return a string of patch statistics"""
 
258
        removes = 0
 
259
        inserts = 0
 
260
        for hunk in self.hunks:
 
261
            for line in hunk.lines:
 
262
                if isinstance(line, InsertLine):
 
263
                     inserts+=1;
 
264
                elif isinstance(line, RemoveLine):
 
265
                     removes+=1;
 
266
        return "%i inserts, %i removes in %i hunks" % \
 
267
            (inserts, removes, len(self.hunks))
 
268
 
 
269
    def pos_in_mod(self, position):
 
270
        newpos = position
 
271
        for hunk in self.hunks:
 
272
            shift = hunk.shift_to_mod(position)
 
273
            if shift is None:
 
274
                return None
 
275
            newpos += shift
 
276
        return newpos
 
277
            
 
278
    def iter_inserted(self):
 
279
        """Iteraties through inserted lines
 
280
        
 
281
        :return: Pair of line number, line
 
282
        :rtype: iterator of (int, InsertLine)
 
283
        """
 
284
        for hunk in self.hunks:
 
285
            pos = hunk.mod_pos - 1;
 
286
            for line in hunk.lines:
 
287
                if isinstance(line, InsertLine):
 
288
                    yield (pos, line)
 
289
                    pos += 1
 
290
                if isinstance(line, ContextLine):
 
291
                    pos += 1
 
292
 
 
293
def parse_patch(iter_lines):
 
294
    (orig_name, mod_name) = get_patch_names(iter_lines)
 
295
    patch = Patch(orig_name, mod_name)
 
296
    for hunk in iter_hunks(iter_lines):
 
297
        patch.hunks.append(hunk)
 
298
    return patch
 
299
 
 
300
 
 
301
def iter_file_patch(iter_lines):
 
302
    saved_lines = []
 
303
    for line in iter_lines:
 
304
        if line.startswith('=== '):
 
305
            continue
 
306
        elif line.startswith('--- '):
 
307
            if len(saved_lines) > 0:
 
308
                yield saved_lines
 
309
            saved_lines = []
 
310
        saved_lines.append(line)
 
311
    if len(saved_lines) > 0:
 
312
        yield saved_lines
 
313
 
 
314
 
 
315
def iter_lines_handle_nl(iter_lines):
 
316
    """
 
317
    Iterates through lines, ensuring that lines that originally had no
 
318
    terminating \n are produced without one.  This transformation may be
 
319
    applied at any point up until hunk line parsing, and is safe to apply
 
320
    repeatedly.
 
321
    """
 
322
    last_line = None
 
323
    for line in iter_lines:
 
324
        if line == NO_NL:
 
325
            assert last_line.endswith('\n')
 
326
            last_line = last_line[:-1]
 
327
            line = None
 
328
        if last_line is not None:
 
329
            yield last_line
 
330
        last_line = line
 
331
    if last_line is not None:
 
332
        yield last_line
 
333
 
 
334
 
 
335
def parse_patches(iter_lines):
 
336
    iter_lines = iter_lines_handle_nl(iter_lines)
 
337
    return [parse_patch(f.__iter__()) for f in iter_file_patch(iter_lines)]
 
338
 
 
339
 
 
340
def difference_index(atext, btext):
 
341
    """Find the indext of the first character that differs betweeen two texts
 
342
 
 
343
    :param atext: The first text
 
344
    :type atext: str
 
345
    :param btext: The second text
 
346
    :type str: str
 
347
    :return: The index, or None if there are no differences within the range
 
348
    :rtype: int or NoneType
 
349
    """
 
350
    length = len(atext)
 
351
    if len(btext) < length:
 
352
        length = len(btext)
 
353
    for i in range(length):
 
354
        if atext[i] != btext[i]:
 
355
            return i;
 
356
    return None
 
357
 
 
358
class PatchConflict(Exception):
 
359
    def __init__(self, line_no, orig_line, patch_line):
 
360
        orig = orig_line.rstrip('\n')
 
361
        patch = str(patch_line).rstrip('\n')
 
362
        msg = 'Text contents mismatch at line %d.  Original has "%s",'\
 
363
            ' but patch says it should be "%s"' % (line_no, orig, patch)
 
364
        Exception.__init__(self, msg)
 
365
 
 
366
 
 
367
def iter_patched(orig_lines, patch_lines):
 
368
    """Iterate through a series of lines with a patch applied.
 
369
    This handles a single file, and does exact, not fuzzy patching.
 
370
    """
 
371
    if orig_lines is not None:
 
372
        orig_lines = orig_lines.__iter__()
 
373
    seen_patch = []
 
374
    patch_lines = iter_lines_handle_nl(patch_lines.__iter__())
 
375
    get_patch_names(patch_lines)
 
376
    line_no = 1
 
377
    for hunk in iter_hunks(patch_lines):
 
378
        while line_no < hunk.orig_pos:
 
379
            orig_line = orig_lines.next()
 
380
            yield orig_line
 
381
            line_no += 1
 
382
        for hunk_line in hunk.lines:
 
383
            seen_patch.append(str(hunk_line))
 
384
            if isinstance(hunk_line, InsertLine):
 
385
                yield hunk_line.contents
 
386
            elif isinstance(hunk_line, (ContextLine, RemoveLine)):
 
387
                orig_line = orig_lines.next()
 
388
                if orig_line != hunk_line.contents:
 
389
                    raise PatchConflict(line_no, orig_line, "".join(seen_patch))
 
390
                if isinstance(hunk_line, ContextLine):
 
391
                    yield orig_line
 
392
                else:
 
393
                    assert isinstance(hunk_line, RemoveLine)
 
394
                line_no += 1
 
395
    for line in orig_lines:
 
396
        yield line
 
397
                    
 
398
import unittest
 
399
import os.path
 
400
class PatchesTester(unittest.TestCase):
 
401
    def datafile(self, filename):
 
402
        data_path = os.path.join(os.path.dirname(__file__), "testdata", 
 
403
                                 filename)
 
404
        return file(data_path, "rb")
 
405
 
 
406
    def testValidPatchHeader(self):
 
407
        """Parse a valid patch header"""
 
408
        lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
 
409
        (orig, mod) = get_patch_names(lines.__iter__())
 
410
        assert(orig == "orig/commands.py")
 
411
        assert(mod == "mod/dommands.py")
 
412
 
 
413
    def testInvalidPatchHeader(self):
 
414
        """Parse an invalid patch header"""
 
415
        lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
 
416
        self.assertRaises(MalformedPatchHeader, get_patch_names,
 
417
                          lines.__iter__())
 
418
 
 
419
    def testValidHunkHeader(self):
 
420
        """Parse a valid hunk header"""
 
421
        header = "@@ -34,11 +50,6 @@\n"
 
422
        hunk = hunk_from_header(header);
 
423
        assert (hunk.orig_pos == 34)
 
424
        assert (hunk.orig_range == 11)
 
425
        assert (hunk.mod_pos == 50)
 
426
        assert (hunk.mod_range == 6)
 
427
        assert (str(hunk) == header)
 
428
 
 
429
    def testValidHunkHeader2(self):
 
430
        """Parse a tricky, valid hunk header"""
 
431
        header = "@@ -1 +0,0 @@\n"
 
432
        hunk = hunk_from_header(header);
 
433
        assert (hunk.orig_pos == 1)
 
434
        assert (hunk.orig_range == 1)
 
435
        assert (hunk.mod_pos == 0)
 
436
        assert (hunk.mod_range == 0)
 
437
        assert (str(hunk) == header)
 
438
 
 
439
    def makeMalformed(self, header):
 
440
        self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
 
441
 
 
442
    def testInvalidHeader(self):
 
443
        """Parse an invalid hunk header"""
 
444
        self.makeMalformed(" -34,11 +50,6 \n")
 
445
        self.makeMalformed("@@ +50,6 -34,11 @@\n")
 
446
        self.makeMalformed("@@ -34,11 +50,6 @@")
 
447
        self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
 
448
        self.makeMalformed("@@-34,11 +50,6@@\n")
 
449
        self.makeMalformed("@@ 34,11 50,6 @@\n")
 
450
        self.makeMalformed("@@ -34,11 @@\n")
 
451
        self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
 
452
        self.makeMalformed("@@ -34,11 +50,-6 @@\n")
 
453
 
 
454
    def lineThing(self,text, type):
 
455
        line = parse_line(text)
 
456
        assert(isinstance(line, type))
 
457
        assert(str(line)==text)
 
458
 
 
459
    def makeMalformedLine(self, text):
 
460
        self.assertRaises(MalformedLine, parse_line, text)
 
461
 
 
462
    def testValidLine(self):
 
463
        """Parse a valid hunk line"""
 
464
        self.lineThing(" hello\n", ContextLine)
 
465
        self.lineThing("+hello\n", InsertLine)
 
466
        self.lineThing("-hello\n", RemoveLine)
 
467
    
 
468
    def testMalformedLine(self):
 
469
        """Parse invalid valid hunk lines"""
 
470
        self.makeMalformedLine("hello\n")
 
471
    
 
472
    def compare_parsed(self, patchtext):
 
473
        lines = patchtext.splitlines(True)
 
474
        patch = parse_patch(lines.__iter__())
 
475
        pstr = str(patch)
 
476
        i = difference_index(patchtext, pstr)
 
477
        if i is not None:
 
478
            print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
 
479
        self.assertEqual (patchtext, str(patch))
 
480
 
 
481
    def testAll(self):
 
482
        """Test parsing a whole patch"""
 
483
        patchtext = """--- orig/commands.py
 
484
+++ mod/commands.py
 
485
@@ -1337,7 +1337,8 @@
 
486
 
 
487
     def set_title(self, command=None):
 
488
         try:
 
489
-            version = self.tree.tree_version.nonarch
 
490
+            version = pylon.alias_or_version(self.tree.tree_version, self.tree,
 
491
+                                             full=False)
 
492
         except:
 
493
             version = "[no version]"
 
494
         if command is None:
 
495
@@ -1983,7 +1984,11 @@
 
496
                                          version)
 
497
         if len(new_merges) > 0:
 
498
             if cmdutil.prompt("Log for merge"):
 
499
-                mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
500
+                if cmdutil.prompt("changelog for merge"):
 
501
+                    mergestuff = "Patches applied:\\n"
 
502
+                    mergestuff += pylon.changelog_for_merge(new_merges)
 
503
+                else:
 
504
+                    mergestuff = cmdutil.log_for_merge(tree, comp_version)
 
505
                 log.description += mergestuff
 
506
         log.save()
 
507
     try:
 
508
"""
 
509
        self.compare_parsed(patchtext)
 
510
 
 
511
    def testInit(self):
 
512
        """Handle patches missing half the position, range tuple"""
 
513
        patchtext = \
 
514
"""--- orig/__init__.py
 
515
+++ mod/__init__.py
 
516
@@ -1 +1,2 @@
 
517
 __docformat__ = "restructuredtext en"
 
518
+__doc__ = An alternate Arch commandline interface
 
519
"""
 
520
        self.compare_parsed(patchtext)
 
521
        
 
522
 
 
523
 
 
524
    def testLineLookup(self):
 
525
        import sys
 
526
        """Make sure we can accurately look up mod line from orig"""
 
527
        patch = parse_patch(self.datafile("diff"))
 
528
        orig = list(self.datafile("orig"))
 
529
        mod = list(self.datafile("mod"))
 
530
        removals = []
 
531
        for i in range(len(orig)):
 
532
            mod_pos = patch.pos_in_mod(i)
 
533
            if mod_pos is None:
 
534
                removals.append(orig[i])
 
535
                continue
 
536
            assert(mod[mod_pos]==orig[i])
 
537
        rem_iter = removals.__iter__()
 
538
        for hunk in patch.hunks:
 
539
            for line in hunk.lines:
 
540
                if isinstance(line, RemoveLine):
 
541
                    next = rem_iter.next()
 
542
                    if line.contents != next:
 
543
                        sys.stdout.write(" orig:%spatch:%s" % (next,
 
544
                                         line.contents))
 
545
                    assert(line.contents == next)
 
546
        self.assertRaises(StopIteration, rem_iter.next)
 
547
 
 
548
    def testFirstLineRenumber(self):
 
549
        """Make sure we handle lines at the beginning of the hunk"""
 
550
        patch = parse_patch(self.datafile("insert_top.patch"))
 
551
        assert (patch.pos_in_mod(0)==1)
 
552
 
 
553
def test():
 
554
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
 
555
    runner = unittest.TextTestRunner(verbosity=0)
 
556
    return runner.run(patchesTestSuite)
 
557
    
 
558
 
 
559
if __name__ == "__main__":
 
560
    test()
 
561
# arch-tag: d1541a25-eac5-4de9-a476-08a7cecd5683