diff --git a/aeon/testing/estimator_checking/_yield_transformation_checks.py b/aeon/testing/estimator_checking/_yield_transformation_checks.py index f01d2b5370..178ec5e167 100644 --- a/aeon/testing/estimator_checking/_yield_transformation_checks.py +++ b/aeon/testing/estimator_checking/_yield_transformation_checks.py @@ -10,8 +10,12 @@ from aeon.testing.testing_data import FULL_TEST_DATA_DICT from aeon.testing.utils.deep_equals import deep_equals from aeon.testing.utils.estimator_checks import _run_estimator_method +from aeon.transformations.collection import CollectionInverseTransformerMixin from aeon.transformations.collection.channel_selection.base import BaseChannelSelector -from aeon.transformations.series import BaseSeriesTransformer +from aeon.transformations.series import ( + BaseSeriesTransformer, + SeriesInverseTransformerMixin, +) from aeon.utils.data_types import COLLECTIONS_DATA_TYPES, VALID_SERIES_INNER_TYPES @@ -82,8 +86,19 @@ def check_transformer_overrides_and_tags(estimator_class): else: # must be a list assert any([t in valid_unequal_types for t in X_inner_type]) + inherits_inverse = ( + issubclass(estimator_class, SeriesInverseTransformerMixin) + if issubclass(estimator_class, BaseSeriesTransformer) + else issubclass(estimator_class, CollectionInverseTransformerMixin) + ) if estimator_class.get_class_tag("capability:inverse_transform"): + assert inherits_inverse + assert "inverse_transform" not in estimator_class.__dict__ assert "_inverse_transform" in estimator_class.__dict__ + else: + assert not inherits_inverse + assert "inverse_transform" not in estimator_class.__dict__ + assert "_inverse_transform" not in estimator_class.__dict__ def check_transformer_output(estimator, datatype): diff --git a/aeon/transformations/__init__.py b/aeon/transformations/__init__.py index 23421c062b..eae067f9cb 100644 --- a/aeon/transformations/__init__.py +++ b/aeon/transformations/__init__.py @@ -2,6 +2,7 @@ __all__ = [ "BaseTransformer", + "InverseTransformerMixin", ] -from aeon.transformations.base import BaseTransformer +from aeon.transformations.base import BaseTransformer, InverseTransformerMixin diff --git a/aeon/transformations/base.py b/aeon/transformations/base.py index f0ae37d008..9b55affebc 100644 --- a/aeon/transformations/base.py +++ b/aeon/transformations/base.py @@ -1,9 +1,9 @@ """Base class for transformers.""" __maintainer__ = ["MatthewMiddlehurst", "TonyBagnall"] -__all__ = ["BaseTransformer"] +__all__ = ["BaseTransformer", "InverseTransformerMixin"] -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np import pandas as pd @@ -112,3 +112,76 @@ def _check_y(self, y, n_cases=None): f"Mismatch in number of cases. Number in X = {n_cases} nos in y = " f"{n_labels}" ) + + +class InverseTransformerMixin(ABC): + """Mixin for transformers that support inverse transformation.""" + + _tags = { + "capability:inverse_transform": True, + } + + @abstractmethod + def inverse_transform(self, X, y=None, axis=1): + """Inverse transform X and return an inverse transformed version. + + Currently it is assumed that only transformers with tags + "input_data_type"="Series", "output_data_type"="Series", + can have an inverse_transform. + + State required: + Requires state to be "fitted". + + Accesses in self: + _is_fitted : must be True + fitted model attributes (ending in "_") : accessed by _inverse_transform + + Parameters + ---------- + X : Series or Collection, any supported type + Data to fit transform to, of python type as follows: + Series: 2D np.ndarray shape (n_channels, n_timepoints) + Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints) + or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i) + y : Series, default=None + Additional data, e.g., labels for transformation. + axis : int, default = 1 + Axis of time in the input series. + If ``axis == 0``, it is assumed each column is a time series and each row is + a time point. i.e. the shape of the data is ``(n_timepoints, + n_channels)``. + ``axis == 1`` indicates the time series are in rows, i.e. the shape of + the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates + that the axis of X is the same as ``self.axis``. + + Only relevant for ``aeon.transformations.series`` transformers. + + Returns + ------- + inverse transformed version of X + of the same type as X + """ + ... + + @abstractmethod + def _inverse_transform(self, X, y=None): + """Inverse transform X and return an inverse transformed version. + + private _inverse_transform containing core logic, called from inverse_transform. + + Parameters + ---------- + X : Series or Collection, any supported type + Data to fit transform to, of python type as follows: + Series: 2D np.ndarray shape (n_channels, n_timepoints) + Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints) + or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i) + y : Series, default=None + Additional data, e.g., labels for transformation. + + Returns + ------- + inverse transformed version of X + of the same type as X. + """ + ... diff --git a/aeon/transformations/collection/__init__.py b/aeon/transformations/collection/__init__.py index 19dddc6e99..306fb46f49 100644 --- a/aeon/transformations/collection/__init__.py +++ b/aeon/transformations/collection/__init__.py @@ -3,6 +3,7 @@ __all__ = [ # base class and series wrapper "BaseCollectionTransformer", + "CollectionInverseTransformerMixin", # transformers "AutocorrelationFunctionTransformer", "ARCoefficientTransformer", @@ -32,4 +33,7 @@ from aeon.transformations.collection._reduce import Tabularizer from aeon.transformations.collection._rescale import Centerer, MinMaxScaler, Normalizer from aeon.transformations.collection._slope import SlopeTransformer -from aeon.transformations.collection.base import BaseCollectionTransformer +from aeon.transformations.collection.base import ( + BaseCollectionTransformer, + CollectionInverseTransformerMixin, +) diff --git a/aeon/transformations/collection/_broadcaster.py b/aeon/transformations/collection/_broadcaster.py index f284416a47..9a4c36faa0 100644 --- a/aeon/transformations/collection/_broadcaster.py +++ b/aeon/transformations/collection/_broadcaster.py @@ -5,12 +5,17 @@ import numpy as np -from aeon.transformations.collection.base import BaseCollectionTransformer +from aeon.transformations.collection.base import ( + BaseCollectionTransformer, + CollectionInverseTransformerMixin, +) from aeon.transformations.series.base import BaseSeriesTransformer from aeon.utils.validation.collection import get_n_cases -class SeriesToCollectionBroadcaster(BaseCollectionTransformer): +class SeriesToCollectionBroadcaster( + CollectionInverseTransformerMixin, BaseCollectionTransformer +): """Broadcast a ``BaseSeriesTransformer`` over a collection of time series. Uses the ``BaseSeriesTransformer`` passed in the constructor. If the diff --git a/aeon/transformations/collection/base.py b/aeon/transformations/collection/base.py index 54aa1bb839..55df22b0a2 100644 --- a/aeon/transformations/collection/base.py +++ b/aeon/transformations/collection/base.py @@ -20,15 +20,13 @@ class name: BaseCollectionTransformer """ __maintainer__ = ["MatthewMiddlehurst"] -__all__ = [ - "BaseCollectionTransformer", -] +__all__ = ["BaseCollectionTransformer", "CollectionInverseTransformerMixin"] from abc import abstractmethod from typing import final from aeon.base import BaseCollectionEstimator -from aeon.transformations.base import BaseTransformer +from aeon.transformations.base import BaseTransformer, InverseTransformerMixin from aeon.utils.validation.collection import get_n_cases @@ -208,63 +206,6 @@ def fit_transform(self, X, y=None): self.is_fitted = True return Xt - @final - def inverse_transform(self, X, y=None): - """Inverse transform X and return an inverse transformed version. - - Currently it is assumed that only transformers with tags - "input_data_type"="Series", "output_data_type"="Series", - can have an inverse_transform. - - State required: - Requires state to be "fitted". - - Accesses in self: - _is_fitted : must be True - fitted model attributes (ending in "_") : accessed by _inverse_transform - - Parameters - ---------- - X : np.ndarray or list - Data to fit transform to, of valid collection type. Input data, - any number of channels, equal length series of shape ``( - n_cases, n_channels, n_timepoints)`` or list of numpy arrays (number - of channels, series length) of shape ``[n_cases]``, 2D np.array - ``(n_channels, n_timepoints_i)``, where ``n_timepoints_i`` is length of - series ``i``. Other types are allowed and converted into one of the above. - - Different estimators have different capabilities to handle different - types of input. If ``self.get_tag("capability:multivariate")`` is False, - they cannot handle multivariate series. If ``self.get_tag( - "capability:unequal_length")`` is False, they cannot handle unequal - length input. In both situations, a ``ValueError`` is raised if X has a - characteristic that the estimator does not have the capability to handle. - y : np.ndarray, default=None - 1D np.array of float or str, of shape ``(n_cases)`` - class labels - (ground truth) for fitting indices corresponding to instance indices in X. - If None, no labels are used in fitting. - - Returns - ------- - inverse transformed version of X - of the same type as X - """ - if not self.get_tag("capability:inverse_transform"): - raise NotImplementedError( - f"{type(self)} does not implement inverse_transform" - ) - - # check whether is fitted - self._check_is_fitted() - - # input check and conversion for X/y - X_inner = self._preprocess_collection(X, store_metadata=False) - y_inner = y - - Xt = self._inverse_transform(X=X_inner, y=y_inner) - - return Xt - def _fit(self, X, y=None): """Fit transformer to X and y. @@ -323,23 +264,58 @@ def _fit_transform(self, X, y=None): self._fit(X, y) return self._transform(X, y) - def _inverse_transform(self, X, y=None): + +class CollectionInverseTransformerMixin(InverseTransformerMixin): + """Mixin for transformers that support inverse transformation.""" + + _tags = { + "capability:inverse_transform": True, + } + + @final + def inverse_transform(self, X, y=None): """Inverse transform X and return an inverse transformed version. - private _inverse_transform containing core logic, called from inverse_transform. + Currently it is assumed that only transformers with tags + "input_data_type"="Series", "output_data_type"="Series", + can have an inverse_transform. + + State required: + Requires state to be "fitted". + + Accesses in self: + _is_fitted : must be True + fitted model attributes (ending in "_") : accessed by _inverse_transform Parameters ---------- - X : Input data - Data to fit transform to, of valid collection type. - y : Target variable, default=None - Additional data, e.g., labels for transformation + X : Series or Collection, any supported type + Data to fit transform to, of python type as follows: + Series: 2D np.ndarray shape (n_channels, n_timepoints) + Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints) + or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i) + y : Series, default=None + Additional data, e.g., labels for transformation. + axis : int, default = 1 + Axis of time in the input series. + If ``axis == 0``, it is assumed each column is a time series and each row is + a time point. i.e. the shape of the data is ``(n_timepoints, + n_channels)``. + ``axis == 1`` indicates the time series are in rows, i.e. the shape of + the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates + that the axis of X is the same as ``self.axis``. + + Only relevant for ``aeon.transformations.series`` transformers. Returns ------- inverse transformed version of X - of the same type as X. + of the same type as X """ - raise NotImplementedError( - f"{self.__class__.__name__} does not support inverse_transform" - ) + # check whether is fitted + self._check_is_fitted() + + # input check and conversion for X/y + X_inner = self._preprocess_collection(X, store_metadata=False) + Xt = self._inverse_transform(X=X_inner, y=y) + return Xt diff --git a/aeon/transformations/collection/compose/_identity.py b/aeon/transformations/collection/compose/_identity.py index a359255242..12f5c4a3aa 100644 --- a/aeon/transformations/collection/compose/_identity.py +++ b/aeon/transformations/collection/compose/_identity.py @@ -1,16 +1,16 @@ """Identity transformer.""" from aeon.transformations.collection import BaseCollectionTransformer +from aeon.transformations.collection.base import CollectionInverseTransformerMixin from aeon.utils.data_types import COLLECTIONS_DATA_TYPES -class CollectionId(BaseCollectionTransformer): +class CollectionId(CollectionInverseTransformerMixin, BaseCollectionTransformer): """Identity transformer, returns data unchanged in transform/inverse_transform.""" _tags = { "X_inner_type": COLLECTIONS_DATA_TYPES, "fit_is_empty": True, - "capability:inverse_transform": True, "capability:multivariate": True, "capability:unequal_length": True, "capability:missing_values": True, diff --git a/aeon/transformations/collection/dictionary_based/_borf.py b/aeon/transformations/collection/dictionary_based/_borf.py index 738ef1cda1..9b09924edd 100644 --- a/aeon/transformations/collection/dictionary_based/_borf.py +++ b/aeon/transformations/collection/dictionary_based/_borf.py @@ -100,7 +100,6 @@ class BORF(BaseCollectionTransformer): _tags = { "X_inner_type": "numpy3D", - "capability:inverse_transform": False, "capability:missing_values": True, "capability:multivariate": True, "capability:multithreading": True, diff --git a/aeon/transformations/collection/tests/test_base.py b/aeon/transformations/collection/tests/test_base.py index c4aac11407..ada1cf2325 100644 --- a/aeon/transformations/collection/tests/test_base.py +++ b/aeon/transformations/collection/tests/test_base.py @@ -11,7 +11,10 @@ make_example_3d_numpy_list, make_example_pandas_series, ) -from aeon.transformations.collection import BaseCollectionTransformer +from aeon.transformations.collection import ( + BaseCollectionTransformer, + CollectionInverseTransformerMixin, +) @pytest.mark.parametrize( @@ -44,18 +47,7 @@ def test_collection_transformer_invalid_input(dtype): t.fit_transform(y) -def test_raise_inverse_transform(): - """Test that inverse transform raises NotImplementedError.""" - d = _Dummy() - x, _ = make_example_3d_numpy() - d.fit(x) - with pytest.raises( - NotImplementedError, match="does not implement " "inverse_transform" - ): - d.inverse_transform(x) - - -class _Dummy(BaseCollectionTransformer): +class _Dummy(CollectionInverseTransformerMixin, BaseCollectionTransformer): """Dummy transformer for testing. Converts a numpy array to a list of numpy arrays. diff --git a/aeon/transformations/series/__init__.py b/aeon/transformations/series/__init__.py index 96d2c78958..c491479f36 100644 --- a/aeon/transformations/series/__init__.py +++ b/aeon/transformations/series/__init__.py @@ -1,8 +1,11 @@ """Series transformations.""" __all__ = [ - "AutoCorrelationSeriesTransformer", + # base class "BaseSeriesTransformer", + "SeriesInverseTransformerMixin", + # transformers + "AutoCorrelationSeriesTransformer", "ClaSPTransformer", "Dobin", "MatrixProfileTransformer", @@ -38,4 +41,7 @@ from aeon.transformations.series._pla import PLASeriesTransformer from aeon.transformations.series._scaled_logit import ScaledLogitSeriesTransformer from aeon.transformations.series._warping import WarpingSeriesTransformer -from aeon.transformations.series.base import BaseSeriesTransformer +from aeon.transformations.series.base import ( + BaseSeriesTransformer, + SeriesInverseTransformerMixin, +) diff --git a/aeon/transformations/series/_boxcox.py b/aeon/transformations/series/_boxcox.py index 4df8f95720..9f9a926c6b 100644 --- a/aeon/transformations/series/_boxcox.py +++ b/aeon/transformations/series/_boxcox.py @@ -8,7 +8,10 @@ from scipy.special import boxcox, inv_boxcox from scipy.stats import boxcox_llf, distributions, variation -from aeon.transformations.series.base import BaseSeriesTransformer +from aeon.transformations.series.base import ( + BaseSeriesTransformer, + SeriesInverseTransformerMixin, +) # copy-pasted from scipy 1.7.3 since it moved in 1.8.0 and broke this estimator @@ -38,7 +41,7 @@ def _calc_uniform_order_statistic_medians(n): return v -class BoxCoxTransformer(BaseSeriesTransformer): +class BoxCoxTransformer(SeriesInverseTransformerMixin, BaseSeriesTransformer): r"""Box-Cox power transform. Box-Cox transformation is a power transformation that is used to @@ -106,7 +109,6 @@ class BoxCoxTransformer(BaseSeriesTransformer): "X_inner_type": "np.ndarray", "fit_is_empty": False, "capability:multivariate": False, - "capability:inverse_transform": True, } def __init__(self, bounds=None, method="mle", sp=None): diff --git a/aeon/transformations/series/_log.py b/aeon/transformations/series/_log.py index 45d72664f5..d2e0cabead 100644 --- a/aeon/transformations/series/_log.py +++ b/aeon/transformations/series/_log.py @@ -5,10 +5,13 @@ import numpy as np -from aeon.transformations.series.base import BaseSeriesTransformer +from aeon.transformations.series.base import ( + BaseSeriesTransformer, + SeriesInverseTransformerMixin, +) -class LogTransformer(BaseSeriesTransformer): +class LogTransformer(SeriesInverseTransformerMixin, BaseSeriesTransformer): """Natural logarithm transformation. The Natural logarithm transformation can be used to make the data more normally @@ -41,7 +44,6 @@ class LogTransformer(BaseSeriesTransformer): "X_inner_type": "np.ndarray", "fit_is_empty": True, "capability:multivariate": True, - "capability:inverse_transform": True, } def __init__(self, offset=0, scale=1): diff --git a/aeon/transformations/series/_scaled_logit.py b/aeon/transformations/series/_scaled_logit.py index e400385f35..69e8f2127d 100644 --- a/aeon/transformations/series/_scaled_logit.py +++ b/aeon/transformations/series/_scaled_logit.py @@ -8,10 +8,15 @@ import numpy as np -from aeon.transformations.series.base import BaseSeriesTransformer +from aeon.transformations.series.base import ( + BaseSeriesTransformer, + SeriesInverseTransformerMixin, +) -class ScaledLogitSeriesTransformer(BaseSeriesTransformer): +class ScaledLogitSeriesTransformer( + SeriesInverseTransformerMixin, BaseSeriesTransformer +): r"""Scaled logit transform or Log transform. If both lower_bound and upper_bound are not None, a scaled logit transform is @@ -59,7 +64,6 @@ class ScaledLogitSeriesTransformer(BaseSeriesTransformer): "X_inner_type": "np.ndarray", "fit_is_empty": True, "capability:multivariate": True, - "capability:inverse_transform": True, } def __init__(self, lower_bound=None, upper_bound=None): diff --git a/aeon/transformations/series/base.py b/aeon/transformations/series/base.py index 03b0946577..022bd3eca9 100644 --- a/aeon/transformations/series/base.py +++ b/aeon/transformations/series/base.py @@ -14,7 +14,7 @@ class name: BaseSeriesTransformer from deprecated.sphinx import deprecated from aeon.base import BaseSeriesEstimator -from aeon.transformations.base import BaseTransformer +from aeon.transformations.base import BaseTransformer, InverseTransformerMixin class BaseSeriesTransformer(BaseSeriesEstimator, BaseTransformer): @@ -168,44 +168,6 @@ def fit_transform(self, X, y=None, axis=1): self.is_fitted = True return self._postprocess_series(Xt, axis=axis) - @final - def inverse_transform(self, X, y=None, axis=1): - """Inverse transform X and return an inverse transformed version. - - State required: - Requires state to be "fitted". - - Parameters - ---------- - X : Input data - Data to fit transform to, of valid collection type. - y : Target variable, default=None - Additional data, e.g., labels for transformation - axis : int, default = 1 - Axis of time in the input series. - If ``axis == 0``, it is assumed each column is a time series and each row is - a time point. i.e. the shape of the data is ``(n_timepoints, - n_channels)``. - ``axis == 1`` indicates the time series are in rows, i.e. the shape of - the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates - that the axis of X is the same as ``self.axis``. - - Returns - ------- - inverse transformed version of X - of the same type as X - """ - if not self.get_tag("capability:inverse_transform"): - raise NotImplementedError( - f"{type(self)} does not implement inverse_transform" - ) - - # check whether is fitted - self._check_is_fitted() - X = self._preprocess_series(X, axis=axis, store_metadata=False) - Xt = self._inverse_transform(X=X, y=y) - return self._postprocess_series(Xt, axis=axis) - def _fit(self, X, y=None): """Fit transformer to X and y. @@ -267,27 +229,6 @@ def _fit_transform(self, X, y=None): self._fit(X, y) return self._transform(X, y) - def _inverse_transform(self, X, y=None): - """Inverse transform X and return an inverse transformed version. - - private _inverse_transform containing core logic, called from inverse_transform. - - Parameters - ---------- - X : Input data - Time series to fit transform to, of valid collection type. - y : Target variable, default=None - Additional data, e.g., labels for transformation - - Returns - ------- - inverse transformed version of X - of the same type as X. - """ - raise NotImplementedError( - f"{self.__class__.__name__} does not support inverse_transform" - ) - def _postprocess_series(self, Xt, axis): """Postprocess data Xt to revert to original shape. @@ -357,3 +298,58 @@ def update(self, X, y=None, update_params=True, axis=1): def _update(self, X, y=None, update_params=True): # standard behaviour: no update takes place, new data is ignored return self + + +class SeriesInverseTransformerMixin(InverseTransformerMixin): + """Mixin for transformers that support inverse transformation.""" + + _tags = { + "capability:inverse_transform": True, + } + + @final + def inverse_transform(self, X, y=None, axis=1): + """Inverse transform X and return an inverse transformed version. + + Currently it is assumed that only transformers with tags + "input_data_type"="Series", "output_data_type"="Series", + can have an inverse_transform. + + State required: + Requires state to be "fitted". + + Accesses in self: + _is_fitted : must be True + fitted model attributes (ending in "_") : accessed by _inverse_transform + + Parameters + ---------- + X : Series or Collection, any supported type + Data to fit transform to, of python type as follows: + Series: 2D np.ndarray shape (n_channels, n_timepoints) + Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints) + or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i) + y : Series, default=None + Additional data, e.g., labels for transformation. + axis : int, default = 1 + Axis of time in the input series. + If ``axis == 0``, it is assumed each column is a time series and each row is + a time point. i.e. the shape of the data is ``(n_timepoints, + n_channels)``. + ``axis == 1`` indicates the time series are in rows, i.e. the shape of + the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates + that the axis of X is the same as ``self.axis``. + + Only relevant for ``aeon.transformations.series`` transformers. + + Returns + ------- + inverse transformed version of X + of the same type as X + """ + # check whether is fitted + self._check_is_fitted() + + X = self._preprocess_series(X, axis=axis, store_metadata=False) + Xt = self._inverse_transform(X=X, y=y) + return self._postprocess_series(Xt, axis=axis) diff --git a/aeon/transformations/series/compose/_identity.py b/aeon/transformations/series/compose/_identity.py index d9135fae36..b186b6fb70 100644 --- a/aeon/transformations/series/compose/_identity.py +++ b/aeon/transformations/series/compose/_identity.py @@ -1,16 +1,16 @@ """Identity transformer.""" from aeon.transformations.series import BaseSeriesTransformer +from aeon.transformations.series.base import SeriesInverseTransformerMixin from aeon.utils.data_types import VALID_SERIES_INNER_TYPES -class SeriesId(BaseSeriesTransformer): +class SeriesId(SeriesInverseTransformerMixin, BaseSeriesTransformer): """Identity transformer, returns data unchanged in transform/inverse_transform.""" _tags = { "X_inner_type": VALID_SERIES_INNER_TYPES, "fit_is_empty": True, - "capability:inverse_transform": True, "capability:multivariate": True, "capability:missing_values": True, }