~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/matchers.py

  • Committer: Vincent Ladeuil
  • Date: 2013-05-25 17:12:43 UTC
  • mto: (6437.77.1 2.5)
  • mto: This revision was merged to the branch mainline in revision 6577.
  • Revision ID: v.ladeuil+lp@free.fr-20130525171243-au0073fnspecl3kg
Empty arguments in EDITOR are now properly preserved

Show diffs side-by-side

added added

removed removed

Lines of Context:
27
27
"""
28
28
 
29
29
__all__ = [
 
30
    'HasLayout',
 
31
    'MatchesAncestry',
 
32
    'ContainsNoVfsCalls',
30
33
    'ReturnsUnlockable',
 
34
    'RevisionHistoryMatches',
31
35
    ]
32
36
 
33
 
from testtools.matchers import Mismatch, Matcher
 
37
from bzrlib import (
 
38
    osutils,
 
39
    revision as _mod_revision,
 
40
    )
 
41
from bzrlib import lazy_import
 
42
lazy_import.lazy_import(globals(),
 
43
"""
 
44
from bzrlib.smart.request import request_handlers as smart_request_handlers
 
45
from bzrlib.smart import vfs
 
46
""")
 
47
 
 
48
from testtools.matchers import Equals, Mismatch, Matcher
34
49
 
35
50
 
36
51
class ReturnsUnlockable(Matcher):
48
63
        self.lockable_thing = lockable_thing
49
64
 
50
65
    def __str__(self):
51
 
        return ('ReturnsUnlockable(lockable_thing=%s)' % 
 
66
        return ('ReturnsUnlockable(lockable_thing=%s)' %
52
67
            self.lockable_thing)
53
68
 
54
69
    def match(self, lock_method):
66
81
 
67
82
    def describe(self):
68
83
        return "%s is locked" % self.lockable_thing
 
84
 
 
85
 
 
86
class _AncestryMismatch(Mismatch):
 
87
    """Ancestry matching mismatch."""
 
88
 
 
89
    def __init__(self, tip_revision, got, expected):
 
90
        self.tip_revision = tip_revision
 
91
        self.got = got
 
92
        self.expected = expected
 
93
 
 
94
    def describe(self):
 
95
        return "mismatched ancestry for revision %r was %r, expected %r" % (
 
96
            self.tip_revision, self.got, self.expected)
 
97
 
 
98
 
 
99
class MatchesAncestry(Matcher):
 
100
    """A matcher that checks the ancestry of a particular revision.
 
101
 
 
102
    :ivar graph: Graph in which to check the ancestry
 
103
    :ivar revision_id: Revision id of the revision
 
104
    """
 
105
 
 
106
    def __init__(self, repository, revision_id):
 
107
        Matcher.__init__(self)
 
108
        self.repository = repository
 
109
        self.revision_id = revision_id
 
110
 
 
111
    def __str__(self):
 
112
        return ('MatchesAncestry(repository=%r, revision_id=%r)' % (
 
113
            self.repository, self.revision_id))
 
114
 
 
115
    def match(self, expected):
 
116
        self.repository.lock_read()
 
117
        try:
 
118
            graph = self.repository.get_graph()
 
119
            got = [r for r, p in graph.iter_ancestry([self.revision_id])]
 
120
            if _mod_revision.NULL_REVISION in got:
 
121
                got.remove(_mod_revision.NULL_REVISION)
 
122
        finally:
 
123
            self.repository.unlock()
 
124
        if sorted(got) != sorted(expected):
 
125
            return _AncestryMismatch(self.revision_id, sorted(got),
 
126
                sorted(expected))
 
127
 
 
128
 
 
129
class HasLayout(Matcher):
 
130
    """A matcher that checks if a tree has a specific layout.
 
131
 
 
132
    :ivar entries: List of expected entries, as (path, file_id) pairs.
 
