diff --git a/keras/src/wrappers/sklearn_test.py b/keras/src/wrappers/sklearn_test.py index 250b12c51274..6b3dc3fb59f3 100644 --- a/keras/src/wrappers/sklearn_test.py +++ b/keras/src/wrappers/sklearn_test.py @@ -57,7 +57,7 @@ def patched_more_tags(self): return parametrize_with_checks(estimators) -def dynamic_model(X, y, loss, layers=[10]): +def dynamic_model(X, y, loss, out_activation_function="softmax", layers=[10]): """Creates a basic MLP classifier dynamically choosing binary/multiclass classification loss and ouput activations. """ @@ -69,7 +69,7 @@ def dynamic_model(X, y, loss, layers=[10]): hidden = Dense(layer_size, activation="relu")(hidden) n_outputs = y.shape[1] if len(y.shape) > 1 else 1 - out = [Dense(n_outputs, activation="softmax")(hidden)] + out = [Dense(n_outputs, activation=out_activation_function)(hidden)] model = Model(inp, out) model.compile(loss=loss, optimizer="rmsprop") @@ -107,6 +107,9 @@ def use_floatx(x): ), "check_supervised_y_2d": "This test assumes reproducibility in fit.", "check_fit_idempotent": "This test assumes reproducibility in fit.", + "check_classifiers_train": ( + "decision_function can return both probabilities and logits" + ), }, "SKLearnRegressor": { "check_parameters_default_constructible": ( @@ -158,3 +161,51 @@ def test_sklearn_estimator_checks(estimator, check): pytest.xfail("Backend not implemented") else: raise + + +@pytest.mark.parametrize( + "estimator", + [ + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "softmax", + "loss": "binary_crossentropy", + }, + fit_kwargs={"epochs": 1}, + ), + SKLearnClassifier( + model=dynamic_model, + model_kwargs={ + "out_activation_function": "linear", + "loss": "binary_crossentropy", + }, + fit_kwargs={"epochs": 1}, + ), + ], +) +def test_sklearn_estimator_decision_function(estimator): + """Checks that the argmax of ``decision_function`` is the same as + ``predict`` for classifiers. + """ + try: + X, y = sklearn.datasets.make_classification( + n_samples=10, + n_features=10, + n_informative=4, + n_classes=2, + random_state=42, + ) + estimator.fit(X, y) + assert ( + estimator.decision_function(X[:1]).argmax(axis=-1) + == estimator.predict(X[:1]).flatten() + ) + except Exception as exc: + if keras.config.backend() in ["numpy", "openvino"] and ( + isinstance(exc, NotImplementedError) + or "NotImplementedError" in str(exc) + ): + pytest.xfail("Backend not implemented") + else: + raise diff --git a/keras/src/wrappers/sklearn_wrapper.py b/keras/src/wrappers/sklearn_wrapper.py index 90d36c669792..02f5777bf6f5 100644 --- a/keras/src/wrappers/sklearn_wrapper.py +++ b/keras/src/wrappers/sklearn_wrapper.py @@ -18,6 +18,7 @@ from sklearn.base import ClassifierMixin from sklearn.base import RegressorMixin from sklearn.base import TransformerMixin + from sklearn.utils._array_api import get_namespace except ImportError: sklearn = None @@ -278,6 +279,15 @@ def dynamic_model(X, y, loss, layers=[10]): ``` """ + def decision_function(self, X): + """Get raw model outputs.""" + from sklearn.utils.validation import check_is_fitted + + check_is_fitted(self) + + X = _validate_data(self, X, reset=False) + return self.model_.predict(X) + def _process_target(self, y, reset=False): """Classifiers do OHE.""" target_type = type_of_target(y, raise_unknown=True) diff --git a/keras/src/wrappers/utils.py b/keras/src/wrappers/utils.py index 8c2954b055ad..6d4bbaabdea2 100644 --- a/keras/src/wrappers/utils.py +++ b/keras/src/wrappers/utils.py @@ -80,7 +80,7 @@ def inverse_transform(self, y): If the transformer was fit to a 1D numpy array, and a 2D numpy array with a singleton second dimension is passed, it will be squeezed back to 1D. Otherwise, it - will eb left untouched. + will be left untouched. """ from sklearn.utils.validation import check_is_fitted