~bzr-pqm/bzr/bzr.dev

« back to all changes in this revision

Viewing changes to bzrlib/tests/test_branch.py

  • Committer: Robert Collins
  • Date: 2010-04-08 04:34:03 UTC
  • mfrom: (5138 +trunk)
  • mto: This revision was merged to the branch mainline in revision 5139.
  • Revision ID: robertc@robertcollins.net-20100408043403-56z0d07vdqrx7f3t
Update bugfix for 528114 to trunk.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
 
1
# Copyright (C) 2006-2010 Canonical Ltd
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License as published by
12
12
#
13
13
# You should have received a copy of the GNU General Public License
14
14
# along with this program; if not, write to the Free Software
15
 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 
15
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
16
 
17
17
"""Tests for the Branch facility that are not interface  tests.
18
18
 
19
 
For interface tests see tests/branch_implementations/*.py.
 
19
For interface tests see tests/per_branch/*.py.
20
20
 
21
21
For concrete class tests see this file, and for meta-branch tests
22
22
also see this file.
23
23
"""
24
24
 
25
 
from StringIO import StringIO
 
25
from cStringIO import StringIO
26
26
 
27
27
from bzrlib import (
28
28
    branch as _mod_branch,
29
29
    bzrdir,
30
30
    config,
31
31
    errors,
 
32
    tests,
32
33
    trace,
 
34
    transport,
33
35
    urlutils,
34
36
    )
35
 
from bzrlib.branch import (
36
 
    Branch,
37
 
    BranchHooks,
38
 
    BranchFormat,
39
 
    BranchReferenceFormat,
40
 
    BzrBranch5,
41
 
    BzrBranchFormat5,
42
 
    BzrBranchFormat6,
43
 
    PullResult,
44
 
    )
45
 
from bzrlib.bzrdir import (BzrDirMetaFormat1, BzrDirMeta1, 
46
 
                           BzrDir, BzrDirFormat)
47
 
from bzrlib.errors import (NotBranchError,
48
 
                           UnknownFormatError,
49
 
                           UnknownHook,
50
 
                           UnsupportedFormatError,
51
 
                           )
52
 
 
53
 
from bzrlib.tests import TestCase, TestCaseWithTransport
54
 
from bzrlib.transport import get_transport
55
 
 
56
 
class TestDefaultFormat(TestCase):
 
37
 
 
38
 
 
39
class TestDefaultFormat(tests.TestCase):
57
40
 
58
41
    def test_default_format(self):
59
42
        # update this if you change the default branch format
60
 
        self.assertIsInstance(BranchFormat.get_default_format(),
61
 
                BzrBranchFormat6)
 
43
        self.assertIsInstance(_mod_branch.BranchFormat.get_default_format(),
 
44
                _mod_branch.BzrBranchFormat7)
62
45
 
63
46
    def test_default_format_is_same_as_bzrdir_default(self):
64
47
        # XXX: it might be nice if there was only one place the default was
65
 
        # set, but at the moment that's not true -- mbp 20070814 -- 
 
48
        # set, but at the moment that's not true -- mbp 20070814 --
66
49
        # https://bugs.launchpad.net/bzr/+bug/132376
67
 
        self.assertEqual(BranchFormat.get_default_format(),
68
 
                BzrDirFormat.get_default_format().get_branch_format())
 
50
        self.assertEqual(
 
51
            _mod_branch.BranchFormat.get_default_format(),
 
52
            bzrdir.BzrDirFormat.get_default_format().get_branch_format())
69
53
 
70
54
    def test_get_set_default_format(self):
71
55
        # set the format and then set it back again
72
 
        old_format = BranchFormat.get_default_format()
73
 
        BranchFormat.set_default_format(SampleBranchFormat())
 
56
        old_format = _mod_branch.BranchFormat.get_default_format()
 
57
        _mod_branch.BranchFormat.set_default_format(SampleBranchFormat())
74
58
        try:
75
59
            # the default branch format is used by the meta dir format
76
60
            # which is not the default bzrdir format at this point
77
 
            dir = BzrDirMetaFormat1().initialize('memory:///')
 
61
            dir = bzrdir.BzrDirMetaFormat1().initialize('memory:///')
78
62
            result = dir.create_branch()