133
    """
 
134
 
 
135
    def __init__(self, entries):
 
136
        Matcher.__init__(self)
 
137
        self.entries = entries
 
138
 
 
139
    def get_tree_layout(self, tree):
 
140
        """Get the (path, file_id) pairs for the current tree."""
 
141
        tree.lock_read()
 
142
        try:
 
143
            for path, ie in tree.iter_entries_by_dir():
 
144
                if ie.parent_id is None:
 
145
                    yield (u"", ie.file_id)
 
146
                else:
 
147
                    yield (path+ie.kind_character(), ie.file_id)
 
148
        finally:
 
149
            tree.unlock()
 
150
 
 
151
    @staticmethod
 
152
    def _strip_unreferenced_directories(entries):
 
153
        """Strip all directories that don't (in)directly contain any files.
 
154
 
 
155
        :param entries: List of path strings or (path, ie) tuples to process
 
156
        """
 
157
        directories = []
 
158
        for entry in entries:
 
159
            if isinstance(entry, basestring):
 
160
                path = entry
 
161
            else:
 
162
                path = entry[0]
 
163
            if not path or path[-1] == "/":
 
164
                # directory
 
165
                directories.append((path, entry))
 
166
            else:
 
167
                # Yield the referenced parent directories
 
168
                for dirpath, direntry in directories:
 
169
                    if osutils.is_inside(dirpath, path):
 
170
                        yield direntry
 
171
                directories = []
 
172
                yield entry
 
173
 
 
174
    def __str__(self):
 
175
        return 'HasLayout(%r)' % self.entries
 
176
 
 
177
    def match(self, tree):
 
178
        actual = list(self.get_tree_layout(tree))
 
179
        if self.entries and isinstance(self.entries[0], basestring):
 
180
            actual = [path for (path, fileid) in actual]
 
181
        if not tree.has_versioned_directories():
 
182
            entries = list(self._strip_unreferenced_directories(self.entries))
 
183
        else:
 
184
            entries = self.entries
 
185
        return Equals(entries).match(actual)
 
186
 
 
187
 
 
188
class RevisionHistoryMatches(Matcher):
 
189
    """A matcher that checks if a branch has a specific revision history.
 
190
 
 
191
    :ivar history: Revision history, as list of revisions. Oldest first.
 
192
    """
 
193
 
 
194
    def __init__(self, history):
 
195
        Matcher.__init__(self)
 
196
        self.expected = history
 
197
 
 
198
    def __str__(self):
 
199
        return 'RevisionHistoryMatches(%r)' % self.expected
 
200
 
 
201
    def match(self, branch):
 
202
        branch.lock_read()
 
203
        try:
 
204
            graph = branch.repository.get_graph()
 
205
            history = list(graph.iter_lefthand_ancestry(
 
206
                branch.last_revision(), [_mod_revision.NULL_REVISION]))
 
207
            history.reverse()
 
208
        finally:
 
209
            branch.unlock()
 
210
        return Equals(self.expected).match(history)
 
211
 
 
212
 
 
213
class _NoVfsCallsMismatch(Mismatch):
 
214
    """Mismatch describing a list of HPSS calls which includes VFS requests."""
 
215
 
 
216
    def __init__(self, vfs_calls):
 
217
        self.vfs_calls = vfs_calls
 
218
 
 
219
    def describe(self):
 
220
        return "no VFS calls expected, got: %s" % ",".join([
 
221
            "%s(%s)" % (c.method,
 
222
                ", ".join([repr(a) for a in c.args])) for c in self.vfs_calls])
 
223
 
 
224
 
 
225
class ContainsNoVfsCalls(Matcher):
 
226
    """Ensure that none of the specified calls are HPSS calls."""
 
227
 
 
228
    def __str__(self):
 
229
        return 'ContainsNoVfsCalls()'
 
230
 
 
231
    @classmethod
 
232
    def match(cls, hpss_calls):
 
233
        vfs_calls = []
 
234
        for call in hpss_calls:
 
235
            try:
 
236
                request_method = smart_request_handlers.get(call.call.method)
 
237
            except KeyError:
 
238
                # A method we don't know about doesn't count as a VFS method.
 
239
                continue
 
240
            if issubclass(request_method, vfs.VfsRequest):
 
241
                vfs_calls.append(call.call)
 
242
        if len(vfs_calls) == 0:
 
243
            return None
 
244
        return _NoVfsCallsMismatch(vfs_calls)