30
30
class TestRegistry(TestCase):
31
def register_stuff(self, registry):
32
registry.register('one', 1)
33
registry.register('two', 2)
34
registry.register('four', 4)
35
registry.register('five', 5)
32
def register_stuff(self, a_registry):
33
a_registry.register('one', 1)
34
a_registry.register('two', 2)
35
a_registry.register('four', 4)
36
a_registry.register('five', 5)
37
38
def test_registry(self):
38
registry_ = registry.Registry()
39
self.register_stuff(registry_)
39
a_registry = registry.Registry()
40
self.register_stuff(a_registry)
41
self.failUnless(registry_.default_key is None)
42
self.failUnless(a_registry.default_key is None)
43
44
# test get() (self.default_key == None)
44
self.assertRaises(KeyError, registry_.get)
45
self.assertRaises(KeyError, registry_.get, None)
46
self.assertEqual(2, registry_.get('two'))
47
self.assertRaises(KeyError, registry_.get, 'three')
45
self.assertRaises(KeyError, a_registry.get)
46
self.assertRaises(KeyError, a_registry.get, None)
47
self.assertEqual(2, a_registry.get('two'))
48
self.assertRaises(KeyError, a_registry.get, 'three')
49
50
# test _set_default_key
50
registry_.default_key = 'five'
51
self.failUnless(registry_.default_key == 'five')
52
self.assertEqual(5, registry_.get())
53
self.assertEqual(5, registry_.get(None))
54
self.assertEqual(5, registry_.get('six'))
55
self.assertRaises(KeyError, registry_._set_default_key, 'six')
51
a_registry.default_key = 'five'
52
self.failUnless(a_registry.default_key == 'five')
53
self.assertEqual(5, a_registry.get())
54
self.assertEqual(5, a_registry.get(None))
55
# If they ask for a specific entry, they should get KeyError
56
# not the default value. They can always pass None if they prefer
57
self.assertRaises(KeyError, a_registry.get, 'six')
58
self.assertRaises(KeyError, a_registry._set_default_key, 'six')
58
self.assertEqual(['five', 'four', 'one', 'two'], registry_.keys())
61
self.assertEqual(['five', 'four', 'one', 'two'], a_registry.keys())
60
63
def test_registry_with_first_is_default(self):
61
registry_ = registry.Registry(True)
62
self.register_stuff(registry_)
64
a_registry = registry.Registry(True)
65
self.register_stuff(a_registry)
64
self.failUnless(registry_.default_key == 'one')
67
self.failUnless(a_registry.default_key == 'one')
66
69
# test get() (self.default_key == 'one')
67
self.assertEqual(1, registry_.get())
68
self.assertEqual(1, registry_.get(None))
69
self.assertEqual(2, registry_.get('two'))
70
self.assertEqual(1, registry_.get('three'))
70
self.assertEqual(1, a_registry.get())
71
self.assertEqual(1, a_registry.get(None))
72
self.assertEqual(2, a_registry.get('two'))
73
self.assertRaises(KeyError, a_registry.get, 'three')
72
75
# test _set_default_key
73
registry_.default_key = 'five'
74
self.failUnless(registry_.default_key == 'five')
75
self.assertEqual(5, registry_.get())
76
self.assertEqual(5, registry_.get(None))
77
self.assertEqual(5, registry_.get('six'))
78
self.assertRaises(KeyError, registry_._set_default_key, 'six')
80
class TestLazyImportRegistry(TestCaseInTempDir):
76
a_registry.default_key = 'five'
77
self.failUnless(a_registry.default_key == 'five')
78
self.assertEqual(5, a_registry.get())
79
self.assertEqual(5, a_registry.get(None))
80
self.assertRaises(KeyError, a_registry.get, 'six')
81
self.assertRaises(KeyError, a_registry._set_default_key, 'six')
83
def test_registry_like_dict(self):
84
a_registry = registry.Registry()
85
self.register_stuff(a_registry)
87
self.failUnless('one' in a_registry)
89
self.failIf('one' in a_registry)
90
self.assertRaises(KeyError, a_registry.get, 'one')
92
a_registry['one'] = 'one'
93
self.assertEqual('one', a_registry['one'])
94
self.assertEqual(4, len(a_registry))
96
self.assertEqual(['five', 'four', 'one', 'two'],
97
sorted(a_registry.iterkeys()))
98
self.assertEqual([('five', 5), ('four', 4),
99
('one', 'one'), ('two', 2)],
100
sorted(a_registry.iteritems()))
101
self.assertEqual([2, 4, 5, 'one'],
102
sorted(a_registry.itervalues()))
104
self.assertEqual(['five', 'four', 'one', 'two'],
105
sorted(a_registry.keys()))
106
self.assertEqual([('five', 5), ('four', 4),
107
('one', 'one'), ('two', 2)],
108
sorted(a_registry.items()))
109
self.assertEqual([2, 4, 5, 'one'],
110
sorted(a_registry.values()))
113
class TestRegistryWithDirs(TestCaseInTempDir):
114
"""Registry tests that require temporary dirs"""
82
116
def create_plugin_file(self, contents):
117
"""Create a file to be used as a plugin.
119
This is created in a temporary directory, so that we
120
are sure that it doesn't start in the plugin path.
83
123
plugin_name = 'bzr_plugin_a_%s' % (osutils.rand_chars(4),)
84
open(plugin_name+'.py', 'wb').write(contents)
124
open('tmp/'+plugin_name+'.py', 'wb').write(contents)
85
125
return plugin_name
87
127
def create_simple_plugin(self):
100
140
def test_lazy_import_registry(self):
101
141
plugin_name = self.create_simple_plugin()
102
factory = registry.LazyImportRegistry()
103
factory.register('obj', plugin_name, 'object1')
104
factory.register('function', plugin_name, 'function')
105
factory.register('klass', plugin_name, 'MyClass')
106
factory.register('module', plugin_name, None)
142
a_registry = registry.Registry()
143
a_registry.register_lazy('obj', plugin_name, 'object1')
144
a_registry.register_lazy('function', plugin_name, 'function')
145
a_registry.register_lazy('klass', plugin_name, 'MyClass')
146
a_registry.register_lazy('module', plugin_name, None)
108
148
self.assertEqual(['function', 'klass', 'module', 'obj'],
109
sorted(factory.keys()))
149
sorted(a_registry.keys()))
110
150
# The plugin should not be loaded until we grab the first object
111
151
self.failIf(plugin_name in sys.modules)
113
153
# By default the plugin won't be in the search path
114
self.assertRaises(ImportError, factory.get, 'obj')
154
self.assertRaises(ImportError, a_registry.get, 'obj')
156
plugin_path = os.getcwd() + '/tmp'
157
sys.path.append(plugin_path)
119
obj = factory.get('obj')
159
obj = a_registry.get('obj')
120
160
self.assertEqual('foo', obj)
121
161
self.failUnless(plugin_name in sys.modules)
123
163
# Now grab another object
124
func = factory.get('function')
164
func = a_registry.get('function')
125
165
self.assertEqual(plugin_name, func.__module__)
126
166
self.assertEqual('function', func.__name__)
127
167
self.assertEqual((1, [], '3'), func(1, [], '3'))
129
169
# And finally a class
130
klass = factory.get('klass')
170
klass = a_registry.get('klass')
131
171
self.assertEqual(plugin_name, klass.__module__)
132
172
self.assertEqual('MyClass', klass.__name__)