~abentley/bzrtools/bzrtools.dev

« back to all changes in this revision

Viewing changes to fai/pylon/patches.py

  • Committer: Robert Collins
  • Date: 2005-09-14 11:27:20 UTC
  • mto: (147.2.6) (364.1.3 bzrtools)
  • mto: This revision was merged to the branch mainline in revision 324.
  • Revision ID: robertc@robertcollins.net-20050914112720-c66a21de86eafa6e
trim fai cribbage

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
import util
2
 
import sys
3
 
class PatchSyntax(Exception):
4
 
    def __init__(self, msg):
5
 
        Exception.__init__(self, msg)
6
 
 
7
 
 
8
 
class MalformedPatchHeader(PatchSyntax):
9
 
    def __init__(self, desc, line):
10
 
        self.desc = desc
11
 
        self.line = line
12
 
        msg = "Malformed patch header.  %s\n%s" % (self.desc, self.line)
13
 
        PatchSyntax.__init__(self, msg)
14
 
 
15
 
class MalformedHunkHeader(PatchSyntax):
16
 
    def __init__(self, desc, line):
17
 
        self.desc = desc
18
 
        self.line = line
19
 
        msg = "Malformed hunk header.  %s\n%s" % (self.desc, self.line)
20
 
        PatchSyntax.__init__(self, msg)
21
 
 
22
 
class MalformedLine(PatchSyntax):
23
 
    def __init__(self, desc, line):
24
 
        self.desc = desc
25
 
        self.line = line
26
 
        msg = "Malformed line.  %s\n%s" % (self.desc, self.line)
27
 
        PatchSyntax.__init__(self, msg)
28
 
 
29
 
def get_patch_names(iter_lines):
30
 
    try:
31
 
        line = iter_lines.next()
32
 
        if not line.startswith("--- "):
33
 
            raise MalformedPatchHeader("No orig name", line)
34
 
        else:
35
 
            orig_name = line[4:].rstrip("\n")
36
 
    except StopIteration:
37
 
        raise MalformedPatchHeader("No orig line", "")
38
 
    try:
39
 
        line = iter_lines.next()
40
 
        if not line.startswith("+++ "):
41
 
            raise PatchSyntax("No mod name")
42
 
        else:
43
 
            mod_name = line[4:].rstrip("\n")
44
 
    except StopIteration:
45
 
        raise MalformedPatchHeader("No mod line", "")
46
 
    return (orig_name, mod_name)
47
 
 
48
 
def parse_range(textrange):
49
 
    """Parse a patch range, handling the "1" special-case
50
 
 
51
 
    :param textrange: The text to parse
52
 
    :type textrange: str
53
 
    :return: the position and range, as a tuple
54
 
    :rtype: (int, int)
55
 
    """
56
 
    tmp = textrange.split(',')
57
 
    if len(tmp) == 1:
58
 
        pos = tmp[0]
59
 
        range = "1"
60
 
    else:
61
 
        (pos, range) = tmp
62
 
    pos = int(pos)
63
 
    range = int(range)
64
 
    return (pos, range)
65
 
 
66
 
 
67
 
def hunk_from_header(line):
68
 
    if not line.startswith("@@") or not line.endswith("@@\n") \
69
 
        or not len(line) > 4:
70
 
        raise MalformedHunkHeader("Does not start and end with @@.", line)
71
 
    try:
72
 
        (orig, mod) = line[3:-4].split(" ")
73
 
    except Exception, e:
74
 
        raise MalformedHunkHeader(str(e), line)
75
 
    if not orig.startswith('-') or not mod.startswith('+'):
76
 
        raise MalformedHunkHeader("Positions don't start with + or -.", line)
77
 
    try:
78
 
        (orig_pos, orig_range) = parse_range(orig[1:])
79
 
        (mod_pos, mod_range) = parse_range(mod[1:])
80
 
    except Exception, e:
81
 
        raise MalformedHunkHeader(str(e), line)
82
 
    if mod_range < 0 or orig_range < 0:
83
 
        raise MalformedHunkHeader("Hunk range is negative", line)
84
 
    return Hunk(orig_pos, orig_range, mod_pos, mod_range)
85
 
 
86
 
 
87
 
class HunkLine:
88
 
    def __init__(self, contents):
89
 
        self.contents = contents
90
 
 
91
 
    def get_str(self, leadchar):
92
 
        if self.contents == "\n" and leadchar == " " and False:
93
 
            return "\n"
94
 
        return leadchar + self.contents
95
 
 
96
 
class ContextLine(HunkLine):
97
 
    def __init__(self, contents):
