~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/TestUtil.py

Merge bzr.dev to resolve conflicts

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright (C) 2005-2011 Canonical Ltd
 
2
#       Author: Robert Collins <robert.collins@canonical.com>
 
3
#
 
4
# This program is free software; you can redistribute it and/or modify
 
5
# it under the terms of the GNU General Public License as published by
 
6
# the Free Software Foundation; either version 2 of the License, or
 
7
# (at your option) any later version.
 
8
#
 
9
# This program is distributed in the hope that it will be useful,
 
10
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 
11
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
12
# GNU General Public License for more details.
 
13
#
 
14
# You should have received a copy of the GNU General Public License
 
15
# along with this program; if not, write to the Free Software
 
16
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 
17
#
 
18
 
 
19
import sys
 
20
import logging
 
21
import unittest
 
22
import weakref
 
23
 
 
24
from bzrlib import pyutils
 
25
 
 
26
# Mark this python module as being part of the implementation
 
27
# of unittest: this gives us better tracebacks where the last
 
28
# shown frame is the test code, not our assertXYZ.
 
29
__unittest = 1
 
30
 
 
31
 
 
32
class LogCollector(logging.Handler):
 
33
 
 
34
    def __init__(self):
 
35
        logging.Handler.__init__(self)
 
36
        self.records=[]
 
37
 
 
38
    def emit(self, record):
 
39
        self.records.append(record.getMessage())
 
40
 
 
41
 
 
42
def makeCollectingLogger():
 
43
    """I make a logger instance that collects its logs for programmatic analysis
 
44
    -> (logger, collector)"""
 
45
    logger=logging.Logger("collector")
 
46
    handler=LogCollector()
 
47
    handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))
 
48
    logger.addHandler(handler)
 
49
    return logger, handler
 
50
 
 
51
 
 
52
def visitTests(suite, visitor):
 
53
    """A foreign method for visiting the tests in a test suite."""
 
54
    for test in suite._tests:
 
55
        #Abusing types to avoid monkey patching unittest.TestCase.
 
56
        # Maybe that would be better?
 
57
        try:
 
58
            test.visit(visitor)
 
59
        except AttributeError:
 
60
            if isinstance(test, unittest.TestCase):
 
61
                visitor.visitCase(test)
 
62
            elif isinstance(test, unittest.TestSuite):
 
63
                visitor.visitSuite(test)
 
64
                visitTests(test, visitor)
 
65
            else:
 
66
                print "unvisitable non-unittest.TestCase element %r (%r)" % (
 
67
                    test, test.__class__)
 
68
 
 
69
 
 
70
class TestSuite(unittest.TestSuite):
 
71
    """I am an extended TestSuite with a visitor interface.
 
72
    This is primarily to allow filtering of tests - and suites or
 
73
    more in the future. An iterator of just tests wouldn't scale..."""
 
74
 
 
75
    def visit(self, visitor):
 
76
        """visit the composite. Visiting is depth-first.
 
77
        current callbacks are visitSuite and visitCase."""
 
78
        visitor.visitSuite(self)
 
79
        visitTests(self, visitor)
 
80
 
 
81
    def run(self, result):
 
82
        """Run the tests in the suite, discarding references after running."""
 
83
        tests = list(self)
 
84
        tests.reverse()
 
85
        self._tests = []
 
86
        stream = getattr(result, "stream", None)
 
87
        # With subunit, not only is stream underscored, but the actual result
 
88
        # object is hidden inside a wrapper decorator, get out the real stream
 
89
        if stream is None:
 
90
            stream = result.decorated._stream
 
91
        stored_count = 0
 
92
        count_stored_tests = getattr(result, "_count_stored_tests", int)
 
93
        from bzrlib.tests import selftest_debug_flags
 
94
        notify = "uncollected_cases" in selftest_debug_flags
 
95
        while tests:
 
96
            if result.shouldStop:
 
97
                self._tests = reversed(tests)
 
98
                break
 
99
            case = _run_and_collect_case(tests.pop(), result)()
 
100
            new_stored_count = count_stored_tests()
 
101
            if case is not None and isinstance(case, unittest.TestCase):
 
102
                if stored_count == new_stored_count and notify:
 
103
                    # Testcase didn't fail, but somehow is still alive
 
104
                    stream.write("Uncollected test case: %s\n" % (case.id(),))
 
105
                # Zombie the testcase but leave a working stub id method
 
106
                case.__dict__ = {"id": lambda _id=case.id(): _id}
 
