diff --git a/equinox/nn/_stateful.py b/equinox/nn/_stateful.py index 5a992d39..fc9b9ab0 100644 --- a/equinox/nn/_stateful.py +++ b/equinox/nn/_stateful.py @@ -1,3 +1,4 @@ +import types from collections.abc import Callable from typing import Any, Generic, TYPE_CHECKING, TypeVar, Union from typing_extensions import ParamSpec @@ -52,6 +53,15 @@ def __init__(self, init: _Value): - `init`: The initial value for the state. """ + if isinstance(init, types.FunctionType): + # Technically a function is valid here, since we could allow any pytree. + # In practice that's weird / kind of useless, so better to explicitly raise + # the deprecation error. + raise ValueError( + "As of Equinox v0.11.0, `eqx.nn.StateIndex` now accepts the value " + "of the initial state directly. (Not a function that creates the " + "initial state.)" + ) self.marker = object() self.init = init