diff --git a/doc/metrics.rst b/doc/metrics.rst index f7e249c02..10e5e3fc6 100644 --- a/doc/metrics.rst +++ b/doc/metrics.rst @@ -60,6 +60,14 @@ The :func:`macro_averaged_mean_absolute_error` :cite:`esuli2009ordinal` is used for imbalanced ordinal classification. The mean absolute error is computed for each class and averaged over classes, giving an equal weight to each class. +.. _macro_averaged_mean_squared_error: + +Macro-Averaged Mean Squared Error (MA-MSE) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Like MA-MAE, but it penalizes errors that are further from the ground truth more +harshly, in the same fashion as MSE for MAE. + .. _classification_report: Summary of important metrics diff --git a/doc/references/metrics.rst b/doc/references/metrics.rst index 200e88d97..618d7135e 100644 --- a/doc/references/metrics.rst +++ b/doc/references/metrics.rst @@ -23,6 +23,7 @@ See the :ref:`metrics` section of the user guide for further details. specificity_score geometric_mean_score macro_averaged_mean_absolute_error + macro_averaged_mean_squared_error make_index_balanced_accuracy Pairwise metrics diff --git a/imblearn/metrics/__init__.py b/imblearn/metrics/__init__.py index a95bab00f..f3ce21eab 100644 --- a/imblearn/metrics/__init__.py +++ b/imblearn/metrics/__init__.py @@ -7,6 +7,7 @@ classification_report_imbalanced, geometric_mean_score, macro_averaged_mean_absolute_error, + macro_averaged_mean_squared_error, make_index_balanced_accuracy, sensitivity_score, sensitivity_specificity_support, @@ -21,4 +22,5 @@ "make_index_balanced_accuracy", "classification_report_imbalanced", "macro_averaged_mean_absolute_error", + "macro_averaged_mean_squared_error", ] diff --git a/imblearn/metrics/_classification.py b/imblearn/metrics/_classification.py index 17797a9a7..d4d22956d 100644 --- a/imblearn/metrics/_classification.py +++ b/imblearn/metrics/_classification.py @@ -21,7 +21,11 @@ import numpy as np import scipy as sp -from sklearn.metrics import mean_absolute_error, precision_recall_fscore_support +from sklearn.metrics import ( + mean_absolute_error, + mean_squared_error, + precision_recall_fscore_support, +) from sklearn.metrics._classification import _check_targets, _prf_divide from sklearn.preprocessing import LabelEncoder from sklearn.utils._param_validation import Interval, StrOptions @@ -1139,3 +1143,76 @@ def macro_averaged_mean_absolute_error(y_true, y_pred, *, sample_weight=None): ) return np.sum(mae) / len(mae) + + +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like"], + "sample_weight": ["array-like", None], + }, + prefer_skip_nested_validation=True, +) +def macro_averaged_mean_squared_error(y_true, y_pred, *, sample_weight=None): + """Compute Macro-Averaged MSE for imbalanced ordinal classification. + + This function computes each MSE for each class and average them, + giving an equal weight to each class. + + Read more in the :ref:`User Guide `. + + .. versionadded:: 0.14 + + Parameters + ---------- + y_true : array-like of shape (n_samples,) or (n_samples, n_outputs) + Ground truth (correct) target values. + + y_pred : array-like of shape (n_samples,) or (n_samples, n_outputs) + Estimated targets as returned by a classifier. + + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. + + Returns + ------- + loss : float or ndarray of floats + Macro-Averaged MSE output is non-negative floating point. + The best value is 0.0. + + Examples + -------- + >>> from sklearn.metrics import mean_squared_error + >>> from imblearn.metrics import macro_averaged_mean_squared_error + >>> y_true_balanced = [1, 1, 3, 3] + >>> y_true_imbalanced = [1, 3, 3, 3] + >>> y_pred = [1, 3, 1, 3] + >>> mean_squared_error(y_true_balanced, y_pred) + 2.0 + >>> mean_squared_error(y_true_imbalanced, y_pred) + 1.0 + >>> macro_averaged_mean_squared_error(y_true_balanced, y_pred) + 2.0 + >>> macro_averaged_mean_squared_error(y_true_imbalanced, y_pred) + 0.66... + """ + _, y_true, y_pred = _check_targets(y_true, y_pred) + if sample_weight is not None: + sample_weight = column_or_1d(sample_weight) + else: + sample_weight = np.ones(y_true.shape) + check_consistent_length(y_true, y_pred, sample_weight) + labels = unique_labels(y_true, y_pred) + mse = [] + for possible_class in labels: + indices = np.flatnonzero(y_true == possible_class) + + mse.append( + mean_squared_error( + y_true[indices], + y_pred[indices], + sample_weight=sample_weight[indices], + ) + ) + + return np.sum(mse) / len(mse) diff --git a/imblearn/metrics/tests/test_classification.py b/imblearn/metrics/tests/test_classification.py index 26d6be0ad..d39d33faa 100644 --- a/imblearn/metrics/tests/test_classification.py +++ b/imblearn/metrics/tests/test_classification.py @@ -30,6 +30,7 @@ classification_report_imbalanced, geometric_mean_score, macro_averaged_mean_absolute_error, + macro_averaged_mean_squared_error, make_index_balanced_accuracy, sensitivity_score, sensitivity_specificity_support, @@ -550,3 +551,33 @@ def test_macro_averaged_mean_absolute_error_sample_weight(): ) assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights) + + +@pytest.mark.parametrize( + "y_true, y_pred, expected_ma_mae", + [ + ([1, 1, 1, 2, 2, 2], [1, 2, 1, 2, 1, 2], 0.333), + ([1, 1, 1, 1, 1, 2], [1, 2, 1, 2, 1, 2], 0.2), + ([1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.777), + ([1, 1, 1, 1, 1, 1, 2, 3, 3], [1, 3, 1, 2, 1, 1, 2, 3, 3], 0.277), + ], +) +def test_macro_averaged_mean_squared_error(y_true, y_pred, expected_ma_mae): + ma_mae = macro_averaged_mean_squared_error(y_true, y_pred) + assert ma_mae == pytest.approx(expected_ma_mae, rel=R_TOL) + + +def test_macro_averaged_mean_squared_error_sample_weight(): + y_true = [1, 1, 1, 2, 2, 2] + y_pred = [1, 2, 1, 2, 1, 2] + + ma_mae_no_weights = macro_averaged_mean_squared_error(y_true, y_pred) + + sample_weight = [1, 1, 1, 1, 1, 1] + ma_mae_unit_weights = macro_averaged_mean_squared_error( + y_true, + y_pred, + sample_weight=sample_weight, + ) + + assert ma_mae_unit_weights == pytest.approx(ma_mae_no_weights) diff --git a/imblearn/tests/test_public_functions.py b/imblearn/tests/test_public_functions.py index 067569ee6..a1dcdcbd1 100644 --- a/imblearn/tests/test_public_functions.py +++ b/imblearn/tests/test_public_functions.py @@ -17,6 +17,7 @@ "imblearn.metrics.classification_report_imbalanced", "imblearn.metrics.geometric_mean_score", "imblearn.metrics.macro_averaged_mean_absolute_error", + "imblearn.metrics.macro_averaged_mean_squared_error", "imblearn.metrics.make_index_balanced_accuracy", "imblearn.metrics.sensitivity_specificity_support", "imblearn.metrics.sensitivity_score",