98
 
        HunkLine.__init__(self, contents)
99
 
 
100
 
    def __str__(self):
101
 
        return self.get_str(" ")
102
 
 
103
 
 
104
 
class InsertLine(HunkLine):
105
 
    def __init__(self, contents):
106
 
        HunkLine.__init__(self, contents)
107
 
 
108
 
    def __str__(self):
109
 
        return self.get_str("+")
110
 
 
111
 
 
112
 
class RemoveLine(HunkLine):
113
 
    def __init__(self, contents):
114
 
        HunkLine.__init__(self, contents)
115
 
 
116
 
    def __str__(self):
117
 
        return self.get_str("-")
118
 
 
119
 
__pychecker__="no-returnvalues"
120
 
def parse_line(line):
121
 
    if line.startswith("\n"):
122
 
        return ContextLine(line)
123
 
    elif line.startswith(" "):
124
 
        return ContextLine(line[1:])
125
 
    elif line.startswith("+"):
126
 
        return InsertLine(line[1:])
127
 
    elif line.startswith("-"):
128
 
        return RemoveLine(line[1:])
129
 
    else:
130
 
        raise MalformedLine("Unknown line type", line)
131
 
__pychecker__=""
132
 
 
133
 
 
134
 
class Hunk:
135
 
    def __init__(self, orig_pos, orig_range, mod_pos, mod_range):
136
 
        self.orig_pos = orig_pos
137
 
        self.orig_range = orig_range
138
 
        self.mod_pos = mod_pos
139
 
        self.mod_range = mod_range
140
 
        self.lines = []
141
 
 
142
 
    def get_header(self):
143
 
        return "@@ -%s +%s @@\n" % (self.range_str(self.orig_pos, 
144
 
                                                   self.orig_range),
145
 
                                    self.range_str(self.mod_pos, 
146
 
                                                   self.mod_range))
147
 
 
148
 
    def range_str(self, pos, range):
149
 
        """Return a file range, special-casing for 1-line files.
150
 
 
151
 
        :param pos: The position in the file
152
 
        :type pos: int
153
 
        :range: The range in the file
154
 
        :type range: int
155
 
        :return: a string in the format 1,4 except when range == pos == 1
156
 
        """
157
 
        if range == 1:
158
 
            return "%i" % pos
159
 
        else:
160
 
            return "%i,%i" % (pos, range)
161
 
 
162
 
    def __str__(self):
163
 
        lines = [self.get_header()]
164
 
        for line in self.lines:
165
 
            lines.append(str(line))
166
 
        return "".join(lines)
167
 
 
168
 
    def shift_to_mod(self, pos):
169
 
        if pos < self.orig_pos-1:
170
 
            return 0
171
 
        elif pos > self.orig_pos+self.orig_range:
172
 
            return self.mod_range - self.orig_range
173
 
        else:
174
 
            return self.shift_to_mod_lines(pos)
175
 
 
176
 
    def shift_to_mod_lines(self, pos):
177
 
        assert (pos >= self.orig_pos-1 and pos <= self.orig_pos+self.orig_range)
178
 
        position = self.orig_pos-1
179
 
        shift = 0
180
 
        for line in self.lines:
181
 
            if isinstance(line, InsertLine):
182
 
                shift += 1
183
 
            elif isinstance(line, RemoveLine):
184
 
                if position == pos:
185
 
                    return None
186
 
                shift -= 1
187
 
                position += 1
188
 
            elif isinstance(line, ContextLine):
189
 
                position += 1
190
 
            if position > pos:
191
 
                break
192
 
        return shift
193
 
 
194
 
def iter_hunks(iter_lines):
195
 
    hunk = None
196
 
    for line in iter_lines:
197
 
        if line.startswith("@@"):
198
 
            if hunk is not None:
199
 
                yield hunk
200
 
            hunk = hunk_from_header(line)
201
 
        else:
202
 
            hunk.lines.append(parse_line(line))
203
 
 
204
 
    if hunk is not None:
205
 
        yield hunk
206
 
 
207
 
class Patch:
208
 
    def __init__(self, oldname, newname):
209
 
        self.oldname = oldname
210
 
        self.newname = newname
211
 
        self.hunks = []
212
 
 
213
 
    def __str__(self):
214
 
        ret =  "--- %s\n+++ %s\n" % (self.oldname, self.newname) 
215
 
        ret += "".join([str(h) for h in self.hunks])
216
 
        return ret
217
 
 
218
 
    def stats_str(self):
219
 
        """Return a string of patch statistics"""
220
 
        removes = 0
