diff --git a/src/sphinxnotes/any/domain.py b/src/sphinxnotes/any/domain.py index b5b584b..c8ac43a 100644 --- a/src/sphinxnotes/any/domain.py +++ b/src/sphinxnotes/any/domain.py @@ -620,29 +620,30 @@ def setup_external_header_anchor( @dataclass class PendingObject(UnresolvedContext): - domain: ObjDomain + domain: str objtype: str objid: str @override - def resolve(self) -> ResolvedContext: + def resolve(self, env: BuildEnvironment) -> ResolvedContext: + domain: ObjDomain = cast(ObjDomain, env.get_domain(self.domain)) objid = self.objid - if (self.objtype, objid) not in self.domain.objects: + if (self.objtype, objid) not in domain.objects: objids = set() - for objtype, objfield, objref in self.domain.references: + for objtype, objfield, objref in domain.references: if objtype == self.objtype and objref == objid: - objids.update(self.domain.references[objtype, objfield, objref]) + objids.update(domain.references[objtype, objfield, objref]) if len(objids) >= 1: objid = objids.pop() else: raise KeyError(f'Object not found: {(self.objtype, objid)}') - _, _, obj = self.domain.objects[self.objtype, objid] + _, _, obj = domain.objects[self.objtype, objid] return obj def __hash__(self) -> int: - return hash((self.domain.name, self.objtype, self.objid)) + return hash((self.domain, self.objtype, self.objid)) class ObjEmbedDirective(BaseContextDirective): @@ -671,7 +672,7 @@ def get_rendered_objects(self) -> list[nodes.Node]: def current_context(self) -> UnresolvedContext | ResolvedContext: domain, objtype = self.get_domain_and_type() objid = self.arguments[0] - return PendingObject(domain, objtype, objid) + return PendingObject(domain.name, objtype, objid) @override def current_template(self) -> Template: diff --git a/tests/test_domain.py b/tests/test_domain.py index 629d1c2..7ac28ef 100644 --- a/tests/test_domain.py +++ b/tests/test_domain.py @@ -9,16 +9,19 @@ class TestPendingObjectResolve(unittest.TestCase): def setUp(self): + self.env = MagicMock() self.domain = ObjDomain(MagicMock()) + self.domain.name = 'obj' self.domain.data['objects'] = {} self.domain.data['references'] = {} + self.env.get_domain.return_value = self.domain def test_resolve_by_id(self): obj = MagicMock() self.domain.objects['foo', 'id1'] = ('doc1', 'anchor1', obj) - pending = PendingObject(domain=self.domain, objtype='foo', objid='id1') - result = pending.resolve() + pending = PendingObject(domain='obj', objtype='foo', objid='id1') + result = pending.resolve(self.env) self.assertIs(result, obj) @@ -27,8 +30,8 @@ def test_resolve_by_reference(self): self.domain.objects['foo', 'id1'] = ('doc1', 'anchor1', obj) self.domain.references['foo', 'field1', 'ref1'] = {'id1'} - pending = PendingObject(domain=self.domain, objtype='foo', objid='ref1') - result = pending.resolve() + pending = PendingObject(domain='obj', objtype='foo', objid='ref1') + result = pending.resolve(self.env) self.assertIs(result, obj) @@ -36,10 +39,10 @@ def test_resolve_not_found(self): self.domain.objects['foo', 'id1'] = ('doc1', 'anchor1', MagicMock()) self.domain.references['foo', 'field1', 'ref1'] = {'id1'} - pending = PendingObject(domain=self.domain, objtype='foo', objid='nonexistent') + pending = PendingObject(domain='obj', objtype='foo', objid='nonexistent') with self.assertRaises(KeyError): - pending.resolve() + pending.resolve(self.env) def test_resolve_multiple_references(self): obj1 = MagicMock() @@ -48,8 +51,8 @@ def test_resolve_multiple_references(self): self.domain.objects['foo', 'id2'] = ('doc2', 'anchor2', obj2) self.domain.references['foo', 'field1', 'ref1'] = {'id1', 'id2'} - pending = PendingObject(domain=self.domain, objtype='foo', objid='ref1') - result = pending.resolve() + pending = PendingObject(domain='obj', objtype='foo', objid='ref1') + result = pending.resolve(self.env) self.assertIn(result, [obj1, obj2])