From 9e1c5d1b5932879717235a0126088a26a6c3a80f Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 7 Sep 2023 07:53:23 -0700 Subject: [PATCH] Now running tree_check when initialising a Module --- equinox/_module.py | 41 ++++++++++++++++++++++++++++++++++++++++- equinox/_tree.py | 19 ++++++++++++++----- tests/test_module.py | 23 +++++++++++++++++++++++ 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/equinox/_module.py b/equinox/_module.py index c8422ddd..9c0182e8 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -16,7 +16,7 @@ from ._caches import internal_lru_caches from ._doc_utils import doc_repr from ._pretty_print import tree_pformat -from ._tree import tree_equal +from ._tree import tree_check_internal, tree_equal _P = ParamSpec("_P") @@ -141,6 +141,11 @@ def _not_magic(k: str) -> bool: _has_dataclass_init = weakref.WeakKeyDictionary() +_has_been_checked = weakref.WeakValueDictionary() + + +def _skip(node): + return isinstance(node, Module) and node is _has_been_checked.get(id(node), None) _dummy_abstract = abc.abstractmethod(lambda self: 1) @@ -272,6 +277,8 @@ def __call__(cls, *args, **kwargs): else: setattr(self, field.name, converter(getattr(self, field.name))) object.__setattr__(self, "__class__", cls) + # Note that these checks only run during the initial creation, and not during + # unflattening. for kls in cls.__mro__: try: check = kls.__dict__["__check_init__"] @@ -279,6 +286,38 @@ def __call__(cls, *args, **kwargs): pass else: check(self) + try: + tree_check_internal(self, _skip) + except ValueError as e: + raise ValueError( + "As of Equinox v0.11.0, `equinox.Module`s now validate that there " + "aren't any repeated layers inside a module. This is because this was " + "previously a common bug.\n" + "As an example, something like this:\n" + "```\n`" + "class MyModule(eqx.Module):\n" + " linear1: eqx.nn.Linear\n" + " linear2: eqx.nn.Linear\n" + "\n" + " def __init__(self, ...):\n" + " linear = eqx.nn.Linear(...)\n" + " self.linear1 = linear\n" + " self.linear2 = linear\n" + "```\n" + "resulted in two independent linear layers after a gradient update had " + "happened.\n" + "An exception is being thrown now as this error been detected.\n" + "If you intended to share the layer, then use the new functionality " + "`eqx.nn.Shared`. If you intended to have two duplicate copies, then " + "please instantiate two separate layers. If it's easier, you can also " + "clone an existing layer by doing\n" + "```\n" + "layer = ...\n" + "leaves, treedef = jax.tree_util.tree_flatten(layer)\n" + "clone_layer = jax.tree_util.tree_unflatten(treedef, leaves)\n" + "```" + ) from e + _has_been_checked[id(self)] = self return self diff --git a/equinox/_tree.py b/equinox/_tree.py index de1b1dc3..ef25e357 100644 --- a/equinox/_tree.py +++ b/equinox/_tree.py @@ -338,6 +338,14 @@ def is_leaf(node): return jtu.tree_flatten(pytree, is_leaf=is_leaf) +def tree_check_internal(pytree, skip) -> None: + """As `tree_check`, but can skips checking some nodes (typically those that have + alread been checked). + """ + all_nodes = {} + _tree_check(pytree, all_nodes, skip) + + def tree_check(pytree: Any) -> None: """Checks if the PyTree is well-formed: does it have no self-references, and does it have no duplicate layers. @@ -389,13 +397,13 @@ def tree_check(pytree: Any) -> None: A `ValueError` if the PyTree is not well-formed. """ all_nodes = {} - _tree_check(pytree, all_nodes) + _tree_check(pytree, all_nodes, skip=lambda _: False) _leaf_treedef = jtu.tree_structure(0) -def _tree_check(node, all_nodes): +def _tree_check(node, all_nodes, skip): subnodes, treedef = tree_flatten_one_level(node) # We allow duplicate leaves and empty containers, so don't raise an error with those if treedef != _leaf_treedef and treedef.num_leaves > 0: @@ -422,7 +430,8 @@ def _tree_check(node, all_nodes): except AttributeError: # AttributeError: in case we cannot get __name__ for some weird reason. type_string = "" - all_nodes[id(node)] = (True, type_string) - for subnode in subnodes: - _tree_check(subnode, all_nodes) + if not skip(node): + all_nodes[id(node)] = (True, type_string) + for subnode in subnodes: + _tree_check(subnode, all_nodes, skip) all_nodes[id(node)] = (False, type_string) diff --git a/tests/test_module.py b/tests/test_module.py index c954ad0c..1e298b72 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,6 +1,7 @@ import abc import dataclasses import functools as ft +import gc from collections.abc import Callable from typing import Any @@ -540,3 +541,25 @@ class Abstract3(eqx.Module, strict=True): @abc.abstractmethod def foo(self): pass + + +def test_tree_check_cache(getkey): + gc.collect() + has_been_checked = eqx._module._has_been_checked + num_checked = len(has_been_checked) + mlp = eqx.nn.MLP(2, 2, 2, 2, key=getkey()) + # +4: one for `MLP`, and three for its `Linear` layers inside. + assert len(has_been_checked) == num_checked + 4 + del mlp + gc.collect() + assert len(has_been_checked) == num_checked + + +def test_duplicate_layer_error(getkey): + class M(eqx.Module): + l1: eqx.nn.Linear + l2: eqx.nn.Linear + + linear = eqx.nn.Linear(2, 2, key=getkey()) + with pytest.raises(ValueError, match="As of Equinox v0.11.0"): + M(linear, linear)