Skip to content
17 changes: 16 additions & 1 deletion aeon/testing/estimator_checking/_yield_transformation_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
from aeon.testing.testing_data import FULL_TEST_DATA_DICT
from aeon.testing.utils.deep_equals import deep_equals
from aeon.testing.utils.estimator_checks import _run_estimator_method
from aeon.transformations.collection import CollectionInverseTransformerMixin
from aeon.transformations.collection.channel_selection.base import BaseChannelSelector
from aeon.transformations.series import BaseSeriesTransformer
from aeon.transformations.series import (
BaseSeriesTransformer,
SeriesInverseTransformerMixin,
)
from aeon.utils.data_types import COLLECTIONS_DATA_TYPES, VALID_SERIES_INNER_TYPES


Expand Down Expand Up @@ -82,8 +86,19 @@ def check_transformer_overrides_and_tags(estimator_class):
else: # must be a list
assert any([t in valid_unequal_types for t in X_inner_type])

inherits_inverse = (
issubclass(estimator_class, SeriesInverseTransformerMixin)
if issubclass(estimator_class, BaseSeriesTransformer)
else issubclass(estimator_class, CollectionInverseTransformerMixin)
)
if estimator_class.get_class_tag("capability:inverse_transform"):
assert inherits_inverse
assert "inverse_transform" not in estimator_class.__dict__
assert "_inverse_transform" in estimator_class.__dict__
else:
assert not inherits_inverse
assert "inverse_transform" not in estimator_class.__dict__
assert "_inverse_transform" not in estimator_class.__dict__


def check_transformer_output(estimator, datatype):
Expand Down
3 changes: 2 additions & 1 deletion aeon/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"BaseTransformer",
"InverseTransformerMixin",
]

from aeon.transformations.base import BaseTransformer
from aeon.transformations.base import BaseTransformer, InverseTransformerMixin
77 changes: 75 additions & 2 deletions aeon/transformations/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Base class for transformers."""

__maintainer__ = ["MatthewMiddlehurst", "TonyBagnall"]
__all__ = ["BaseTransformer"]
__all__ = ["BaseTransformer", "InverseTransformerMixin"]

from abc import abstractmethod
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -112,3 +112,76 @@ def _check_y(self, y, n_cases=None):
f"Mismatch in number of cases. Number in X = {n_cases} nos in y = "
f"{n_labels}"
)


class InverseTransformerMixin(ABC):
"""Mixin for transformers that support inverse transformation."""

_tags = {
"capability:inverse_transform": True,
}

@abstractmethod
def inverse_transform(self, X, y=None, axis=1):
"""Inverse transform X and return an inverse transformed version.

Currently it is assumed that only transformers with tags
"input_data_type"="Series", "output_data_type"="Series",
can have an inverse_transform.

State required:
Requires state to be "fitted".

Accesses in self:
_is_fitted : must be True
fitted model attributes (ending in "_") : accessed by _inverse_transform

Parameters
----------
X : Series or Collection, any supported type
Data to fit transform to, of python type as follows:
Series: 2D np.ndarray shape (n_channels, n_timepoints)
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
y : Series, default=None
Additional data, e.g., labels for transformation.
axis : int, default = 1
Axis of time in the input series.
If ``axis == 0``, it is assumed each column is a time series and each row is
a time point. i.e. the shape of the data is ``(n_timepoints,
n_channels)``.
``axis == 1`` indicates the time series are in rows, i.e. the shape of
the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates
that the axis of X is the same as ``self.axis``.

Only relevant for ``aeon.transformations.series`` transformers.

