Skip to content

Commit 856d3f4

Browse files
authored
add mixins (#3009)
1 parent b5b645f commit 856d3f4

File tree

4 files changed

+6
-20
lines changed

4 files changed

+6
-20
lines changed

aeon/forecasting/machine_learning/_setar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77

8-
from aeon.forecasting.base import BaseForecaster
8+
from aeon.forecasting.base import BaseForecaster, IterativeForecastingMixin
99

1010

1111
def _lagmat_1d(y: np.ndarray, maxlag: int) -> np.ndarray:
@@ -38,7 +38,7 @@ def _ols_fit(X: np.ndarray, y: np.ndarray):
3838
return intercept, coefs, sse
3939

4040

41-
class SETAR(BaseForecaster):
41+
class SETAR(BaseForecaster, IterativeForecastingMixin):
4242
"""
4343
Self-Exciting Threshold AutoRegressive (SETAR) forecaster with 2 regimes.
4444

aeon/forecasting/machine_learning/_setarforest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
import numpy as np
66

7-
from aeon.forecasting.base import BaseForecaster
7+
from aeon.forecasting.base import BaseForecaster, IterativeForecastingMixin
88

99
from ._setartree import SETARTree
1010

1111
__maintainer__ = ["TinaJin0228"]
1212
__all__ = ["SETARForest"]
1313

1414

15-
class SETARForest(BaseForecaster):
15+
class SETARForest(BaseForecaster, IterativeForecastingMixin):
1616
"""
1717
SETAR-Forest: Bagging + random subspace ensemble of SETAR-Tree base learners.
1818

aeon/forecasting/machine_learning/_setartree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
from scipy.stats import f
66
from sklearn.linear_model import LinearRegression
77

8-
from aeon.forecasting.base import BaseForecaster
8+
from aeon.forecasting.base import BaseForecaster, IterativeForecastingMixin
99

1010
__maintainer__ = ["TinaJin0228"]
1111
__all__ = ["SETARTree"]
1212

1313

14-
class SETARTree(BaseForecaster):
14+
class SETARTree(BaseForecaster, IterativeForecastingMixin):
1515
"""
1616
SETAR-Tree: A tree algorithm for global time series forecasting.
1717

aeon/forecasting/machine_learning/tests/test_setartree.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,6 @@ def test_linear_series():
2525
assert np.isclose(pred, 21.0, atol=0.01), f"Prediction {pred} not close to 21.0"
2626

2727

28-
def test_iterative_forecast_linear():
29-
"""Test iterative forecasting on a linear series (1-D array of length H)."""
30-
y = np.arange(1, 11.0)
31-
f = SETARTree(lag=2)
32-
preds = f.iterative_forecast(y, 3)
33-
expected = np.arange(11.0, 14.0)
34-
assert preds.shape == (
35-
3,
36-
), f"iterative_forecast should return shape (H,), got {preds.shape}"
37-
assert np.allclose(
38-
preds, expected, atol=0.01
39-
), f"Predictions {preds} not close to {expected}"
40-
41-
4228
def test_fixed_lag():
4329
"""Test the fixed_lag parameter."""
4430
y = np.arange(1, 21.0)

0 commit comments

Comments
 (0)