~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_branch.py

  • Committer: Vincent Ladeuil
  • Date: 2012-01-09 20:55:07 UTC
  • mto: This revision was merged to the branch mainline in revision 6468.
  • Revision ID: v.ladeuil+lp@free.fr-20120109205507-2i3en5r4w4ohdchj
Use idioms coherently and add comments to make their purpose clearer.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2006-2010 Canonical Ltd
 
1
# Copyright (C) 2006-2012 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
16
16
 
17
17
"""Tests for the Branch facility that are not interface  tests.
18
18
 
19
 
For interface tests see tests/per_branch/*.py.
 
19
For interface tests see `tests/per_branch/*.py`.
20
20
 
21
21
For concrete class tests see this file, and for meta-branch tests
22
22
also see this file.
29
29
    bzrdir,
30
30
    config,
31
31
    errors,
 
32
    symbol_versioning,
32
33
    tests,
33
34
    trace,
34
 
    transport,
35
35
    urlutils,
36
36
    )
37
37
 
40
40
 
41
41
    def test_default_format(self):
42
42
        # update this if you change the default branch format
43
 
        self.assertIsInstance(_mod_branch.BranchFormat.get_default_format(),
 
43
        self.assertIsInstance(_mod_branch.format_registry.get_default(),
44
44
                _mod_branch.BzrBranchFormat7)
45
45
 
46
46
    def test_default_format_is_same_as_bzrdir_default(self):
48
48
        # set, but at the moment that's not true -- mbp 20070814 --
49
49
        # https://bugs.launchpad.net/bzr/+bug/132376
50
50
        self.assertEqual(
51
 
            _mod_branch.BranchFormat.get_default_format(),
 
51
            _mod_branch.format_registry.get_default(),
52
52
            bzrdir.BzrDirFormat.get_default_format().get_branch_format())
53
53
 
54
54
    def test_get_set_default_format(self):
55
55
        # set the format and then set it back again
56
 
        old_format = _mod_branch.BranchFormat.get_default_format()
57
 
        _mod_branch.BranchFormat.set_default_format(SampleBranchFormat())
 
56
        old_format = _mod_branch.format_registry.get_default()
 
57
        _mod_branch.format_registry.set_default(SampleBranchFormat())
58
58
        try:
59
59
            # the default branch format is used by the meta dir format
60
60
            # which is not the default bzrdir format at this point
62
62
            result = dir.create_branch()
63
63
            self.assertEqual(result, 'A branch')
64
64
        finally:
65
 
            _mod_branch.BranchFormat.set_default_format(old_format)
 
65
            _mod_branch.format_registry.set_default(old_format)
66
66
        self.assertEqual(old_format,
67
 
                         _mod_branch.BranchFormat.get_default_format())
 
67
                         _mod_branch.format_registry.get_default())
68
68
 
69
69
 
70
70
class TestBranchFormat5(tests.TestCaseWithTransport):
74
74
        url = self.get_url()
75
75
        bdir = bzrdir.BzrDirMetaFormat1().initialize(url)
76
76
        bdir.create_repository()
77
 
        branch = bdir.create_branch()
 
77
        branch = _mod_branch.BzrBranchFormat5().initialize(bdir)
78
78
        t = self.get_transport()
79
79
        self.log("branch instance is %r" % branch)
80
80
        self.assert_(isinstance(branch, _mod_branch.BzrBranch5))
86
86
        self.assertIsDirectory('.bzr/branch/lock/held', t)
87
87
 
88
88
    def test_set_push_location(self):
89
 
        from bzrlib.config import (locations_config_filename,
90
 
                                   ensure_config_dir_exists)
91
 
        ensure_config_dir_exists()
92
 
        fn = locations_config_filename()
93
 
        # write correct newlines to locations.conf
94
 
        # by default ConfigObj uses native line-endings for new files
95
 
        # but uses already existing line-endings if file is not empty
96
 
        f = open(fn, 'wb')
97
 
        try:
98
 
            f.write('# comment\n')
99
 
        finally:
100
 
            f.close()
 
89
        conf = config.LocationConfig.from_string('# comment\n', '.', save=True)
101
90
 
102
91
        branch = self.make_branch('.', format='knit')
103
92
        branch.set_push_location('foo')
106
95
                             "[%s]\n"
107
96
                             "push_location = foo\n"
108
97
                             "push_location:policy = norecurse\n" % local_path,
109
 
                             fn)
 
98
                             config.locations_config_filename())
110
99
 
111
100
    # TODO RBC 20051029 test getting a push location from a branch in a
112
101
    # recursive section - that is, it appends the branch name.
113
102
 
114
103
 
115
 
class SampleBranchFormat(_mod_branch.BranchFormat):
 
104
class SampleBranchFormat(_mod_branch.BranchFormatMetadir):
116
105
    """A sample format