Returns
-------
inverse transformed version of X
of the same type as X
"""
...

@abstractmethod
def _inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

private _inverse_transform containing core logic, called from inverse_transform.

Parameters
----------
X : Series or Collection, any supported type
Data to fit transform to, of python type as follows:
Series: 2D np.ndarray shape (n_channels, n_timepoints)
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
y : Series, default=None
Additional data, e.g., labels for transformation.

Returns
-------
inverse transformed version of X
of the same type as X.
"""
...
6 changes: 5 additions & 1 deletion aeon/transformations/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__all__ = [
# base class and series wrapper
"BaseCollectionTransformer",
"CollectionInverseTransformerMixin",
# transformers
"AutocorrelationFunctionTransformer",
"ARCoefficientTransformer",
Expand Down Expand Up @@ -32,4 +33,7 @@
from aeon.transformations.collection._reduce import Tabularizer
from aeon.transformations.collection._rescale import Centerer, MinMaxScaler, Normalizer
from aeon.transformations.collection._slope import SlopeTransformer
from aeon.transformations.collection.base import BaseCollectionTransformer
from aeon.transformations.collection.base import (
BaseCollectionTransformer,
CollectionInverseTransformerMixin,
)
9 changes: 7 additions & 2 deletions aeon/transformations/collection/_broadcaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

import numpy as np

from aeon.transformations.collection.base import BaseCollectionTransformer
from aeon.transformations.collection.base import (
BaseCollectionTransformer,
CollectionInverseTransformerMixin,
)
from aeon.transformations.series.base import BaseSeriesTransformer
from aeon.utils.validation.collection import get_n_cases


class SeriesToCollectionBroadcaster(BaseCollectionTransformer):
class SeriesToCollectionBroadcaster(
CollectionInverseTransformerMixin, BaseCollectionTransformer
):
"""Broadcast a ``BaseSeriesTransformer`` over a collection of time series.

Uses the ``BaseSeriesTransformer`` passed in the constructor. If the
Expand Down
118 changes: 47 additions & 71 deletions aeon/transformations/collection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ class name: BaseCollectionTransformer
"""

__maintainer__ = ["MatthewMiddlehurst"]
__all__ = [
"BaseCollectionTransformer",
]
__all__ = ["BaseCollectionTransformer", "CollectionInverseTransformerMixin"]

from abc import abstractmethod
from typing import final

from aeon.base import BaseCollectionEstimator
from aeon.transformations.base import BaseTransformer
from aeon.transformations.base import BaseTransformer, InverseTransformerMixin
from aeon.utils.validation.collection import get_n_cases


Expand Down Expand Up @@ -208,63 +206,6 @@ def fit_transform(self, X, y=None):
self.is_fitted = True
return Xt

@final
def inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

Currently it is assumed that only transformers with tags
"input_data_type"="Series", "output_data_type"="Series",
can have an inverse_transform.

State required:
Requires state to be "fitted".

Accesses in self:
_is_fitted : must be True
fitted model attributes (ending in "_") : accessed by _inverse_transform

Parameters
----------
X : np.ndarray or list
Data to fit transform to, of valid collection type. Input data,
any number of channels, equal length series of shape ``(
n_cases, n_channels, n_timepoints)`` or list of numpy arrays (number
of channels, series length) of shape ``[n_cases]``, 2D np.array
``(n_channels, n_timepoints_i)``, where ``n_timepoints_i`` is length of
series ``i``. Other types are allowed and converted into one of the above.

Different estimators have different capabilities to handle different
types of input. If ``self.get_tag("capability:multivariate")`` is False,
they cannot handle multivariate series. If ``self.get_tag(
"capability:unequal_length")`` is False, they cannot handle unequal
length input. In both situations, a ``ValueError`` is raised if X has a
characteristic that the estimator does not have the capability to handle.
y : np.ndarray, default=None
1D np.array of float or str, of shape ``(n_cases)`` - class labels
(ground truth) for fitting indices corresponding to instance indices in X.
If None, no labels are used in fitting.

Returns
-------
inverse transformed version of X
of the same type as X
"""
if not self.get_tag("capability:inverse_transform"):
raise NotImplementedError(
f"{type(self)} does not implement inverse_transform"
)

# check whether is fitted
self._check_is_fitted()

# input check and conversion for X/y
X_inner = self._preprocess_collection(X, store_metadata=False)
y_inner = y

Xt = self._inverse_transform(X=X_inner, y=y_inner)

return Xt

def _fit(self, X, y=None):
"""Fit transformer to X and y.

Expand Down Expand Up @@ -323,23 +264,58 @@ def _fit_transform(self, X, y=None):
self._fit(X, y)
return self._transform(X, y)

def _inverse_transform(self, X, y=None):

class CollectionInverseTransformerMixin(InverseTransformerMixin):
"""Mixin for transformers that support inverse transformation."""

_tags = {
"capability:inverse_transform": True,
}

@final
def inverse_transform(self, X, y=None):
"""Inverse transform X and return an inverse transformed version.

private _inverse_transform containing core logic, called from inverse_transform.
Currently it is assumed that only transformers with tags
"input_data_type"="Series", "output_data_type"="Series",
can have an inverse_transform.

State required:
Requires state to be "fitted".

Accesses in self:
_is_fitted : must be True
fitted model attributes (ending in "_") : accessed by _inverse_transform

Parameters
----------
X : Input data
Data to fit transform to, of valid collection type.
y : Target variable, default=None
Additional data, e.g., labels for transformation
X : Series or Collection, any supported type
Data to fit transform to, of python type as follows:
Series: 2D np.ndarray shape (n_channels, n_timepoints)
Collection: 3D np.ndarray shape (n_cases, n_channels, n_timepoints)
or list of 2D np.ndarray, case i has shape (n_channels, n_timepoints_i)
y : Series, default=None
Additional data, e.g., labels for transformation.
axis : int, default = 1
Axis of time in the input series.
If ``axis == 0``, it is assumed each column is a time series and each row is
a time point. i.e. the shape of the data is ``(n_timepoints,
n_channels)``.
``axis == 1`` indicates the time series are in rows, i.e. the shape of
the data is ``(n_channels, n_timepoints)`.``axis is None`` indicates
that the axis of X is the same as ``self.axis``.

Only relevant for ``aeon.transformations.series`` transformers.

Returns
-------
inverse transformed version of X
of the same type as X.
of the same type as X
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not support inverse_transform"
)
# check whether is fitted
self._check_is_fitted()

# input check and conversion for X/y
X_inner = self._preprocess_collection(X, store_metadata=False)
Xt = self._inverse_transform(X=X_inner, y=y)
return Xt
4 changes: 2 additions & 2 deletions aeon/transformations/collection/compose/_identity.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Identity transformer."""

from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.collection.base import CollectionInverseTransformerMixin
from aeon.utils.data_types import COLLECTIONS_DATA_TYPES


class CollectionId(BaseCollectionTransformer):
class CollectionId(CollectionInverseTransformerMixin, BaseCollectionTransformer):
"""Identity transformer, returns data unchanged in transform/inverse_transform."""

_tags = {
"X_inner_type": COLLECTIONS_DATA_TYPES,
"fit_is_empty": True,
"capability:inverse_transform": True,
"capability:multivariate": True,
"capability:unequal_length": True,
"capability:missing_values": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ class BORF(BaseCollectionTransformer):

_tags = {
"X_inner_type": "numpy3D",
"capability:inverse_transform": False,
"capability:missing_values": True,
"capability:multivariate": True,
"capability:multithreading": True,
Expand Down
18 changes: 5 additions & 13 deletions aeon/transformations/collection/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
make_example_3d_numpy_list,
make_example_pandas_series,
)
from aeon.transformations.collection import BaseCollectionTransformer
from aeon.transformations.collection import (
BaseCollectionTransformer,
CollectionInverseTransformerMixin,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -44,18 +47,7 @@ def test_collection_transformer_invalid_input(dtype):
t.fit_transform(y)


def test_raise_inverse_transform():
"""Test that inverse transform raises NotImplementedError."""
d = _Dummy()
x, _ = make_example_3d_numpy()
d.fit(x)
with pytest.raises(
NotImplementedError, match="does not implement " "inverse_transform"
):
d.inverse_transform(x)


class _Dummy(BaseCollectionTransformer):
class _Dummy(CollectionInverseTransformerMixin, BaseCollectionTransformer):
"""Dummy transformer for testing.

Converts a numpy array to a list of numpy arrays.
Expand Down
Loading