From 44b6fee87fb976afca759fa816ce6900969e5511 Mon Sep 17 00:00:00 2001 From: Alexandre Gramfort Date: Tue, 23 Dec 2025 14:26:46 +0100 Subject: [PATCH] make sure picard does not alter sklearn FastICA._parameter_constraints --- picard/dropin_sklearn.py | 4 +++- picard/tests/test_sklearn.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/picard/dropin_sklearn.py b/picard/dropin_sklearn.py index bb3c2f9..6817909 100644 --- a/picard/dropin_sklearn.py +++ b/picard/dropin_sklearn.py @@ -106,10 +106,12 @@ def __init__(self, n_components=None, *, ortho=True, extended=None, super().__init__() # update parameters constraint dict - self._parameter_constraints["fun"] = [ + _parameter_constraints = self._parameter_constraints.copy() + _parameter_constraints["fun"] = [ StrOptions({"tanh", "exp", "cube"}), callable, ] + self._parameter_constraints = _parameter_constraints if max_iter < 1: raise ValueError("max_iter should be greater than 1, got " "(max_iter={})".format(max_iter)) diff --git a/picard/tests/test_sklearn.py b/picard/tests/test_sklearn.py index 36d5bc3..b74024f 100644 --- a/picard/tests/test_sklearn.py +++ b/picard/tests/test_sklearn.py @@ -7,6 +7,7 @@ import numpy as np +from sklearn.decomposition import FastICA from sklearn.utils._testing import assert_array_almost_equal from picard import Picard @@ -90,3 +91,15 @@ def test_inverse_transform(): def test_picard_errors(): with pytest.raises(ValueError, match='max_iter should be greater than 1'): Picard(max_iter=0) + + +def test_picard_fastica_fun_parameter_constraints(): + """Test that the 'fun' parameter constraints have been updated to include + 'tanh', 'exp', 'cube', and callable. + This test ensures that the modifications made to the Picard class to + update the parameter constraints for 'fun' do not inadvertently affect + the FastICA class from scikit-learn. + """ + picard = Picard() + fast_ica = FastICA() + assert picard._parameter_constraints != fast_ica._parameter_constraints