Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion picard/dropin_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def __init__(self, n_components=None, *, ortho=True, extended=None,
super().__init__()

# update parameters constraint dict
self._parameter_constraints["fun"] = [
_parameter_constraints = self._parameter_constraints.copy()
_parameter_constraints["fun"] = [
StrOptions({"tanh", "exp", "cube"}),
callable,
]
self._parameter_constraints = _parameter_constraints
if max_iter < 1:
raise ValueError("max_iter should be greater than 1, got "
"(max_iter={})".format(max_iter))
Expand Down
13 changes: 13 additions & 0 deletions picard/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

from sklearn.decomposition import FastICA
from sklearn.utils._testing import assert_array_almost_equal

from picard import Picard
Expand Down Expand Up @@ -90,3 +91,15 @@ def test_inverse_transform():
def test_picard_errors():
with pytest.raises(ValueError, match='max_iter should be greater than 1'):
Picard(max_iter=0)


def test_picard_fastica_fun_parameter_constraints():
"""Test that the 'fun' parameter constraints have been updated to include
'tanh', 'exp', 'cube', and callable.
This test ensures that the modifications made to the Picard class to
update the parameter constraints for 'fun' do not inadvertently affect
the FastICA class from scikit-learn.
"""
picard = Picard()
fast_ica = FastICA()
assert picard._parameter_constraints != fast_ica._parameter_constraints
Loading