diff --git a/pysr/sr.py b/pysr/sr.py index a304e9503..2747355d1 100644 --- a/pysr/sr.py +++ b/pysr/sr.py @@ -39,6 +39,7 @@ AbstractExpressionSpec, ExpressionSpec, ParametricExpressionSpec, + TemplateExpressionSpec, parametric_expression_deprecation_warning, ) from .feature_selection import run_feature_selection @@ -1467,6 +1468,19 @@ def __getstate__(self) -> dict[str, Any]: ] return pickled_state + def __setstate__(self, state: dict[str, Any]) -> None: + # ponytail: raise immediately on reload instead of confusing SymPy error later + self.__dict__.update(state) + if ( + "equations_" in state + and state["equations_"] is not None + and isinstance(self.expression_spec, TemplateExpressionSpec) + ): + raise NotImplementedError( + "Reloading fitted TemplateExpressionSpec models is not yet supported. " + "Please refit the model in the current session." + ) + def _checkpoint(self): """Save the model's current state to a checkpoint file. diff --git a/pysr/test/test_main.py b/pysr/test/test_main.py index 1f7304c1b..9a2d9a19c 100644 --- a/pysr/test/test_main.py +++ b/pysr/test/test_main.py @@ -2246,6 +2246,21 @@ def test_process_constraints_swaps_multiplication_constraints(self): class TestTemplateExpressionSpec(unittest.TestCase): + def test_reload_raises_clear_error(self): + # ponytail: one check — reload of fitted template spec raises immediately + import pickle + + model = PySRRegressor( + expression_spec=TemplateExpressionSpec( + combine="f(x)", expressions=["f"], variable_names=["x"] + ) + ) + model.equations_ = pd.DataFrame({"loss": [0.0]}) + model.feature_names_in_ = np.array(["x"]) + model.nout_ = 1 + with self.assertRaisesRegex(NotImplementedError, "not yet supported"): + pickle.loads(pickle.dumps(model)) + def _check_macro_str(self, spec, expected_str): self.assertEqual( spec._template_macro_str().strip(), dedent(expected_str).strip()