From 1c6098470085c68a4e27c5118d7d9c66bd4ac2d9 Mon Sep 17 00:00:00 2001 From: MilesCranmerBot Date: Mon, 22 Jun 2026 19:29:51 +0000 Subject: [PATCH 1/2] fix: raise clear error on reload of TemplateExpressionSpec models Closes #846 Co-authored-by: Miles Cranmer --- pysr/sr.py | 12 ++++++++++++ pysr/test/test_main.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/pysr/sr.py b/pysr/sr.py index a304e9503..ff677bfa1 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -39,6 +39,7 @@ AbstractExpressionSpec, ExpressionSpec, ParametricExpressionSpec, + TemplateExpressionSpec, parametric_expression_deprecation_warning, ) from .feature_selection import run_feature_selection @@ -1467,6 +1468,17 @@ def __getstate__(self) -> dict[str, Any]: ] return pickled_state + def __setstate__(self, state: dict[str, Any]) -> None: + # ponytail: raise immediately on reload instead of confusing SymPy error later + self.__dict__.update(state) + if "equations_" in state and state["equations_"] is not None and isinstance( + self.expression_spec, TemplateExpressionSpec + ): + raise NotImplementedError( + "Reloading fitted TemplateExpressionSpec models is not yet supported. " + "Please refit the model in the current session." + ) + def _checkpoint(self): """Save the model's current state to a checkpoint file. diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 1f7304c1b..59ef24ce1 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -2246,6 +2246,20 @@ def test_process_constraints_swaps_multiplication_constraints(self): class TestTemplateExpressionSpec(unittest.TestCase): + def test_reload_raises_clear_error(self): + # ponytail: one check — reload of fitted template spec raises immediately + import pickle + model = PySRRegressor( + expression_spec=TemplateExpressionSpec( + combine="f(x)", expressions=["f"], variable_names=["x"] + ) + ) + model.equations_ = pd.DataFrame({"loss": [0.0]}) + model.feature_names_in_ = np.array(["x"]) + model.nout_ = 1 + with self.assertRaisesRegex(NotImplementedError, "not yet supported"): + pickle.loads(pickle.dumps(model)) + def _check_macro_str(self, spec, expected_str): self.assertEqual( spec._template_macro_str().strip(), dedent(expected_str).strip() From 762cb0c0cba1786c27967fcaf7dbafa54ea2f6fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jun 2026 19:30:05 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pysr/sr.py | 6 ++++-- pysr/test/test_main.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pysr/sr.py b/pysr/sr.py index ff677bfa1..2747355d1 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -1471,8 +1471,10 @@ def __getstate__(self) -> dict[str, Any]: def __setstate__(self, state: dict[str, Any]) -> None: # ponytail: raise immediately on reload instead of confusing SymPy error later self.__dict__.update(state) - if "equations_" in state and state["equations_"] is not None and isinstance( - self.expression_spec, TemplateExpressionSpec + if ( + "equations_" in state + and state["equations_"] is not None + and isinstance(self.expression_spec, TemplateExpressionSpec) ): raise NotImplementedError( "Reloading fitted TemplateExpressionSpec models is not yet supported. " diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 59ef24ce1..9a2d9a19c 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -2249,6 +2249,7 @@ class TestTemplateExpressionSpec(unittest.TestCase): def test_reload_raises_clear_error(self): # ponytail: one check — reload of fitted template spec raises immediately import pickle + model = PySRRegressor( expression_spec=TemplateExpressionSpec( combine="f(x)", expressions=["f"], variable_names=["x"]