79
63
            self.assertEqual(result, 'A branch')
80
64
        finally:
81
 
            BranchFormat.set_default_format(old_format)
82
 
        self.assertEqual(old_format, BranchFormat.get_default_format())
83
 
 
84
 
 
85
 
class TestBranchFormat5(TestCaseWithTransport):
 
65
            _mod_branch.BranchFormat.set_default_format(old_format)
 
66
        self.assertEqual(old_format,
 
67
                         _mod_branch.BranchFormat.get_default_format())
 
68
 
 
69
 
 
70
class TestBranchFormat5(tests.TestCaseWithTransport):
86
71
    """Tests specific to branch format 5"""
87
72
 
88
73
    def test_branch_format_5_uses_lockdir(self):
89
74
        url = self.get_url()
90
 
        bzrdir = BzrDirMetaFormat1().initialize(url)
91
 
        bzrdir.create_repository()
92
 
        branch = bzrdir.create_branch()
 
75
        bdir = bzrdir.BzrDirMetaFormat1().initialize(url)
 
76
        bdir.create_repository()
 
77
        branch = bdir.create_branch()
93
78
        t = self.get_transport()
94
79
        self.log("branch instance is %r" % branch)
95
 
        self.assert_(isinstance(branch, BzrBranch5))
 
80
        self.assert_(isinstance(branch, _mod_branch.BzrBranch5))
96
81
        self.assertIsDirectory('.', t)
97
82
        self.assertIsDirectory('.bzr/branch', t)
98
83
        self.assertIsDirectory('.bzr/branch/lock', t)
99
84
        branch.lock_write()
100
 
        try:
101
 
            self.assertIsDirectory('.bzr/branch/lock/held', t)
102
 
        finally:
103
 
            branch.unlock()
 
85
        self.addCleanup(branch.unlock)
 
86
        self.assertIsDirectory('.bzr/branch/lock/held', t)
104
87
 
105
88
    def test_set_push_location(self):