221
 
        inserts = 0
222
 
        for hunk in self.hunks:
223
 
            for line in hunk.lines:
224
 
                if isinstance(line, InsertLine):
225
 
                     inserts+=1;
226
 
                elif isinstance(line, RemoveLine):
227
 
                     removes+=1;
228
 
        return "%i inserts, %i removes in %i hunks" % \
229
 
            (inserts, removes, len(self.hunks))
230
 
 
231
 
    def pos_in_mod(self, position):
232
 
        newpos = position
233
 
        for hunk in self.hunks:
234
 
            shift = hunk.shift_to_mod(position)
235
 
            if shift is None:
236
 
                return None
237
 
            newpos += shift
238
 
        return newpos
239
 
            
240
 
    def iter_inserted(self):
241
 
        """Iteraties through inserted lines
242
 
        
243
 
        :return: Pair of line number, line
244
 
        :rtype: iterator of (int, InsertLine)
245
 
        """
246
 
        for hunk in self.hunks:
247
 
            pos = hunk.mod_pos - 1;
248
 
            for line in hunk.lines:
249
 
                if isinstance(line, InsertLine):
250
 
                    yield (pos, line)
251
 
                    pos += 1
252
 
                if isinstance(line, ContextLine):
253
 
                    pos += 1
254
 
 
255
 
def parse_patch(iter_lines):
256
 
    (orig_name, mod_name) = get_patch_names(iter_lines)
257
 
    patch = Patch(orig_name, mod_name)
258
 
    for hunk in iter_hunks(iter_lines):
259
 
        patch.hunks.append(hunk)
260
 
    return patch
261
 
 
262
 
if __name__ == "__main__":
263
 
    import unittest
264
 
    class PatchesTester(unittest.TestCase):
265
 
        def testValidPatchHeader(self):
266
 
            """Parse a valid patch header"""
267
 
            lines = "--- orig/commands.py\n+++ mod/dommands.py\n".split('\n')
268
 
            (orig, mod) = get_patch_names(lines.__iter__())
269
 
            assert(orig == "orig/commands.py")
270
 
            assert(mod == "mod/dommands.py")
271
 
 
272
 
        def testInvalidPatchHeader(self):
273
 
            """Parse an invalid patch header"""
274
 
            lines = "-- orig/commands.py\n+++ mod/dommands.py".split('\n')
275
 
            self.assertRaises(MalformedPatchHeader, get_patch_names,
276
 
                              lines.__iter__())
277
 
 
278
 
        def testValidHunkHeader(self):
279
 
            """Parse a valid hunk header"""
280
 
            header = "@@ -34,11 +50,6 @@\n"
281
 
            hunk = hunk_from_header(header);
282
 
            assert (hunk.orig_pos == 34)
283
 
            assert (hunk.orig_range == 11)
284
 
            assert (hunk.mod_pos == 50)
285
 
            assert (hunk.mod_range == 6)
286
 
            assert (str(hunk) == header)
287
 
 
288
 
        def testValidHunkHeader2(self):
289
 
            """Parse a tricky, valid hunk header"""
290
 
            header = "@@ -1 +0,0 @@\n"
291
 
            hunk = hunk_from_header(header);
292
 
            assert (hunk.orig_pos == 1)
293
 
            assert (hunk.orig_range == 1)
294
 
            assert (hunk.mod_pos == 0)
295
 
            assert (hunk.mod_range == 0)
296
 
            assert (str(hunk) == header)
297
 
 
298
 
        def makeMalformed(self, header):
299
 
            self.assertRaises(MalformedHunkHeader, hunk_from_header, header)
300
 
 
301
 
        def testInvalidHeader(self):
302
 
            """Parse an invalid hunk header"""
303
 
            self.makeMalformed(" -34,11 +50,6 \n")
304
 
            self.makeMalformed("@@ +50,6 -34,11 @@\n")
305
 
            self.makeMalformed("@@ -34,11 +50,6 @@")
306
 
            self.makeMalformed("@@ -34.5,11 +50,6 @@\n")
307
 
            self.makeMalformed("@@-34,11 +50,6@@\n")
308
 
            self.makeMalformed("@@ 34,11 50,6 @@\n")
309
 
            self.makeMalformed("@@ -34,11 @@\n")
310
 
            self.makeMalformed("@@ -34,11 +50,6.5 @@\n")
311
 
            self.makeMalformed("@@ -34,11 +50,-6 @@\n")
312
 
 
313
 
        def lineThing(self,text, type):
314
 
            line = parse_line(text)
315
 
            assert(isinstance(line, type))
