88
99import numpy as np
1010from scipy .stats import rv_continuous
11+ from skbase .base import BaseObject
1112
1213from tsbootstrap .utils .odds_and_ends import time_series_split
1314
2425from tsbootstrap .tsfit import TSFitBestLag
2526
2627
27- class BaseTimeSeriesBootstrap (metaclass = ABCMeta ):
28+ class BaseTimeSeriesBootstrap (BaseObject ):
2829 """
2930 Base class for time series bootstrapping.
3031
@@ -34,6 +35,8 @@ class BaseTimeSeriesBootstrap(metaclass=ABCMeta):
3435 If n_bootstraps is not greater than 0.
3536 """
3637
38+ _tags = {"object_type" : "bootstrap" }
39+
3740 def __init__ (self , config : BaseTimeSeriesBootstrapConfig ) -> None :
3841 self .config = config
3942
@@ -131,14 +134,14 @@ def _generate_samples(
131134 else :
132135 yield data
133136
134- @abstractmethod
135137 def _generate_samples_single_bootstrap (
136138 self , X : np .ndarray , y : np .ndarray | None = None
137139 ) -> tuple [list [np .ndarray ], list [np .ndarray ]]:
138140 """Generates list of bootstrapped indices and samples for a single bootstrap iteration.
139141
140142 Should be implemented in derived classes.
141143 """
144+ raise NotImplementedError ("abstract method" )
142145
143146 def _check_input (self , X ):
144147 """Checks if the input is valid."""
@@ -154,24 +157,6 @@ def get_n_bootstraps(
154157 """Returns the number of bootstrapping iterations."""
155158 return self .config .n_bootstraps # type: ignore
156159
157- def __repr__ (self ) -> str :
158- """Returns the string representation of the object."""
159- return f"{ self .__class__ .__name__ } (config={ self .config } )"
160-
161- def __str__ (self ) -> str :
162- """Returns the string representation of the object."""
163- return f"{ self .__class__ .__name__ } (config={ self .config } )"
164-
165- def __eq__ (self , __value : object ) -> bool :
166- """Returns True if the objects are equal, False otherwise."""
167- if not isinstance (__value , BaseTimeSeriesBootstrap ):
168- return NotImplemented
169- return self .config == __value .config
170-
171- def __hash__ (self ) -> int :
172- """Returns the hash of the object."""
173- return hash (self .config )
174-
175160
176161class BaseResidualBootstrap (BaseTimeSeriesBootstrap ):
177162 """
@@ -234,24 +219,6 @@ def _fit_model(self, X: np.ndarray, y: np.ndarray | None = None) -> None:
234219 self .order = fit_obj .get_order ()
235220 self .coefs = fit_obj .get_coefs ()
236221
237- def __repr__ (self ) -> str :
238- """Returns the string representation of the object."""
239- return f"{ self .__class__ .__name__ } (config={ self .config } )"
240-
241- def __str__ (self ) -> str :
242- """Returns the string representation of the object."""
243- return f"{ self .__class__ .__name__ } (config={ self .config } )"
244-
245- def __eq__ (self , __value : object ) -> bool :
246- """Returns True if the objects are equal, False otherwise."""
247- if not isinstance (__value , BaseResidualBootstrap ):
248- return NotImplemented
249- return self .config == __value .config
250-
251- def __hash__ (self ) -> int :
252- """Returns the hash of the object."""
253- return hash ((super ().__hash__ (), self .config ))
254-
255222
256223class BaseMarkovBootstrap (BaseResidualBootstrap ):
257224 """
@@ -288,24 +255,6 @@ def __init__(
288255
289256 self .hmm_object = None
290257
291- def __repr__ (self ) -> str :
292- """Returns the string representation of the object."""
293- return f"{ self .__class__ .__name__ } (config={ self .config } )"
294-
295- def __str__ (self ) -> str :
296- """Returns the string representation of the object."""
297- return self .__repr__ ()
298-
299- def __eq__ (self , __value : object ) -> bool :
300- """Returns True if the objects are equal, False otherwise."""
301- if not isinstance (__value , BaseMarkovBootstrap ):
302- return NotImplemented
303- return self .config == __value .config
304-
305- def __hash__ (self ) -> int :
306- """Returns the hash of the object."""
307- return hash ((super ().__hash__ (), self .config ))
308-
309258
310259class BaseStatisticPreservingBootstrap (BaseTimeSeriesBootstrap ):
311260 """Bootstrap class that generates bootstrapped samples preserving a specific statistic.
@@ -352,24 +301,6 @@ def _calculate_statistic(self, X: np.ndarray) -> np.ndarray:
352301 statistic_X = self .config .statistic (X , ** kwargs_stat )
353302 return statistic_X
354303
355- def __repr__ (self ) -> str :
356- """Returns the string representation of the object."""
357- return f"{ self .__class__ .__name__ } (config={ self .config } )"
358-
359- def __str__ (self ) -> str :
360- """Returns the string representation of the object."""
361- return self .__repr__ ()
362-
363- def __eq__ (self , __value : object ) -> bool :
364- """Returns True if the objects are equal, False otherwise."""
365- if not isinstance (__value , BaseStatisticPreservingBootstrap ):
366- return NotImplemented
367- return self .config == __value .config
368-
369- def __hash__ (self ) -> int :
370- """Returns the hash of the object."""
371- return hash ((super ().__hash__ (), self .config ))
372-
373304
374305# We can only fit uni-variate distributions, so X must be a 1D array, and `model_type` in BaseResidualBootstrap must not be "var".
375306class BaseDistributionBootstrap (BaseResidualBootstrap ):
@@ -450,24 +381,6 @@ def _fit_distribution(
450381 resids_dist_params = resids_dist .fit (resids )
451382 return resids_dist , resids_dist_params
452383
453- def __repr__ (self ) -> str :
454- """Returns the string representation of the object."""
455- return f"{ self .__class__ .__name__ } (config={ self .config } )"
456-
457- def __str__ (self ) -> str :
458- """Returns the string representation of the object."""
459- return self .__repr__ ()
460-
461- def __eq__ (self , __value : object ) -> bool :
462- """Returns True if the objects are equal, False otherwise."""
463- if not isinstance (__value , BaseDistributionBootstrap ):
464- return NotImplemented
465- return self .config == __value .config
466-
467- def __hash__ (self ) -> int :
468- """Returns the hash of the object."""
469- return hash ((super ().__hash__ (), self .config ))
470-
471384
472385class BaseSieveBootstrap (BaseResidualBootstrap ):
473386 """
@@ -536,21 +449,3 @@ def _fit_resids_model(self, X: np.ndarray) -> None:
536449 self .resids_fit_model = resids_fit_model
537450 self .resids_order = resids_order
538451 self .resids_coefs = resids_coefs
539-
540- def __repr__ (self ) -> str :
541- """Returns the string representation of the object."""
542- return f"{ self .__class__ .__name__ } (config={ self .config } )"
543-
544- def __str__ (self ) -> str :
545- """Returns the string representation of the object."""
546- return self .__repr__ ()
547-
548- def __eq__ (self , __value : object ) -> bool :
549- """Returns True if the objects are equal, False otherwise."""
550- if not isinstance (__value , BaseSieveBootstrap ):
551- return NotImplemented
552- return self .config == __value .config
553-
554- def __hash__ (self ) -> int :
555- """Returns the hash of the object."""
556- return hash ((super ().__hash__ (), self .config ))
0 commit comments