106
89
        from bzrlib.config import (locations_config_filename,
122
105
        self.assertFileEqual("# comment\n"
123
106
                             "[%s]\n"
124
107
                             "push_location = foo\n"
125
 
                             "push_location:policy = norecurse" % local_path,
 
108
                             "push_location:policy = norecurse\n" % local_path,
126
109
                             fn)
127
110
 
128
111
    # TODO RBC 20051029 test getting a push location from a branch in a
129
112
    # recursive section - that is, it appends the branch name.
130
113
 
131
114
 
132
 
class SampleBranchFormat(BranchFormat):
 
115
class SampleBranchFormat(_mod_branch.BranchFormat):
133
116
    """A sample format
134
117
 
135
 
    this format is initializable, unsupported to aid in testing the 
 
118
    this format is initializable, unsupported to aid in testing the
136
119
    open and open_downlevel routines.
137
120
    """
138
121
 
140
123
        """See BzrBranchFormat.get_format_string()."""
141
124
        return "Sample branch format."
142
125
 
143
 
    def initialize(self, a_bzrdir):
 
126
    def initialize(self, a_bzrdir, name=None):
144
127
        """Format 4 branches cannot be created."""
145
 
        t = a_bzrdir.get_branch_transport(self)
 
128
        t = a_bzrdir.get_branch_transport(self, name=name)
146
129
        t.put_bytes('format', self.get_format_string())
147
130
        return 'A branch'
148
131
 
149
132
    def is_supported(self):
150
133
        return False
151
134
 
152
 
    def open(self, transport, _found=False):
 
135
    def open(self, transport, name=None, _found=False, ignore_fallbacks=False):
153
136
        return "opened branch."
154
137
 
155
138
 
156
 
class TestBzrBranchFormat(TestCaseWithTransport):
 
139
class TestBzrBranchFormat(tests.TestCaseWithTransport):
157
140
    """Tests for the BzrBranchFormat facility."""
158
141
 
159
142
    def test_find_format(self):
160
143
        # is the right format object found for a branch?
161
144
        # create a branch with a few known format objects.
162
 
        # this is not quite the same as 
 
145
        # this is not quite the same as
163
146
        self.build_tree(["foo/", "bar/"])
164
147
        def check_format(format, url):
165
148
            dir = format._matchingbzrdir.initialize(url)
166
149
            dir.create_repository()
167
150
            format.initialize(dir)
168
 
            found_format = BranchFormat.find_format(dir)
 
151
            found_format = _mod_branch.BranchFormat.find_format(dir)
169
152
            self.failUnless(isinstance(found_format, format.__class__))
170
 
        check_format(BzrBranchFormat5(), "bar")
171
 
        
 
153
        check_format(_mod_branch.BzrBranchFormat5(), "bar")
 
154
 
172
155
    def test_find_format_not_branch(self):
173
156
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
174
 
        self.assertRaises(NotBranchError,
175
 
                          BranchFormat.find_format,
 
157
        self.assertRaises(errors.NotBranchError,
 
158
                          _mod_branch.BranchFormat.find_format,
176
159
                          dir)
177
160
 
178
161
    def test_find_format_unknown_format(self):
179
162
        dir = bzrdir.BzrDirMetaFormat1().initialize(self.get_url())
180
163
        SampleBranchFormat().initialize(dir)
181
 
        self.assertRaises(UnknownFormatError,
182
 
                          BranchFormat.find_format,
 
164
        self.assertRaises(errors.UnknownFormatError,
 
165
                          _mod_branch.BranchFormat.find_format,
183
166
                          dir)
184
167
 
185
168
    def test_register_unregister_format(self):
189
172
        # make a branch
190
173
        format.initialize(dir)
191
174
        # register a format for it.
192
 
        BranchFormat.register_format(format)
 
175
        _mod_branch.BranchFormat.register_format(format)
193
176
        # which branch.Open will refuse (not supported)
194
 
        self.assertRaises(UnsupportedFormatError, Branch.open, self.get_url())
 
177
        self.assertRaises(errors.UnsupportedFormatError,
 
178
                          _mod_branch.Branch.open, self.get_url())
195
179
        self.make_branch_and_tree('foo')
196
180
        # but open_downlevel will work
197
 
        self.assertEqual(format.open(dir), bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
 
181
        self.assertEqual(
 
182
            format.open(dir),
 
183
            bzrdir.BzrDir.open(self.get_url()).open_branch(unsupported=True))
198
184
        # unregister the format
199
 
        BranchFormat.unregister_format(format)
 
185
        _mod_branch.BranchFormat.unregister_format(format)
200
186
        self.make_branch_and_tree('bar')
201
187
 
202
188
 
203
 
class TestBranch6(TestCaseWithTransport):
 
189
class TestBranch67(object):
 
190
    """Common tests for both branch 6 and 7 which are mostly the same."""
 
191
 
 
192
    def get_format_name(self):
 
193
        raise NotImplementedError(self.get_format_name)
 
194
 
 
195
    def get_format_name_subtree(self):
 
196
        raise NotImplementedError(self.get_format_name)
 
197
 
 
198
    def get_class(self):
 
199
        raise NotImplementedError(self.get_class)
204
200
 
205
201
    def test_creation(self):
206
 
        format = BzrDirMetaFormat1()
 
202
        format = bzrdir.BzrDirMetaFormat1()
207
203
        format.set_branch_format(_mod_branch.BzrBranchFormat6())
208
204
        branch = self.make_branch('a', format=format)
209
 
        self.assertIsInstance(branch, _mod_branch.BzrBranch6)
210
 
        branch = self.make_branch('b', format='dirstate-tags')
211
 
        self.assertIsInstance(branch, _mod_branch.BzrBranch6)
 
205
        self.assertIsInstance(branch, self.get_class())
 
206
        branch = self.make_branch('b', format=self.get_format_name())
 
207
        self.assertIsInstance(branch, self.get_class())
212
208
        branch = _mod_branch.Branch.open('a')
213
 
        self.assertIsInstance(branch, _mod_branch.BzrBranch6)
 
209
        self.assertIsInstance(branch, self.get_class())
214
210
 
215
211
    def test_layout(self):
216
 
        branch = self.make_branch('a', format='dirstate-tags')
 
212
        branch = self.make_branch('a', format=self.get_format_name())
217
213
        self.failUnlessExists('a/.bzr/branch/last-revision')
218
214
        self.failIfExists('a/.bzr/branch/revision-history')
 
215
        self.failIfExists('a/.bzr/branch/references')
219
216
 
220
217
    def test_config(self):
221
218
        """Ensure that all configuration data is stored in the branch"""
222
 
        branch = self.make_branch('a', format='dirstate-tags')
 
219
        branch = self.make_branch('a', format=self.get_format_name())
223
220
        branch.set_parent('http://bazaar-vcs.org')
224
221
        self.failIfExists('a/.bzr/branch/parent')
225
222
        self.assertEqual('http://bazaar-vcs.org', branch.get_parent())
232
229
        self.assertEqual('ftp://bazaar-vcs.org', branch.get_bound_location())
233
230
 
234
231
    def test_set_revision_history(self):
235
 
        tree = self.make_branch_and_memory_tree('.',
236
 
            format='dirstate-tags')
237
 
        tree.lock_write()
238
 
        try:
239
 
            tree.add('.')
240
 
            tree.commit('foo', rev_id='foo')
241
 
            tree.commit('bar', rev_id='bar')
242
 
            tree.branch.set_revision_history(['foo', 'bar'])
243
 
            tree.branch.set_revision_history(['foo'])
244
 
            self.assertRaises(errors.NotLefthandHistory,
245
 
                              tree.branch.set_revision_history, ['bar'])
246
 
        finally:
247
 
            tree.unlock()
 
232
        builder = self.make_branch_builder('.', format=self.get_format_name())
 
233
        builder.build_snapshot('foo', None,
 
234
            [('add', ('', None, 'directory', None))],
 
235
            message='foo')
 
236
        builder.build_snapshot('bar', None, [], message='bar')
 
237
        branch = builder.get_branch()
 
238
        branch.lock_write()
 
239
        self.addCleanup(branch.unlock)
 
240
        branch.set_revision_history(['foo', 'bar'])
 
241
        branch.set_revision_history(['foo'])
 
242
        self.assertRaises(errors.NotLefthandHistory,
 
243
                          branch.set_revision_history, ['bar'])
248
244
 
249
245
    def do_checkout_test(self, lightweight=False):
250
 
        tree = self.make_branch_and_tree('source', format='dirstate-with-subtree')
 
246
        tree = self.make_branch_and_tree('source',
 
247
            format=self.get_format_name_subtree())
251
248
        subtree = self.make_branch_and_tree('source/subtree',
252
 
            format='dirstate-with-subtree')
 
249
            format=self.get_format_name_subtree())
253
250
        subsubtree = self.make_branch_and_tree('source/subtree/subsubtree',
254
 
            format='dirstate-with-subtree')
 
251
            format=self.get_format_name_subtree())
255
252
        self.build_tree(['source/subtree/file',
256
253
                         'source/subtree/subsubtree/file'])
257
254
        subsubtree.add('file')
279
276
        self.do_checkout_test(lightweight=True)
280
277
 
281
278
    def test_set_push(self):
282
 
        branch = self.make_branch('source', format='dirstate-tags')
 
279
        branch = self.make_branch('source', format=self.get_format_name())
283
280
        branch.get_config().set_user_option('push_location', 'old',
284
281
            store=config.STORE_LOCATION)
285
282
        warnings = []
294
291
        self.assertEqual(warnings[0], 'Value "new" is masked by "old" from '
295
292
                         'locations.conf')
296
293
 
297
 
class TestBranchReference(TestCaseWithTransport):
 
294
 
 
295
class TestBranch6(TestBranch67, tests.TestCaseWithTransport):
 
296
 
 
297
    def get_class(self):
 
298
        return _mod_branch.BzrBranch6
 
299
 
 
300
    def get_format_name(self):
 
301
        return "dirstate-tags"
 
302
 
 
303
    def get_format_name_subtree(self):
 
304
        return "dirstate-with-subtree"
 
305
 
 
306
    def test_set_stacked_on_url_errors(self):
 
307
        branch = self.make_branch('a', format=self.get_format_name())
 
308
        self.assertRaises(errors.UnstackableBranchFormat,
 
309
            branch.set_stacked_on_url, None)
 
310
 
 
311
    def test_default_stacked_location(self):
 
312
        branch = self.make_branch('a', format=self.get_format_name())
 
313
        self.assertRaises(errors.UnstackableBranchFormat, branch.get_stacked_on_url)
 
314
 
 
315
 
 
316
class TestBranch7(TestBranch67, tests.TestCaseWithTransport):
 
317
 
 
318
    def get_class(self):
 
319
        return _mod_branch.BzrBranch7
 
320
 
 
321
    def get_format_name(self):
 
322
        return "1.9"
 
323
 
 
324
    def get_format_name_subtree(self):
 
325
        return "development-subtree"
 
326
 
 
327
    def test_set_stacked_on_url_unstackable_repo(self):
 
328
        repo = self.make_repository('a', format='dirstate-tags')
 
329
        control = repo.bzrdir
 
330
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
331
        target = self.make_branch('b')
 
332
        self.assertRaises(errors.UnstackableRepositoryFormat,
 
333
            branch.set_stacked_on_url, target.base)
 
334
 
 
335
    def test_clone_stacked_on_unstackable_repo(self):
 
336
        repo = self.make_repository('a', format='dirstate-tags')
 
337
        control = repo.bzrdir
 
338
        branch = _mod_branch.BzrBranchFormat7().initialize(control)
 
339
        # Calling clone should not raise UnstackableRepositoryFormat.
 
340
        cloned_bzrdir = control.clone('cloned')
 
341
 
 
342
    def _test_default_stacked_location(self):
 
343
        branch = self.make_branch('a', format=self.get_format_name())
 
344
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
345
 
 
346
    def test_stack_and_unstack(self):
 
347
        branch = self.make_branch('a', format=self.get_format_name())
 
348
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
349
        branch.set_stacked_on_url(target.branch.base)
 
350
        self.assertEqual(target.branch.base, branch.get_stacked_on_url())
 
351
        revid = target.commit('foo')
 
352
        self.assertTrue(branch.repository.has_revision(revid))
 
353
        branch.set_stacked_on_url(None)
 
354
        self.assertRaises(errors.NotStacked, branch.get_stacked_on_url)
 
355
        self.assertFalse(branch.repository.has_revision(revid))
 
356
 
 
357
    def test_open_opens_stacked_reference(self):
 
358
        branch = self.make_branch('a', format=self.get_format_name())
 
359
        target = self.make_branch_and_tree('b', format=self.get_format_name())
 
360
        branch.set_stacked_on_url(target.branch.base)
 
361
        branch = branch.bzrdir.open_branch()
 
362
        revid = target.commit('foo')
 
363
        self.assertTrue(branch.repository.has_revision(revid))
 
364
 
 
365
 
 
366
class BzrBranch8(tests.TestCaseWithTransport):
 
367
 
 
368
    def make_branch(self, location, format=None):
 
369
        if format is None:
 
370
            format = bzrdir.format_registry.make_bzrdir('1.9')
 
371
            format.set_branch_format(_mod_branch.BzrBranchFormat8())
 
372
        return tests.TestCaseWithTransport.make_branch(
 
373
            self, location, format=format)
 
374
 
 
375
    def create_branch_with_reference(self):
 
376
        branch = self.make_branch('branch')
 
377
        branch._set_all_reference_info({'file-id': ('path', 'location')})
 
378
        return branch
 
379
 
 
380
    @staticmethod
 
381
    def instrument_branch(branch, gets):
 
382
        old_get = branch._transport.get
 
383
        def get(*args, **kwargs):
 
384
            gets.append((args, kwargs))
 
385
            return old_get(*args, **kwargs)
 
386
        branch._transport.get = get
 
387
 
 
388
    def test_reference_info_caching_read_locked(self):
 
389
        gets = []
 
390
        branch = self.create_branch_with_reference()
 
391
        branch.lock_read()
 
392
        self.addCleanup(branch.unlock)
 
393
        self.instrument_branch(branch, gets)
 
394
        branch.get_reference_info('file-id')
 
395
        branch.get_reference_info('file-id')
 
396
        self.assertEqual(1, len(gets))
 
397
 
 
398
    def test_reference_info_caching_read_unlocked(self):
 
399
        gets = []
 
400
        branch = self.create_branch_with_reference()
 
401
        self.instrument_branch(branch, gets)
 
402
        branch.get_reference_info('file-id')
 
403
        branch.get_reference_info('file-id')
 
404
        self.assertEqual(2, len(gets))
 
405
 
 
406
    def test_reference_info_caching_write_locked(self):
 
407
        gets = []
 
408
        branch = self.make_branch('branch')
 
409
        branch.lock_write()
 
410
        self.instrument_branch(branch, gets)
 
411
        self.addCleanup(branch.unlock)
 
412
        branch._set_all_reference_info({'file-id': ('path2', 'location2')})
 
413
        path, location = branch.get_reference_info('file-id')
 
414
        self.assertEqual(0, len(gets))
 
415
        self.assertEqual('path2', path)
 
416
        self.assertEqual('location2', location)
 
417
 
 
418
    def test_reference_info_caches_cleared(self):
 
419
        branch = self.make_branch('branch')
 
420
        branch.lock_write()
 
421
        branch.set_reference_info('file-id', 'path2', 'location2')
 
422
        branch.unlock()
 
423
        doppelganger = _mod_branch.Branch.open('branch')
 
424
        doppelganger.set_reference_info('file-id', 'path3', 'location3')
 
425
        self.assertEqual(('path3', 'location3'),
 
426
                         branch.get_reference_info('file-id'))
 
427
 
 
428
class TestBranchReference(tests.TestCaseWithTransport):
298
429
    """Tests for the branch reference facility."""
299
430
 
300
431
    def test_create_open_reference(self):
301
432
        bzrdirformat = bzrdir.BzrDirMetaFormat1()
302
 
        t = get_transport(self.get_url('.'))
 
433
        t = transport.get_transport(self.get_url('.'))
303
434
        t.mkdir('repo')
304
435
        dir = bzrdirformat.initialize(self.get_url('repo'))
305
436
        dir.create_repository()
306
437
        target_branch = dir.create_branch()
307
438
        t.mkdir('branch')
308
439
        branch_dir = bzrdirformat.initialize(self.get_url('branch'))
309
 
        made_branch = BranchReferenceFormat().initialize(branch_dir, target_branch)
 
440
        made_branch = _mod_branch.BranchReferenceFormat().initialize(
 
441
            branch_dir, target_branch=target_branch)
310
442
        self.assertEqual(made_branch.base, target_branch.base)
311
443
        opened_branch = branch_dir.open_branch()
312
444
        self.assertEqual(opened_branch.base, target_branch.base)
323
455
            _mod_branch.BranchReferenceFormat().get_reference(checkout.bzrdir))
324
456
 
325
457
 
326
 
class TestHooks(TestCase):
 
458
class TestHooks(tests.TestCase):
327
459
 
328
460
    def test_constructor(self):
329
461
        """Check that creating a BranchHooks instance has the right defaults."""
330
 
        hooks = BranchHooks()
 
462
        hooks = _mod_branch.BranchHooks()
331
463
        self.assertTrue("set_rh" in hooks, "set_rh not in %s" % hooks)
332
464
        self.assertTrue("post_push" in hooks, "post_push not in %s" % hooks)
333
465
        self.assertTrue("post_commit" in hooks, "post_commit not in %s" % hooks)
334
466
        self.assertTrue("pre_commit" in hooks, "pre_commit not in %s" % hooks)
335
467
        self.assertTrue("post_pull" in hooks, "post_pull not in %s" % hooks)
336
 
        self.assertTrue("post_uncommit" in hooks, "post_uncommit not in %s" % hooks)
 
468
        self.assertTrue("post_uncommit" in hooks,
 
469
                        "post_uncommit not in %s" % hooks)
 
470
        self.assertTrue("post_change_branch_tip" in hooks,
 
471
                        "post_change_branch_tip not in %s" % hooks)
337
472
 
338
473
    def test_installed_hooks_are_BranchHooks(self):
339
474
        """The installed hooks object should be a BranchHooks."""
340
475
        # the installed hooks are saved in self._preserved_hooks.
341
 
        self.assertIsInstance(self._preserved_hooks[_mod_branch.Branch], BranchHooks)
342
 
 
343
 
 
344
 
class TestPullResult(TestCase):
 
476
        self.assertIsInstance(self._preserved_hooks[_mod_branch.Branch][1],
 
477
                              _mod_branch.BranchHooks)
 
478
 
 
479
 
 
480
class TestPullResult(tests.TestCase):
345
481
 
346
482
    def test_pull_result_to_int(self):
347
483
        # to support old code, the pull result can be used as an int
348
 
        r = PullResult()
 
484
        r = _mod_branch.PullResult()
349
485
        r.old_revno = 10
350
486
        r.new_revno = 20
351
487
        # this usage of results is not recommended for new code (because it
353
489
        # it's still supported
354
490
        a = "%d revisions pulled" % r
355
491
        self.assertEqual(a, "10 revisions pulled")
 
492
 
 
493
    def test_report_changed(self):
 
494
        r = _mod_branch.PullResult()
 
495
        r.old_revid = "old-revid"
 
496
        r.old_revno = 10
 
497
        r.new_revid = "new-revid"
 
498
        r.new_revno = 20
 
499
        f = StringIO()
 
500
        r.report(f)
 
501
        self.assertEqual("Now on revision 20.\n", f.getvalue())
 
502
 
 
503
    def test_report_unchanged(self):
 
504
        r = _mod_branch.PullResult()
 
505
        r.old_revid = "same-revid"
 
506
        r.new_revid = "same-revid"
 
507
        f = StringIO()
 
508
        r.report(f)
 
509
        self.assertEqual("No revisions to pull.\n", f.getvalue())
 
510
 
 
511
 
 
512
class _StubLockable(object):
 
513
    """Helper for TestRunWithWriteLockedTarget."""
 
514
 
 
515
    def __init__(self, calls, unlock_exc=None):
 
516
        self.calls = calls
 
517
        self.unlock_exc = unlock_exc
 
518
 
 
519
    def lock_write(self):
 
520
        self.calls.append('lock_write')
 
521
 
 
522
    def unlock(self):
 
523
        self.calls.append('unlock')
 
524
        if self.unlock_exc is not None:
 
525
            raise self.unlock_exc
 
526
 
 
527
 
 
528
class _ErrorFromCallable(Exception):
 
529
    """Helper for TestRunWithWriteLockedTarget."""
 
530
 
 
531
 
 
532
class _ErrorFromUnlock(Exception):
 
533
    """Helper for TestRunWithWriteLockedTarget."""
 
534
 
 
535
 
 
536
class TestRunWithWriteLockedTarget(tests.TestCase):
 
537
    """Tests for _run_with_write_locked_target."""
 
538
 
 
539
    def setUp(self):
 
540
        tests.TestCase.setUp(self)
 
541
        self._calls = []
 
542
 
 
543
    def func_that_returns_ok(self):
 
544
        self._calls.append('func called')
 
545
        return 'ok'
 
546
 
 
547
    def func_that_raises(self):
 
548
        self._calls.append('func called')
 
549
        raise _ErrorFromCallable()
 
550
 
 
551
    def test_success_unlocks(self):
 
552
        lockable = _StubLockable(self._calls)
 
553
        result = _mod_branch._run_with_write_locked_target(
 
554
            lockable, self.func_that_returns_ok)
 
555
        self.assertEqual('ok', result)
 
556
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
557
 
 
558
    def test_exception_unlocks_and_propagates(self):
 
559
        lockable = _StubLockable(self._calls)
 
560
        self.assertRaises(_ErrorFromCallable,
 
561
                          _mod_branch._run_with_write_locked_target,
 
562
                          lockable, self.func_that_raises)
 
563
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
564
 
 
565
    def test_callable_succeeds_but_error_during_unlock(self):
 
566
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
567
        self.assertRaises(_ErrorFromUnlock,
 
568
                          _mod_branch._run_with_write_locked_target,
 
569
                          lockable, self.func_that_returns_ok)
 
570
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
571
 
 
572
    def test_error_during_unlock_does_not_mask_original_error(self):
 
573
        lockable = _StubLockable(self._calls, unlock_exc=_ErrorFromUnlock())
 
574
        self.assertRaises(_ErrorFromCallable,
 
575
                          _mod_branch._run_with_write_locked_target,
 
576
                          lockable, self.func_that_raises)
 
577
        self.assertEqual(['lock_write', 'func called', 'unlock'], self._calls)
 
578
 
 
579