13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17
17
"""Tests for the Branch facility that are not interface tests.
19
For interface tests see tests/branch_implementations/*.py.
19
For interface tests see tests/per_branch/*.py.
21
21
For concrete class tests see this file, and for meta-branch tests
22
22
also see this file.
25
from StringIO import StringIO
28
from bzrlib.branch import (BzrBranch5,
30
import bzrlib.bzrdir as bzrdir
31
from bzrlib.bzrdir import (BzrDirMetaFormat1, BzrDirMeta1,
33
from bzrlib.errors import (NotBranchError,
35
UnsupportedFormatError,
38
from bzrlib.tests import TestCase, TestCaseWithTransport
39
from bzrlib.transport import get_transport
41
class TestDefaultFormat(TestCase):
25
from cStringIO import StringIO
28
branch as _mod_branch,
40
class TestDefaultFormat(tests.TestCase):
42
def test_default_format(self):
43
# update this if you change the default branch format
44
self.assertIsInstance(_mod_branch.BranchFormat.get_default_format(),
45
_mod_branch.BzrBranchFormat7)
47
def test_default_format_is_same_as_bzrdir_default(self):
48
# XXX: it might be nice if there was only one place the default was
49
# set, but at the moment that's not true -- mbp 20070814 --
50
# https://bugs.launchpad.net/bzr/+bug/132376
52
_mod_branch.BranchFormat.get_default_format(),
53
bzrdir.BzrDirFormat.get_default_format().get_branch_format())
43
55
def test_get_set_default_format(self):
44
old_format = bzrlib.branch.BranchFormat.get_default_format()
46
self.assertTrue(isinstance(old_format, bzrlib.branch.BzrBranchFormat5))
47
bzrlib.branch.BranchFormat.set_default_format(SampleBranchFormat())
56
# set the format and then set it back again
57
old_format = _mod_branch.BranchFormat.get_default_format()
58
_mod_branch.BranchFormat.set_default_format(SampleBranchFormat())
49
60
# the default branch format is used by the meta dir format
50
61
# which is not the default bzrdir format at this point
51
dir = BzrDirMetaFormat1().initialize('memory:///')
62
dir = bzrdir.BzrDirMetaFormat1().initialize('memory:///')
52
63
result = dir.create_branch()
53
64
self.assertEqual(result, 'A branch')
55
bzrlib.branch.BranchFormat.set_default_format(old_format)
56
self.assertEqual(old_format, bzrlib.branch.BranchFormat.get_default_format())
59
class TestBranchFormat5(TestCaseWithTransport):
66
_mod_branch.BranchFormat.set_default_format(old_format)
67
self.assertEqual(old_format,
68
_mod_branch.BranchFormat.get_default_format())
71
class TestBranchFormat5(tests.TestCaseWithTransport):
60
72
"""Tests specific to branch format 5"""
62
74
def test_branch_format_5_uses_lockdir(self):
63
75
url = self.get_url()
64
bzrdir = BzrDirMetaFormat1().initialize(url)
65
bzrdir.create_repository()
66
branch = bzrdir.create_branch()
76
bdir = bzrdir.BzrDirMetaFormat1().initialize(url)
77
bdir.create_repository()
78
branch = bdir.create_branch()
67
79
t = self.get_transport()
68
80
self.log("branch instance is %r" % branch)
69
self.assert_(isinstance(branch, BzrBranch5))
81
self.assert_(isinstance(branch, _mod_branch.BzrBranch5))
70
82
self.assertIsDirectory('.', t)
71
83
self.assertIsDirectory('.bzr/branch', t)
72
84
self.assertIsDirectory('.bzr/branch/lock', t)
73
85
branch.lock_write()
75
self.assertIsDirectory('.bzr/branch/lock/held', t)
80
class TestBranchEscaping(TestCaseWithTransport):
81
"""Test a branch can be correctly stored and used on a vfat-like transport
83
Makes sure we have proper escaping of invalid characters, etc.
85
It'd be better to test all operations on the FakeVFATTransportDecorator,
86
but working trees go straight to the os not through the Transport layer.
87
Therefore we build some history first in the regular way and then
88
check it's safe to access for vfat.
95
super(TestBranchEscaping, self).setUp()
96
from bzrlib.repository import RepositoryFormatKnit1
97
bzrdir = BzrDirMetaFormat1().initialize(self.get_url())
98
repo = RepositoryFormatKnit1().initialize(bzrdir)
99
branch = bzrdir.create_branch()
100
wt = bzrdir.create_workingtree()
101
self.build_tree_contents([("foo", "contents of foo")])
102
# add file with id containing wierd characters
103
wt.add(['foo'], [self.FOO_ID])
104
wt.commit('this is my new commit', rev_id=self.REV_ID)
106
def test_branch_on_vfat(self):
107
from bzrlib.transport.fakevfat import FakeVFATTransportDecorator
108
# now access over vfat; should be safe
109
transport = FakeVFATTransportDecorator('vfat+' + self.get_url())
110
bzrdir, junk = BzrDir.open_containing_from_transport(transport)
111
branch = bzrdir.open_branch()
112
revtree = branch.repository.revision_tree(self.REV_ID)
113
contents = revtree.get_file_text(self.FOO_ID)
114
self.assertEqual(contents, 'contents of foo')
117
class SampleBranchFormat(bzrlib.branch.BranchFormat):
86
self.addCleanup(branch.unlock)
87
self.assertIsDirectory('.bzr/branch/lock/held', t)
89
def test_set_push_location(self):
90
conf = config.LocationConfig.from_string('# comment\n', '.', save=True)
92
branch = self.make_branch('.', format='knit')
93
branch.set_push_location('foo')
94
local_path = urlutils.local_path_from_url(branch.base[:-1])
95
self.assertFileEqual("# comment\n"
97
"push_location = foo\n"
98
"push_location:policy = norecurse\n" % local_path,
99
config.locations_config_filename())
101
# TODO RBC 20051029 test getting a push location from a branch in a
102
# recursive section - that is, it appends the branch name.
105
class SampleBranchFormat(_mod_branch.BranchFormat):
118
106
"""A sample format
120
this format is initializable, unsupported to aid in testing the
108
this format is initializable, unsupported to aid in testing the
121
109
open and open_downlevel routines.
175
195
format.initialize(dir)
176
196
# register a format for it.
177
bzrlib.branch.BranchFormat.register_format(format)
197
_mod_branch.BranchFormat.register_format(format)
178
198
# which branch.Open will refuse (not supported)
179
self.assertRaises(UnsupportedFormatError, bzrlib.branch.Branch.open, self.get_url())
199
self.assertRaises(errors.UnsupportedFormatError,
200
_mod_branch.Branch.open, self.get_url())
201
self.make_branch_and_tree('foo')
180
202
# but open_downlevel will work
181
self.assertEqual(format.open(dir), bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
205
bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
182
206
# unregister the format
183
bzrlib.branch.BranchFormat.unregister_format(format)
186
class TestBranchReference(TestCaseWithTransport):
207
_mod_branch.BranchFormat.unregister_format(format)
208
self.make_branch_and_tree('bar')
211
#Used by TestMetaDirBranchFormatFactory
212
FakeLazyFormat = None
215
class TestMetaDirBranchFormatFactory(tests.TestCase):
217
def test_get_format_string_does_not_load(self):
218
"""Formats have a static format string."""
219
factory = _mod_branch.MetaDirBranchFormatFactory("yo", None, None)
220
self.assertEqual("yo", factory.get_format_string())
222
def test_call_loads(self):
223
# __call__ is used by the network_format_registry interface to get a
225
global FakeLazyFormat
227
factory = _mod_branch.MetaDirBranchFormatFactory(None,
228
"bzrlib.tests.test_branch", "FakeLazyFormat")
229
self.assertRaises(AttributeError, factory)
231
def test_call_returns_call_of_referenced_object(self):
232
global FakeLazyFormat
233
FakeLazyFormat = lambda:'called'
234
factory = _mod_branch.MetaDirBranchFormatFactory(None,
235
"bzrlib.tests.test_branch", "FakeLazyFormat")
236
self.assertEqual('called', factory())
239
class TestBranch67(object):
240
"""Common tests for both branch 6 and 7 which are mostly the same."""
242
def get_format_name(self):
243
raise NotImplementedError(self.get_format_name)
245
def get_format_name_subtree(self):
246
raise NotImplementedError(self.get_format_name)
249
raise NotImplementedError(self.get_class)
251
def test_creation(self):
252
format = bzrdir.BzrDirMetaFormat1()
253
format.set_branch_format(_mod_branch.BzrBranchFormat6())
254
branch = self.make_branch('a', format=format)
255
self.assertIsInstance(branch, self.get_class())
256
branch = self.make_branch('b', format=self.get_format_name())
257
self.assertIsInstance(branch, self.get_class())
258
branch = _mod_branch.Branch.open('a')
259
self.assertIsInstance(branch, self.get_class())
261
def test_layout(self):
262
branch = self.make_branch('a', format=self.get_format_name())
263
self.failUnlessExists('a/.bzr/branch/last-revision')
264
self.failIfExists('a/.bzr/branch/revision-history')
265
self.failIfExists('a/.bzr/branch/references')
267
def test_config(self):
268
"""Ensure that all configuration data is stored in the branch"""
269
branch = self.make_branch('a', format=self.get_format_name())
270
branch.set_parent('http://bazaar-vcs.org')
271
self.failIfExists('a/.bzr/branch/parent')
272
self.assertEqual('http://bazaar-vcs.org', branch.get_parent())
273
branch.set_push_location('sftp://bazaar-vcs.org')
274
config = branch.get_config()._get_branch_data_config()
275
self.assertEqual('sftp://bazaar-vcs.org',
276
config.get_user_option('push_location'))
277
branch.set_bound_location('ftp://bazaar-vcs.org')
278
self.failIfExists('a/.bzr/branch/bound')
279
self.assertEqual('ftp://bazaar-vcs.org', branch.get_bound_location())
281
def test_set_revision_history(self):
282
builder = self.make_branch_builder('.', format=self.get_format_name())
283
builder.build_snapshot('foo', None,
284
[('add', ('', None, 'directory', None))],
286
builder.build_snapshot('bar', None, [], message='bar')
287
branch = builder.get_branch()
289
self.addCleanup(branch.unlock)
290
branch.set_revision_history(['foo', 'bar'])
291
branch.set_revision_history(['foo'])
292
self.assertRaises(errors.NotLefthandHistory,
293
branch.set_revision_history, ['bar'])
295
def do_checkout_test(self, lightweight=False):
296
tree = self.make_branch_and_tree('source',
297
format=self.get_format_name_subtree())
298
subtree = self.make_branch_and_tree('source/subtree',
299
format=self.get_format_name_subtree())
300
subsubtree = self.make_branch_and_tree('source/subtree/subsubtree',
301
format=self.get_format_name_subtree())
302
self.build_tree(['source/subtree/file',
303
'source/subtree/subsubtree/file'])
304
subsubtree.add('file')
306
subtree.add_reference(subsubtree)
307
tree.add_reference(subtree)
308
tree.commit('a revision')
309
subtree.commit('a subtree file')
310
subsubtree.commit('a subsubtree file')
311
tree.branch.create_checkout('target', lightweight=lightweight)
312
self.failUnlessExists('target')
313
self.failUnlessExists('target/subtree')
314
self.failUnlessExists('target/subtree/file')
315
self.failUnlessExists('target/subtree/subsubtree/file')
316
subbranch = _mod_branch.Branch.open('target/subtree/subsubtree')
318
self.assertEndsWith(subbranch.base, 'source/subtree/subsubtree/')
320
self.assertEndsWith(subbranch.base, 'target/subtree/subsubtree/')
322
def test_checkout_with_references(self):
323
self.do_checkout_test()
325
def test_light_checkout_with_references(self):
326
self.do_checkout_test(lightweight=True)
328
def test_set_push(self):
329
branch = self.make_branch('source', format=self.get_format_name())
330
branch.get_config().set_user_option('push_location', 'old',
331
store=config.STORE_LOCATION)
334
warnings.append(args[0] % args[1:])
335
_warning = trace.warning
336
trace.warning = warning
338
branch.set_push_location('new')
340
trace.warning = _warning
341
self.assertEqual(warnings[0], 'Value "new" is masked by "old" from '
345
class TestBranch6(TestBranch67, tests.TestCaseWithTransport):
348
return _mod_branch.BzrBranch6
350
def get_format_name(self):
351
return "dirstate-tags"
353
def get_format_name_subtree(self):
354
return "dirstate-with-subtree"
356
def test_set_stacked_on_url_errors(self):
357
branch = self.make_branch('a', format=self.get_format_name())
358
self.assertRaises(errors.UnstackableBranchFormat,
359
branch.set_stacked_on_url, None)
361
def test_default_stacked_location(self):
362
branch = self.make_branch('a', format=self.get_format_name())
363
self.assertRaises(errors.UnstackableBranchFormat, branch.get_stacked_on_url)
366
class TestBranch7(TestBranch67, tests.TestCaseWithTransport):
369
return _mod_branch.BzrBranch7
371
def get_format_name(self):
374
def get_format_name_subtree(self):
375
return "development-subtree"
377
def test_set_stacked_on_url_unstackable_repo(self):
378
repo = self.make_repository('a', format='dirstate-tags')
379
control = repo.bzrdir
380
branch = _mod_branch.BzrBranchFormat7().initialize(control)
381
target = self.make_branch('b')
382
self.assertRaises(errors.UnstackableRepositoryFormat,
383
branch.set_stacked_on_url, target.base)
385
def test_clone_stacked_on_unstackable_repo(self):
386
repo = self.make_repository('a', format='dirstate-tags')
387
control = repo.bzrdir
388
branch = _mod_branch.BzrBranchFormat7().initialize(control)
389
# Calling clone should not raise UnstackableRepositoryFormat.
390
cloned_bzrdir = control.clone('cloned')
392
def _test_default_stacked_location(self):
393
branch = self.make_branch('a', format=self.get_format_name())
394
self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
396
def test_stack_and_unstack(self):
397
branch = self.make_branch('a', format=self.get_format_name())
398
target = self.make_branch_and_tree('b', format=self.get_format_name())
399
branch.set_stacked_on_url(target.branch.base)
400
self.assertEqual(target.branch.base, branch.get_stacked_on_url())
401
revid = target.commit('foo')
402
self.assertTrue(branch.repository.has_revision(revid))
403
branch.set_stacked_on_url(None)
404
self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
405
self.assertFalse(branch.repository.has_revision(revid))
407
def test_open_opens_stacked_reference(self):
408
branch = self.make_branch('a', format=self.get_format_name())
409
target = self.make_branch_and_tree('b', format=self.get_format_name())
410
branch.set_stacked_on_url(target.branch.base)
411
branch = branch.bzrdir.open_branch()
412
revid = target.commit('foo')
413
self.assertTrue(branch.repository.has_revision(revid))
416
class BzrBranch8(tests.TestCaseWithTransport):
418
def make_branch(self, location, format=None):
420
format = bzrdir.format_registry.make_bzrdir('1.9')
421
format.set_branch_format(_mod_branch.BzrBranchFormat8())
422
return tests.TestCaseWithTransport.make_branch(
423
self, location, format=format)
425
def create_branch_with_reference(self):
426
branch = self.make_branch('branch')
427
branch._set_all_reference_info({'file-id': ('path', 'location')})
431
def instrument_branch(branch, gets):
432
old_get = branch._transport.get
433
def get(*args, **kwargs):
434
gets.append((args, kwargs))
435
return old_get(*args, **kwargs)
436
branch._transport.get = get
438
def test_reference_info_caching_read_locked(self):
440
branch = self.create_branch_with_reference()
442
self.addCleanup(branch.unlock)
443
self.instrument_branch(branch, gets)
444
branch.get_reference_info('file-id')
445
branch.get_reference_info('file-id')
446
self.assertEqual(1, len(gets))
448
def test_reference_info_caching_read_unlocked(self):
450
branch = self.create_branch_with_reference()
451
self.instrument_branch(branch, gets)
452
branch.get_reference_info('file-id')
453
branch.get_reference_info('file-id')
454
self.assertEqual(2, len(gets))
456
def test_reference_info_caching_write_locked(self):
458
branch = self.make_branch('branch')
460
self.instrument_branch(branch, gets)
461
self.addCleanup(branch.unlock)
462
branch._set_all_reference_info({'file-id': ('path2', 'location2')})
463
path, location = branch.get_reference_info('file-id')
464
self.assertEqual(0, len(gets))
465
self.assertEqual('path2', path)
466
self.assertEqual('location2', location)
468
def test_reference_info_caches_cleared(self):
469
branch = self.make_branch('branch')
471
branch.set_reference_info('file-id', 'path2', 'location2')
473
doppelganger = _mod_branch.Branch.open('branch')
474
doppelganger.set_reference_info('file-id', 'path3', 'location3')
475
self.assertEqual(('path3', 'location3'),
476
branch.get_reference_info('file-id'))
478
class TestBranchReference(tests.TestCaseWithTransport):
187
479
"""Tests for the branch reference facility."""
189
481
def test_create_open_reference(self):
190
482
bzrdirformat = bzrdir.BzrDirMetaFormat1()
191
t = get_transport(self.get_url('.'))
483
t = transport.get_transport(self.get_url('.'))
193
485
dir = bzrdirformat.initialize(self.get_url('repo'))
194
486
dir.create_repository()
195
487
target_branch = dir.create_branch()
196
488
t.mkdir('branch')
197
489
branch_dir = bzrdirformat.initialize(self.get_url('branch'))
198
made_branch = bzrlib.branch.BranchReferenceFormat().initialize(branch_dir, target_branch)
490
made_branch = _mod_branch.BranchReferenceFormat().initialize(
491
branch_dir, target_branch=target_branch)
199
492
self.assertEqual(made_branch.base, target_branch.base)
200
493
opened_branch = branch_dir.open_branch()
201
494
self.assertEqual(opened_branch.base, target_branch.base)
496
def test_get_reference(self):
497
"""For a BranchReference, get_reference should reutrn the location."""
498
branch = self.make_branch('target')
499
checkout = branch.create_checkout('checkout', lightweight=True)
500
reference_url = branch.bzrdir.root_transport.abspath('') + '/'
501
# if the api for create_checkout changes to return different checkout types
502
# then this file read will fail.
503
self.assertFileEqual(reference_url, 'checkout/.bzr/branch/location')
504
self.assertEqual(reference_url,
505
_mod_branch.BranchReferenceFormat().get_reference(checkout.bzrdir))
508
class TestHooks(tests.TestCaseWithTransport):
510
def test_constructor(self):
511
"""Check that creating a BranchHooks instance has the right defaults."""
512
hooks = _mod_branch.BranchHooks()
513
self.assertTrue("set_rh" in hooks, "set_rh not in %s" % hooks)
514
self.assertTrue("post_push" in hooks, "post_push not in %s" % hooks)
515
self.assertTrue("post_commit" in hooks, "post_commit not in %s" % hooks)
516
self.assertTrue("pre_commit" in hooks, "pre_commit not in %s" % hooks)
517
self.assertTrue("post_pull" in hooks, "post_pull not in %s" % hooks)
518
self.assertTrue("post_uncommit" in hooks,
519
"post_uncommit not in %s" % hooks)
520
self.assertTrue("post_change_branch_tip" in hooks,
521
"post_change_branch_tip not in %s" % hooks)
522
self.assertTrue("post_branch_init" in hooks,
523
"post_branch_init not in %s" % hooks)
524
self.assertTrue("post_switch" in hooks,
525
"post_switch not in %s" % hooks)
527
def test_installed_hooks_are_BranchHooks(self):
528
"""The installed hooks object should be a BranchHooks."""
529
# the installed hooks are saved in self._preserved_hooks.
530
self.assertIsInstance(self._preserved_hooks[_mod_branch.Branch][1],
531
_mod_branch.BranchHooks)
533
def test_post_branch_init_hook(self):
535
_mod_branch.Branch.hooks.install_named_hook('post_branch_init',
537
self.assertLength(0, calls)
538
branch = self.make_branch('a')
539
self.assertLength(1, calls)
541
self.assertIsInstance(params, _mod_branch.BranchInitHookParams)
542
self.assertTrue(hasattr(params, 'bzrdir'))
543
self.assertTrue(hasattr(params, 'branch'))
545
def test_post_branch_init_hook_repr(self):
547
_mod_branch.Branch.hooks.install_named_hook('post_branch_init',
548
lambda params: param_reprs.append(repr(params)), None)
549
branch = self.make_branch('a')
550
self.assertLength(1, param_reprs)
551
param_repr = param_reprs[0]
552
self.assertStartsWith(param_repr, '<BranchInitHookParams of ')
554
def test_post_switch_hook(self):
555
from bzrlib import switch
557
_mod_branch.Branch.hooks.install_named_hook('post_switch',
559
tree = self.make_branch_and_tree('branch-1')
560
self.build_tree(['branch-1/file-1'])
563
to_branch = tree.bzrdir.sprout('branch-2').open_branch()
564
self.build_tree(['branch-1/file-2'])
566
tree.remove('file-1')
568
checkout = tree.branch.create_checkout('checkout')
569
self.assertLength(0, calls)
570
switch.switch(checkout.bzrdir, to_branch)
571
self.assertLength(1, calls)
573
self.assertIsInstance(params, _mod_branch.SwitchHookParams)
574
self.assertTrue(hasattr(params, 'to_branch'))
575
self.assertTrue(hasattr(params, 'revision_id'))
578
class TestBranchOptions(tests.TestCaseWithTransport):
581
super(TestBranchOptions, self).setUp()
582
self.branch = self.make_branch('.')
583
self.config = self.branch.get_config()
585
def check_append_revisions_only(self, expected_value, value=None):
586
"""Set append_revisions_only in config and check its interpretation."""
587
if value is not None:
588
self.config.set_user_option('append_revisions_only', value)
589
self.assertEqual(expected_value,
590
self.branch._get_append_revisions_only())
592
def test_valid_append_revisions_only(self):
593
self.assertEquals(None,
594
self.config.get_user_option('append_revisions_only'))
595
self.check_append_revisions_only(None)
596
self.check_append_revisions_only(False, 'False')
597
self.check_append_revisions_only(True, 'True')
598
# The following values will cause compatibility problems on projects
599
# using older bzr versions (<2.2) but are accepted
600
self.check_append_revisions_only(False, 'false')
601
self.check_append_revisions_only(True, 'true')
603
def test_invalid_append_revisions_only(self):
604
"""Ensure warning is noted on invalid settings"""
607
self.warnings.append(args[0] % args[1:])
608
self.overrideAttr(trace, 'warning', warning)
609
self.check_append_revisions_only(None, 'not-a-bool')
610
self.assertLength(1, self.warnings)
612
'Value "not-a-bool" is not a boolean for "append_revisions_only"',
616
class TestPullResult(tests.TestCase):
618
def test_pull_result_to_int(self):
619
# to support old code, the pull result can be used as an int
620
r = _mod_branch.PullResult()
623
# this usage of results is not recommended for new code (because it
624
# doesn't describe very well what happened), but for api stability
625
# it's still supported
626
self.assertEqual(self.applyDeprecated(
627
symbol_versioning.deprecated_in((2, 3, 0)),
631
def test_report_changed(self):
632
r = _mod_branch.PullResult()
633
r.old_revid = "old-revid"
635
r.new_revid = "new-revid"
639
self.assertEqual("Now on revision 20.\n", f.getvalue())
641
def test_report_unchanged(self):
642
r = _mod_branch.PullResult()
643
r.old_revid = "same-revid"
644
r.new_revid = "same-revid"
647
self.assertEqual("No revisions to pull.\n", f.getvalue())
650
class _StubLockable(object):
651
"""Helper for TestRunWithWriteLockedTarget."""
653
def __init__(self, calls, unlock_exc=None):
655
self.unlock_exc = unlock_exc
657
def lock_write(self):
658
self.calls.append('lock_write')
661
self.calls.append('unlock')
662
if self.unlock_exc is not None:
663
raise self.unlock_exc
666
class _ErrorFromCallable(Exception):
667
"""Helper for TestRunWithWriteLockedTarget."""
670
class _ErrorFromUnlock(Exception):
671
"""Helper for TestRunWithWriteLockedTarget."""
674
class TestRunWithWriteLockedTarget(tests.TestCase):
675
"""Tests for _run_with_write_locked_target."""
678
tests.TestCase.setUp(self)
681
def func_that_returns_ok(self):
682
self._calls.append('func called')
685
def func_that_raises(self):
686
self._calls.append('func called')
687
raise _ErrorFromCallable()
689
def test_success_unlocks(self):
690
lockable = _StubLockable(self._calls)
691
result = _mod_branch._run_with_write_locked_target(
692
lockable, self.func_that_returns_ok)
693
self.assertEqual('ok', result)
694
self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
696
def test_exception_unlocks_and_propagates(self):
697
lockable = _StubLockable(self._calls)
698
self.assertRaises(_ErrorFromCallable,
699
_mod_branch._run_with_write_locked_target,
700
lockable, self.func_that_raises)
701
self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
703
def test_callable_succeeds_but_error_during_unlock(self):
704
lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
705
self.assertRaises(_ErrorFromUnlock,
706
_mod_branch._run_with_write_locked_target,
707
lockable, self.func_that_returns_ok)
708
self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
710
def test_error_during_unlock_does_not_mask_original_error(self):
711
lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
712
self.assertRaises(_ErrorFromCallable,
713
_mod_branch._run_with_write_locked_target,
714
lockable, self.func_that_raises)
715
self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)