316
 
            assert(str(line)==text)
317
 
 
318
 
        def makeMalformedLine(self, text):
319
 
            self.assertRaises(MalformedLine, parse_line, text)
320
 
 
321
 
        def testValidLine(self):
322
 
            """Parse a valid hunk line"""
323
 
            self.lineThing(" hello\n", ContextLine)
324
 
            self.lineThing("+hello\n", InsertLine)
325
 
            self.lineThing("-hello\n", RemoveLine)
326
 
        
327
 
        def testMalformedLine(self):
328
 
            """Parse invalid valid hunk lines"""
329
 
            self.makeMalformedLine("hello\n")
330
 
        
331
 
        def compare_parsed(self, patchtext):
332
 
            lines = patchtext.splitlines(True)
333
 
            patch = parse_patch(lines.__iter__())
334
 
            pstr = str(patch)
335
 
            i = util.difference_index(patchtext, pstr)
336
 
            if i is not None:
337
 
                print "%i: \"%s\" != \"%s\"" % (i, patchtext[i], pstr[i])
338
 
            assert (patchtext == str(patch))
339
 
 
340
 
        def testAll(self):
341
 
            """Test parsing a whole patch"""
342
 
            patchtext = """--- orig/commands.py
343
 
+++ mod/commands.py
344
 
@@ -1337,7 +1337,8 @@
345
 
 
346
 
     def set_title(self, command=None):
347
 
         try:
348
 
-            version = self.tree.tree_version.nonarch
349
 
+            version = pylon.alias_or_version(self.tree.tree_version, self.tree,
350
 
+                                             full=False)
351
 
         except:
352
 
             version = "[no version]"
353
 
         if command is None:
354
 
@@ -1983,7 +1984,11 @@
355
 
                                          version)
356
 
         if len(new_merges) > 0:
357
 
             if cmdutil.prompt("Log for merge"):
358
 
-                mergestuff = cmdutil.log_for_merge(tree, comp_version)
359
 
+                if cmdutil.prompt("changelog for merge"):
360
 
+                    mergestuff = "Patches applied:\\n"
361
 
+                    mergestuff += pylon.changelog_for_merge(new_merges)
362
 
+                else:
363
 
+                    mergestuff = cmdutil.log_for_merge(tree, comp_version)
364
 
                 log.description += mergestuff
365
 
         log.save()
366
 
     try:
367
 
"""
368
 
            self.compare_parsed(patchtext)
369
 
 
370
 
        def testInit(self):
371
 
            """Handle patches missing half the position, range tuple"""
372
 
            patchtext = \
373
 
"""--- orig/__init__.py
374
 
+++ mod/__init__.py
375
 
@@ -1 +1,2 @@
376
 
 __docformat__ = "restructuredtext en"
377
 
+__doc__ = An alternate Arch commandline interface"""
378
 
            self.compare_parsed(patchtext)
379
 
            
380
 
 
381
 
 
382
 
        def testLineLookup(self):
383
 
            """Make sure we can accurately look up mod line from orig"""
384
 
            patch = parse_patch(open("testdata/diff"))
385
 
            orig = list(open("testdata/orig"))
386
 
            mod = list(open("testdata/mod"))
387
 
            removals = []
388
 
            for i in range(len(orig)):
389
 
                mod_pos = patch.pos_in_mod(i)
390
 
                if mod_pos is None:
391
 
                    removals.append(orig[i])
392
 
                    continue
393
 
                assert(mod[mod_pos]==orig[i])
394
 
            rem_iter = removals.__iter__()
395
 
            for hunk in patch.hunks:
396
 
                for line in hunk.lines:
397
 
                    if isinstance(line, RemoveLine):
398
 
                        next = rem_iter.next()
399
 
                        if line.contents != next:
400
 
                            sys.stdout.write(" orig:%spatch:%s" % (next,
401
 
                                             line.contents))
402
 
                        assert(line.contents == next)
403
 
            self.assertRaises(StopIteration, rem_iter.next)
404
 
 
405
 
        def testFirstLineRenumber(self):
406
 
            """Make sure we handle lines at the beginning of the hunk"""
407
 
            patch = parse_patch(open("testdata/insert_top.patch"))
408
 
            assert (patch.pos_in_mod(0)==1)
409
 
    
410
 
            
411
 
    patchesTestSuite = unittest.makeSuite(PatchesTester,'test')
412
 
    runner = unittest.TextTestRunner()
413
 
    runner.run(patchesTestSuite)
414
 
    
415
 
 
416
 
# arch-tag: d1541a25-eac5-4de9-a476-08a7cecd5683