107
            stored_count = new_stored_count
 
108
        return result
 
109
 
 
110
 
 
111
def _run_and_collect_case(case, res):
 
112
    """Run test case against result and use weakref to drop the refcount"""
 
113
    case.run(res)
 
114
    return weakref.ref(case)
 
115
 
 
116
 
 
117
class TestLoader(unittest.TestLoader):
 
118
    """Custom TestLoader to extend the stock python one."""
 
119
 
 
120
    suiteClass = TestSuite
 
121
    # Memoize test names by test class dict
 
122
    test_func_names = {}
 
123
 
 
124
    def loadTestsFromModuleNames(self, names):
 
125
        """use a custom means to load tests from modules.
 
126
 
 
127
        There is an undesirable glitch in the python TestLoader where a
 
128
        import error is ignore. We think this can be solved by ensuring the
 
129
        requested name is resolvable, if its not raising the original error.
 
130
        """
 
131
        result = self.suiteClass()
 
132
        for name in names:
 
133
            result.addTests(self.loadTestsFromModuleName(name))
 
134
        return result
 
135
 
 
136
    def loadTestsFromModuleName(self, name):
 
137
        result = self.suiteClass()
 
138
        module = pyutils.get_named_object(name)
 
139
 
 
140
        result.addTests(self.loadTestsFromModule(module))
 
141
        return result
 
142
 
 
143
    def loadTestsFromModule(self, module):
 
144
        """Load tests from a module object.
 
145
 
 
146
        This extension of the python test loader looks for an attribute
 
147
        load_tests in the module object, and if not found falls back to the
 
148
        regular python loadTestsFromModule.
 
149
 
 
150
        If a load_tests attribute is found, it is called and the result is
 
151
        returned.
 
152
 
 
153
        load_tests should be defined like so:
 
154
        >>> def load_tests(standard_tests, module, loader):
 
155
        >>>    pass
 
156
 
 
157
        standard_tests is the tests found by the stock TestLoader in the
 
158
        module, module and loader are the module and loader instances.
 
159
 
 
160
        For instance, to run every test twice, you might do:
 
161
        >>> def load_tests(standard_tests, module, loader):
 
162
        >>>     result = loader.suiteClass()
 
163
        >>>     for test in iter_suite_tests(standard_tests):
 
164
        >>>         result.addTests([test, test])
 
165
        >>>     return result
 
166
        """
 
167
        if sys.version_info < (2, 7):
 
168
            basic_tests = super(TestLoader, self).loadTestsFromModule(module)
 
169
        else:
 
170
            # GZ 2010-07-19: Python 2.7 unittest also uses load_tests but with
 
171
            #                a different and incompatible signature
 
172
            basic_tests = super(TestLoader, self).loadTestsFromModule(module,
 
173
                use_load_tests=False)
 
174
        load_tests = getattr(module, "load_tests", None)
 
175
        if load_tests is not None:
 
176
            return load_tests(basic_tests, module, self)
 
177
        else:
 
178
            return basic_tests
 
179
 
 
180
    def getTestCaseNames(self, test_case_class):
 
181
        test_fn_names = self.test_func_names.get(test_case_class, None)
 
182
        if test_fn_names is not None:
 
183
            # We already know them
 
184
            return test_fn_names
 
185
 
 
186
        test_fn_names = unittest.TestLoader.getTestCaseNames(self,
 
187
                                                             test_case_class)
 
188
        self.test_func_names[test_case_class] = test_fn_names
 
189
        return test_fn_names
 
190
 
 
191
 
 
192
class FilteredByModuleTestLoader(TestLoader):
 
193
    """A test loader that import only the needed modules."""
 
194
 
 
195
    def __init__(self, needs_module):
 
196
        """Constructor.
 
197
 
 
198
        :param needs_module: a callable taking a module name as a
 
199
            parameter returing True if the module should be loaded.
 
200
        """
 
201
        TestLoader.__init__(self)
 
202
        self.needs_module = needs_module
 
203
 
 
204
    def loadTestsFromModuleName(self, name):
 
205
        if self.needs_module(name):
 
206
            return TestLoader.loadTestsFromModuleName(self, name)
 
207
        else:
 
208
            return self.suiteClass()
 
209
 
 
210
 
 
211
class TestVisitor(object):
 
212
    """A visitor for Tests"""
 
213
 
 
214
    def visitSuite(self, aTestSuite):
 
215
        pass
 
216
 
 
217
    def visitCase(self, aTestCase):
 
218
        pass