117
106
 
118
107
    this format is initializable, unsupported to aid in testing the
119
108
    open and open_downlevel routines.
120
109
    """
121
110
 
122
 
    def get_format_string(self):
 
111
    @classmethod
 
112
    def get_format_string(cls):
123
113
        """See BzrBranchFormat.get_format_string()."""
124
114
        return "Sample branch format."
125
115
 
126
 
    def initialize(self, a_bzrdir, name=None):
 
116
    def initialize(self, a_bzrdir, name=None, repository=None,
 
117
                   append_revisions_only=None):
127
118
        """Format 4 branches cannot be created."""
128
119
        t = a_bzrdir.get_branch_transport(self, name=name)
129
120
        t.put_bytes('format', self.get_format_string())
132
123
    def is_supported(self):
133
124
        return False
134
125
 
135
 
    def open(self, transport, name=None, _found=False, ignore_fallbacks=False):
 
126
    def open(self, transport, name=None, _found=False, ignore_fallbacks=False,
 
127
             possible_transports=None):
136
128
        return "opened branch."
137
129
 
138
130
 
 
131
# Demonstrating how lazy loading is often implemented:
 
132
# A constant string is created.
 
133
SampleSupportedBranchFormatString = "Sample supported branch format."
 
134
 
 
135
# And the format class can then reference the constant to avoid skew.
 
136
class SampleSupportedBranchFormat(_mod_branch.BranchFormatMetadir):
 
137
    """A sample supported format."""
 
138
 
 
139
    @classmethod
 
140
    def get_format_string(cls):
 
141
        """See BzrBranchFormat.get_format_string()."""
 
142
        return SampleSupportedBranchFormatString
 
143
 
 
144
    def initialize(self, a_bzrdir, name=None, append_revisions_only=None):
 
145
        t = a_bzrdir.get_branch_transport(self, name=name)
 
146
        t.put_bytes('format', self.get_format_string())
 
147
        return 'A branch'
 
148
 
 
149
    def open(self, transport, name=None, _found=False, ignore_fallbacks=False,
 
150
             possible_transports=None):
 
151
        return "opened supported branch."
 
152
 
 
153
 
 
154
class SampleExtraBranchFormat(_mod_branch.BranchFormat):
 
155
    """A sample format that is not usable in a metadir."""
 
156
 
 
157
    def get_format_string(self):
 
158
        # This format is not usable in a metadir.
 
159
        return None
 
160
 
 
161
    def network_name(self):
 
162
        # Network name always has to be provided.
 
163
        return "extra"
 
164
 
 
165
    def initialize(self, a_bzrdir, name=None):
 
166
        raise NotImplementedError(self.initialize)
 
167
 
 
168
    def open(self, transport, name=None, _found=False, ignore_fallbacks=False,
 
169
             possible_transports=None):
 
170
        raise NotImplementedError(self.open)
 
171
 
 
172
 
139
173
class TestBzrBranchFormat(tests.TestCaseWithTransport):
140
174
    """Tests for the BzrBranchFormat facility."""
141
175
 
148
182
            dir = format._matchingbzrdir.initialize(url)
149
183
            dir.create_repository()
150
184
            format.initialize(dir)
151
 
            found_format = _mod_branch.BranchFormat.find_format(dir)
152
 
            self.failUnless(isinstance(found_format, format.__class__))
 
185
            found_format = _mod_branch.BranchFormatMetadir.find_format(dir)
 
186
            self.assertIsInstance(found_format, format.__class__)
153
187
        check_format(_mod_branch.BzrBranchFormat5(), "bar")
154
188
 
 
189
    def test_find_format_factory(self):
 
190
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
 
191
        SampleSupportedBranchFormat().initialize(dir)
 
192
        factory = _mod_branch.MetaDirBranchFormatFactory(
 
193
            SampleSupportedBranchFormatString,
 
194
            "bzrlib.tests.test_branch", "SampleSupportedBranchFormat")
 
195
        _mod_branch.format_registry.register(factory)
 
196
        self.addCleanup(_mod_branch.format_registry.remove, factory)
 
197
        b = _mod_branch.Branch.open(self.get_url())
 
198
        self.assertEqual(b, "opened supported branch.")
 
199
 
 
200
    def test_from_string(self):
 
201
        self.assertIsInstance(
 
202
            SampleBranchFormat.from_string("Sample branch format."),
 
203
            SampleBranchFormat)
 
204
        self.assertRaises(AssertionError,
 
205
            SampleBranchFormat.from_string, "Different branch format.")
 
206
 
155
207
    def test_find_format_not_branch(self):
156
208
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
157
209
        self.assertRaises(errors.NotBranchError,
158
 
                          _mod_branch.BranchFormat.find_format,
 
210
                          _mod_branch.BranchFormatMetadir.find_format,
159
211
                          dir)
160
212
 
161
213
    def test_find_format_unknown_format(self):
162
214
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
163
215
        SampleBranchFormat().initialize(dir)
164
216
        self.assertRaises(errors.UnknownFormatError,
165
 
                          _mod_branch.BranchFormat.find_format,
 
217
                          _mod_branch.BranchFormatMetadir.find_format,
166
218
                          dir)
167
219
 
 
220
    def test_find_format_with_features(self):
 
221
        tree = self.make_branch_and_tree('.', format='2a')
 
222
        tree.branch.update_feature_flags({"name": "optional"})
 
223
        found_format = _mod_branch.BranchFormatMetadir.find_format(tree.bzrdir)
 
224
        self.assertIsInstance(found_format, _mod_branch.BranchFormatMetadir)
 
225
        self.assertEquals(found_format.features.get("name"), "optional")
 
226
        tree.branch.update_feature_flags({"name": None})
 
227
        branch = _mod_branch.Branch.open('.')
 
228
        self.assertEquals(branch._format.features, {})
 
229
 
168
230
    def test_register_unregister_format(self):
 
231
        # Test the deprecated format registration functions
169
232
        format = SampleBranchFormat()
170
233
        # make a control dir
171
234
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
172
235
        # make a branch
173
236
        format.initialize(dir)
174
237
        # register a format for it.
175
 
        _mod_branch.BranchFormat.register_format(format)
 
238
        self.applyDeprecated(symbol_versioning.deprecated_in((2, 4, 0)),
 
239
            _mod_branch.BranchFormat.register_format, format)
176
240
        # which branch.Open will refuse (not supported)
177
241
        self.assertRaises(errors.UnsupportedFormatError,
178
242
                          _mod_branch.Branch.open, self.get_url())
182
246
            format.open(dir),
183
247
            bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
184
248
        # unregister the format
185
 
        _mod_branch.BranchFormat.unregister_format(format)
 
249
        self.applyDeprecated(symbol_versioning.deprecated_in((2, 4, 0)),
 
250
            _mod_branch.BranchFormat.unregister_format, format)
186
251
        self.make_branch_and_tree('bar')
187
252
 
188
253
 
 
254
class TestBranchFormatRegistry(tests.TestCase):
 
255
 
 
256
    def setUp(self):
 
257
        super(TestBranchFormatRegistry, self).setUp()
 
258
        self.registry = _mod_branch.BranchFormatRegistry()
 
259
 
 
260
    def test_default(self):
 
261
        self.assertIs(None, self.registry.get_default())
 
262
        format = SampleBranchFormat()
 
263
        self.registry.set_default(format)
 
264
        self.assertEquals(format, self.registry.get_default())
 
265
 
 
266
    def test_register_unregister_format(self):
 
267
        format = SampleBranchFormat()
 
268
        self.registry.register(format)
 
269
        self.assertEquals(format,
 
270
            self.registry.get("Sample branch format."))
 
271
        self.registry.remove(format)
 
272
        self.assertRaises(KeyError, self.registry.get,
 
273
            "Sample branch format.")
 
274
 
 
275
    def test_get_all(self):
 
276
        format = SampleBranchFormat()
 
277
        self.assertEquals([], self.registry._get_all())
 
278
        self.registry.register(format)
 
279
        self.assertEquals([format], self.registry._get_all())
 
280
 
 
281
    def test_register_extra(self):
 
282
        format = SampleExtraBranchFormat()
 
283
        self.assertEquals([], self.registry._get_all())
 
284
        self.registry.register_extra(format)
 
285
        self.assertEquals([format], self.registry._get_all())
 
286
 
 
287
    def test_register_extra_lazy(self):
 
288
        self.assertEquals([], self.registry._get_all())
 
289
        self.registry.register_extra_lazy("bzrlib.tests.test_branch",
 
290
            "SampleExtraBranchFormat")
 
291
        formats = self.registry._get_all()
 
292
        self.assertEquals(1, len(formats))
 
293
        self.assertIsInstance(formats[0], SampleExtraBranchFormat)
 
294
 
 
295
 
 
296
#Used by TestMetaDirBranchFormatFactory 
 
297
FakeLazyFormat = None
 
298
 
 
299
 
 
300
class TestMetaDirBranchFormatFactory(tests.TestCase):
 
301
 
 
302
    def test_get_format_string_does_not_load(self):
 
303
        """Formats have a static format string."""
 
304
        factory = _mod_branch.MetaDirBranchFormatFactory("yo", None, None)
 
305
        self.assertEqual("yo", factory.get_format_string())
 
306
 
 
307
    def test_call_loads(self):
 
308
        # __call__ is used by the network_format_registry interface to get a
 
309
        # Format.
 
310
        global FakeLazyFormat
 
311
        del FakeLazyFormat
 
312
        factory = _mod_branch.MetaDirBranchFormatFactory(None,
 
313
            "bzrlib.tests.test_branch", "FakeLazyFormat")
 
314
        self.assertRaises(AttributeError, factory)
 
315
 
 
316
    def test_call_returns_call_of_referenced_object(self):
 
317
        global FakeLazyFormat
 
318
        FakeLazyFormat = lambda:'called'
 
319
        factory = _mod_branch.MetaDirBranchFormatFactory(None,
 
320
            "bzrlib.tests.test_branch", "FakeLazyFormat")
 
321
        self.assertEqual('called', factory())
 
322
 
 
323
 
189
324
class TestBranch67(object):
190
325
    """Common tests for both branch 6 and 7 which are mostly the same."""
191
326
 
210
345
 
211
346
    def test_layout(self):
212
347
        branch = self.make_branch('a', format=self.get_format_name())
213
 
        self.failUnlessExists('a/.bzr/branch/last-revision')
214
 
        self.failIfExists('a/.bzr/branch/revision-history')
215
 
        self.failIfExists('a/.bzr/branch/references')
 
348
        self.assertPathExists('a/.bzr/branch/last-revision')
 
349
        self.assertPathDoesNotExist('a/.bzr/branch/revision-history')
 
350
        self.assertPathDoesNotExist('a/.bzr/branch/references')
216
351
 
217
352
    def test_config(self):
218
353
        """Ensure that all configuration data is stored in the branch"""
219
354
        branch = self.make_branch('a', format=self.get_format_name())
220
 
        branch.set_parent('http://bazaar-vcs.org')
221
 
        self.failIfExists('a/.bzr/branch/parent')
222
 
        self.assertEqual('http://bazaar-vcs.org', branch.get_parent())
223
 
        branch.set_push_location('sftp://bazaar-vcs.org')
224
 
        config = branch.get_config()._get_branch_data_config()
225
 
        self.assertEqual('sftp://bazaar-vcs.org',
226
 
                         config.get_user_option('push_location'))
227
 
        branch.set_bound_location('ftp://bazaar-vcs.org')
228
 
        self.failIfExists('a/.bzr/branch/bound')
229
 
        self.assertEqual('ftp://bazaar-vcs.org', branch.get_bound_location())
 
355
        self.addCleanup(branch.lock_write().unlock)
 
356
        branch.set_parent('http://example.com')
 
357
        self.assertPathDoesNotExist('a/.bzr/branch/parent')
 
358
        self.assertEqual('http://example.com', branch.get_parent())
 
359
        branch.set_push_location('sftp://example.com')
 
360
        conf = branch.get_config_stack()
 
361
        self.assertEqual('sftp://example.com', conf.get('push_location'))
 
362
        branch.set_bound_location('ftp://example.com')
 
363
        self.assertPathDoesNotExist('a/.bzr/branch/bound')
 
364
        self.assertEqual('ftp://example.com', branch.get_bound_location())
230
365
 
231
366
    def test_set_revision_history(self):
232
367
        builder = self.make_branch_builder('.', format=self.get_format_name())
237
372
        branch = builder.get_branch()
238
373
        branch.lock_write()
239
374
        self.addCleanup(branch.unlock)
240
 
        branch.set_revision_history(['foo', 'bar'])
241
 
        branch.set_revision_history(['foo'])
 
375
        self.applyDeprecated(symbol_versioning.deprecated_in((2, 4, 0)),
 
376
            branch.set_revision_history, ['foo', 'bar'])
 
377
        self.applyDeprecated(symbol_versioning.deprecated_in((2, 4, 0)),
 
378
                branch.set_revision_history, ['foo'])
242
379
        self.assertRaises(errors.NotLefthandHistory,
243
 
                          branch.set_revision_history, ['bar'])
 
380
            self.applyDeprecated, symbol_versioning.deprecated_in((2, 4, 0)),
 
381
            branch.set_revision_history, ['bar'])
244
382
 
245
383
    def do_checkout_test(self, lightweight=False):
246
384
        tree = self.make_branch_and_tree('source',
259
397
        subtree.commit('a subtree file')
260
398
        subsubtree.commit('a subsubtree file')
261
399
        tree.branch.create_checkout('target', lightweight=lightweight)
262
 
        self.failUnlessExists('target')
263
 
        self.failUnlessExists('target/subtree')
264
 
        self.failUnlessExists('target/subtree/file')
265
 
        self.failUnlessExists('target/subtree/subsubtree/file')
 
400
        self.assertPathExists('target')
 
401
        self.assertPathExists('target/subtree')
 
402
        self.assertPathExists('target/subtree/file')
 
403
        self.assertPathExists('target/subtree/subsubtree/file')
266
404
        subbranch = _mod_branch.Branch.open('target/subtree/subsubtree')
267
405
        if lightweight:
268
406
            self.assertEndsWith(subbranch.base, 'source/subtree/subsubtree/')
275
413
    def test_light_checkout_with_references(self):
276
414
        self.do_checkout_test(lightweight=True)
277
415
 
278
 
    def test_set_push(self):
279
 
        branch = self.make_branch('source', format=self.get_format_name())
280
 
        branch.get_config().set_user_option('push_location', 'old',
281
 
            store=config.STORE_LOCATION)
282
 
        warnings = []
283
 
        def warning(*args):
284
 
            warnings.append(args[0] % args[1:])
285
 
        _warning = trace.warning
286
 
        trace.warning = warning
287
 
        try:
288
 
            branch.set_push_location('new')
289
 
        finally:
290
 
            trace.warning = _warning
291
 
        self.assertEqual(warnings[0], 'Value "new" is masked by "old" from '
292
 
                         'locations.conf')
293
 
 
294
416
 
295
417
class TestBranch6(TestBranch67, tests.TestCaseWithTransport):
296
418
 
425
547
        self.assertEqual(('path3', 'location3'),
426
548
                         branch.get_reference_info('file-id'))
427
549
 
 
550
    def _recordParentMapCalls(self, repo):
 
551
        self._parent_map_calls = []
 
552
        orig_get_parent_map = repo.revisions.get_parent_map
 
553
        def get_parent_map(q):
 
554
            q = list(q)
 
555
            self._parent_map_calls.extend([e[0] for e in q])
 
556
            return orig_get_parent_map(q)
 
557
        repo.revisions.get_parent_map = get_parent_map
 
558
 
 
559
 
428
560
class TestBranchReference(tests.TestCaseWithTransport):
429
561
    """Tests for the branch reference facility."""
430
562
 
431
563
    def test_create_open_reference(self):
432
564
        bzrdirformat = bzrdir.BzrDirMetaFormat1()
433
 
        t = transport.get_transport(self.get_url('.'))
 
565
        t = self.get_transport()
434
566
        t.mkdir('repo')
435
567
        dir = bzrdirformat.initialize(self.get_url('repo'))
436
568
        dir.create_repository()
492
624
        self.assertTrue(hasattr(params, 'bzrdir'))
493
625
        self.assertTrue(hasattr(params, 'branch'))
494
626
 
 
627
    def test_post_branch_init_hook_repr(self):
 
628
        param_reprs = []
 
629
        _mod_branch.Branch.hooks.install_named_hook('post_branch_init',
 
630
            lambda params: param_reprs.append(repr(params)), None)
 
631
        branch = self.make_branch('a')
 
632
        self.assertLength(1, param_reprs)
 
633
        param_repr = param_reprs[0]
 
634
        self.assertStartsWith(param_repr, '<BranchInitHookParams of ')
 
635
 
495
636
    def test_post_switch_hook(self):
496
637
        from bzrlib import switch
497
638
        calls = []
521
662
    def setUp(self):
522
663
        super(TestBranchOptions, self).setUp()
523
664
        self.branch = self.make_branch('.')
524
 
        self.config = self.branch.get_config()
 
665
        self.config_stack = self.branch.get_config_stack()
525
666
 
526
667
    def check_append_revisions_only(self, expected_value, value=None):
527
668
        """Set append_revisions_only in config and check its interpretation."""
528
669
        if value is not None:
529
 
            self.config.set_user_option('append_revisions_only', value)
 
670
            self.branch.lock_write()
 
671
            try:
 
672
                self.config_stack.set('append_revisions_only', value)
 
673
            finally:
 
674
                self.branch.unlock()
530
675
        self.assertEqual(expected_value,
531
 
                         self.branch._get_append_revisions_only())
 
676
                         self.branch.get_append_revisions_only())
532
677
 
533
678
    def test_valid_append_revisions_only(self):
534
679
        self.assertEquals(None,
535
 
                          self.config.get_user_option('append_revisions_only'))
 
680
                          self.config_stack.get('append_revisions_only'))
536
681
        self.check_append_revisions_only(None)
537
682
        self.check_append_revisions_only(False, 'False')
538
683
        self.check_append_revisions_only(True, 'True')
550
695
        self.check_append_revisions_only(None, 'not-a-bool')
551
696
        self.assertLength(1, self.warnings)
552
697
        self.assertEqual(
553
 
            'Value "not-a-bool" is not a boolean for "append_revisions_only"',
 
698
            'Value "not-a-bool" is not valid for "append_revisions_only"',
554
699
            self.warnings[0])
555
700
 
556
701
 
564
709
        # this usage of results is not recommended for new code (because it
565
710
        # doesn't describe very well what happened), but for api stability
566
711
        # it's still supported
567
 
        a = "%d revisions pulled" % r
568
 
        self.assertEqual(a, "10 revisions pulled")
 
712
        self.assertEqual(self.applyDeprecated(
 
713
            symbol_versioning.deprecated_in((2, 3, 0)),
 
714
            r.__int__),
 
715
            10)
569
716
 
570
717
    def test_report_changed(self):
571
718
        r = _mod_branch.PullResult()
576
723
        f = StringIO()
577
724
        r.report(f)
578
725
        self.assertEqual("Now on revision 20.\n", f.getvalue())
 
726
        self.assertEqual("Now on revision 20.\n", f.getvalue())
579
727
 
580
728
    def test_report_unchanged(self):
581
729
        r = _mod_branch.PullResult()
583
731
        r.new_revid = "same-revid"
584
732
        f = StringIO()
585
733
        r.report(f)
586
 
        self.assertEqual("No revisions to pull.\n", f.getvalue())
587
 
 
588
 
 
589
 
class _StubLockable(object):
590
 
    """Helper for TestRunWithWriteLockedTarget."""
591
 
 
592
 
    def __init__(self, calls, unlock_exc=None):
593
 
        self.calls = calls
594
 
        self.unlock_exc = unlock_exc
595
 
 
596
 
    def lock_write(self):
597
 
        self.calls.append('lock_write')
598
 
 
599
 
    def unlock(self):
600
 
        self.calls.append('unlock')
601
 
        if self.unlock_exc is not None:
602
 
            raise self.unlock_exc
603
 
 
604
 
 
605
 
class _ErrorFromCallable(Exception):
606
 
    """Helper for TestRunWithWriteLockedTarget."""
607
 
 
608
 
 
609
 
class _ErrorFromUnlock(Exception):
610
 
    """Helper for TestRunWithWriteLockedTarget."""
611
 
 
612
 
 
613
 
class TestRunWithWriteLockedTarget(tests.TestCase):
614
 
    """Tests for _run_with_write_locked_target."""
615
 
 
616
 
    def setUp(self):
617
 
        tests.TestCase.setUp(self)
618
 
        self._calls = []
619
 
 
620
 
    def func_that_returns_ok(self):
621
 
        self._calls.append('func called')
622
 
        return 'ok'
623
 
 
624
 
    def func_that_raises(self):
625
 
        self._calls.append('func called')
626
 
        raise _ErrorFromCallable()
627
 
 
628
 
    def test_success_unlocks(self):
629
 
        lockable = _StubLockable(self._calls)
630
 
        result = _mod_branch._run_with_write_locked_target(
631
 
            lockable, self.func_that_returns_ok)
632
 
        self.assertEqual('ok', result)
633
 
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
634
 
 
635
 
    def test_exception_unlocks_and_propagates(self):
636
 
        lockable = _StubLockable(self._calls)
637
 
        self.assertRaises(_ErrorFromCallable,
638
 
                          _mod_branch._run_with_write_locked_target,
639
 
                          lockable, self.func_that_raises)
640
 
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
641
 
 
642
 
    def test_callable_succeeds_but_error_during_unlock(self):
643
 
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
644
 
        self.assertRaises(_ErrorFromUnlock,
645
 
                          _mod_branch._run_with_write_locked_target,
646
 
                          lockable, self.func_that_returns_ok)
647
 
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
648
 
 
649
 
    def test_error_during_unlock_does_not_mask_original_error(self):
650
 
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
651
 
        self.assertRaises(_ErrorFromCallable,
652
 
                          _mod_branch._run_with_write_locked_target,
653
 
                          lockable, self.func_that_raises)
654
 
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
655
 
 
656
 
 
 
734
        self.assertEqual("No revisions or tags to pull.\n", f.getvalue())