Skip to content

Commit a5db385

Browse files
fkiralyastrogilda
andauthored
feat: skbase inheritance in central base classes (#10)
Implements #8 - without additional testing - by changing inheritance to `BaseObject` in the base bootstrapper and the base bootstrap config. Main changes in this refactor: * bootstrap classes now inherit from `scikit-base` `BaseObject` * methods provided by `BaseObject` are removed - `__repr__`, `__str__`, `__hash__` * methods provided by `BaseObject` but not currently present are added - `get_params`, `set_params` Depends on #23, as `poetry.lock` will lock the dependencies in place (so imo it should be removed). --------- Co-authored-by: Sankalp Gilda <[email protected]>
1 parent a33b40a commit a5db385

File tree

4 files changed

+20
-373
lines changed

4 files changed

+20
-373
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ homepage = "https://tsbootstrap.readthedocs.io/en/latest/"
2323

2424
[tool.poetry.dependencies]
2525
python = ">=3.10,<3.12"
26+
scikit-base = ">= 0.6.1"
2627
arch = "~5.6"
2728
hmmlearn = "~0.3"
2829
pyclustering = "~0.10"

src/tsbootstrap/base_bootstrap.py

Lines changed: 5 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010
from scipy.stats import rv_continuous
11+
from skbase.base import BaseObject
1112

1213
from tsbootstrap.utils.odds_and_ends import time_series_split
1314

@@ -24,7 +25,7 @@
2425
from tsbootstrap.tsfit import TSFitBestLag
2526

2627

27-
class BaseTimeSeriesBootstrap(metaclass=ABCMeta):
28+
class BaseTimeSeriesBootstrap(BaseObject):
2829
"""
2930
Base class for time series bootstrapping.
3031
@@ -34,6 +35,8 @@ class BaseTimeSeriesBootstrap(metaclass=ABCMeta):
3435
If n_bootstraps is not greater than 0.
3536
"""
3637

38+
_tags = {"object_type": "bootstrap"}
39+
3740
def __init__(self, config: BaseTimeSeriesBootstrapConfig) -> None:
3841
self.config = config
3942

@@ -131,14 +134,14 @@ def _generate_samples(
131134
else:
132135
yield data
133136

134-
@abstractmethod
135137
def _generate_samples_single_bootstrap(
136138
self, X: np.ndarray, y: np.ndarray | None = None
137139
) -> tuple[list[np.ndarray], list[np.ndarray]]:
138140
"""Generates list of bootstrapped indices and samples for a single bootstrap iteration.
139141
140142
Should be implemented in derived classes.
141143
"""
144+
raise NotImplementedError("abstract method")
142145

143146
def _check_input(self, X):
144147
"""Checks if the input is valid."""
@@ -154,24 +157,6 @@ def get_n_bootstraps(
154157
"""Returns the number of bootstrapping iterations."""
155158
return self.config.n_bootstraps # type: ignore
156159

157-
def __repr__(self) -> str:
158-
"""Returns the string representation of the object."""
159-
return f"{self.__class__.__name__}(config={self.config})"
160-
161-
def __str__(self) -> str:
162-
"""Returns the string representation of the object."""
163-
return f"{self.__class__.__name__}(config={self.config})"
164-
165-
def __eq__(self, __value: object) -> bool:
166-
"""Returns True if the objects are equal, False otherwise."""
167-
if not isinstance(__value, BaseTimeSeriesBootstrap):
168-
return NotImplemented
169-
return self.config == __value.config
170-
171-
def __hash__(self) -> int:
172-
"""Returns the hash of the object."""
173-
return hash(self.config)
174-
175160

176161
class BaseResidualBootstrap(BaseTimeSeriesBootstrap):
177162
"""
@@ -234,24 +219,6 @@ def _fit_model(self, X: np.ndarray, y: np.ndarray | None = None) -> None:
234219
self.order = fit_obj.get_order()
235220
self.coefs = fit_obj.get_coefs()
236221

237-
def __repr__(self) -> str:
238-
"""Returns the string representation of the object."""
239-
return f"{self.__class__.__name__}(config={self.config})"
240-
241-
def __str__(self) -> str:
242-
"""Returns the string representation of the object."""
243-
return f"{self.__class__.__name__}(config={self.config})"
244-
245-
def __eq__(self, __value: object) -> bool:
246-
"""Returns True if the objects are equal, False otherwise."""
247-
if not isinstance(__value, BaseResidualBootstrap):
248-
return NotImplemented
249-
return self.config == __value.config
250-
251-
def __hash__(self) -> int:
252-
"""Returns the hash of the object."""
253-
return hash((super().__hash__(), self.config))
254-
255222

256223
class BaseMarkovBootstrap(BaseResidualBootstrap):
257224
"""
@@ -288,24 +255,6 @@ def __init__(
288255

289256
self.hmm_object = None
290257

291-
def __repr__(self) -> str:
292-
"""Returns the string representation of the object."""
293-
return f"{self.__class__.__name__}(config={self.config})"
294-
295-
def __str__(self) -> str:
296-
"""Returns the string representation of the object."""
297-
return self.__repr__()
298-
299-
def __eq__(self, __value: object) -> bool:
300-
"""Returns True if the objects are equal, False otherwise."""
301-
if not isinstance(__value, BaseMarkovBootstrap):
302-
return NotImplemented
303-
return self.config == __value.config
304-
305-
def __hash__(self) -> int:
306-
"""Returns the hash of the object."""
307-
return hash((super().__hash__(), self.config))
308-
309258

310259
class BaseStatisticPreservingBootstrap(BaseTimeSeriesBootstrap):
311260
"""Bootstrap class that generates bootstrapped samples preserving a specific statistic.
@@ -352,24 +301,6 @@ def _calculate_statistic(self, X: np.ndarray) -> np.ndarray:
352301
statistic_X = self.config.statistic(X, **kwargs_stat)
353302
return statistic_X
354303

355-
def __repr__(self) -> str:
356-
"""Returns the string representation of the object."""
357-
return f"{self.__class__.__name__}(config={self.config})"
358-
359-
def __str__(self) -> str:
360-
"""Returns the string representation of the object."""
361-
return self.__repr__()
362-
363-
def __eq__(self, __value: object) -> bool:
364-
"""Returns True if the objects are equal, False otherwise."""
365-
if not isinstance(__value, BaseStatisticPreservingBootstrap):
366-
return NotImplemented
367-
return self.config == __value.config
368-
369-
def __hash__(self) -> int:
370-
"""Returns the hash of the object."""
371-
return hash((super().__hash__(), self.config))
372-
373304

374305
# We can only fit uni-variate distributions, so X must be a 1D array, and `model_type` in BaseResidualBootstrap must not be "var".
375306
class BaseDistributionBootstrap(BaseResidualBootstrap):
@@ -450,24 +381,6 @@ def _fit_distribution(
450381
resids_dist_params = resids_dist.fit(resids)
451382
return resids_dist, resids_dist_params
452383

453-
def __repr__(self) -> str:
454-
"""Returns the string representation of the object."""
455-
return f"{self.__class__.__name__}(config={self.config})"
456-
457-
def __str__(self) -> str:
458-
"""Returns the string representation of the object."""
459-
return self.__repr__()
460-
461-
def __eq__(self, __value: object) -> bool:
462-
"""Returns True if the objects are equal, False otherwise."""
463-
if not isinstance(__value, BaseDistributionBootstrap):
464-
return NotImplemented
465-
return self.config == __value.config
466-
467-
def __hash__(self) -> int:
468-
"""Returns the hash of the object."""
469-
return hash((super().__hash__(), self.config))
470-
471384

472385
class BaseSieveBootstrap(BaseResidualBootstrap):
473386
"""
@@ -536,21 +449,3 @@ def _fit_resids_model(self, X: np.ndarray) -> None:
536449
self.resids_fit_model = resids_fit_model
537450
self.resids_order = resids_order
538451
self.resids_coefs = resids_coefs
539-
540-
def __repr__(self) -> str:
541-
"""Returns the string representation of the object."""
542-
return f"{self.__class__.__name__}(config={self.config})"
543-
544-
def __str__(self) -> str:
545-
"""Returns the string representation of the object."""
546-
return self.__repr__()
547-
548-
def __eq__(self, __value: object) -> bool:
549-
"""Returns True if the objects are equal, False otherwise."""
550-
if not isinstance(__value, BaseSieveBootstrap):
551-
return NotImplemented
552-
return self.config == __value.config
553-
554-
def __hash__(self) -> int:
555-
"""Returns the hash of the object."""
556-
return hash((super().__hash__(), self.config))

0 commit comments

Comments
 (0)