Skip to content

Commit 1933314

Browse files
fkiralyastrogilda
andauthored
feat: skbase inheritance for config objects (#31)
This PR changes config and sampler objects to inherit from `skbase` `BaseObject`, and in turn removes redundant methods such as `__repr__`. Depends on #10 which should be merged first. --------- Co-authored-by: Sankalp Gilda <[email protected]>
1 parent a5db385 commit 1933314

File tree

2 files changed

+2
-301
lines changed

2 files changed

+2
-301
lines changed

src/tsbootstrap/block_bootstrap_configs.py

Lines changed: 0 additions & 284 deletions
Original file line numberDiff line numberDiff line change
@@ -247,45 +247,6 @@ def min_block_length(self, value) -> None:
247247
validate_single_integer(value, min_value=1)
248248
self._min_block_length = value
249249

250-
def __repr__(self) -> str:
251-
"""Return repr(self)."""
252-
base_repr = super().__repr__()
253-
return (
254-
f"{base_repr[:-1]}, "
255-
f"block_length={self.block_length}, "
256-
f"block_length_distribution={self.block_length_distribution}, "
257-
f"wrap_around_flag={self.wrap_around_flag}, "
258-
f"overlap_flag={self.overlap_flag}, "
259-
f"combine_generation_and_sampling_flag={self.combine_generation_and_sampling_flag}, "
260-
f"block_weights={self.block_weights}, "
261-
f"tapered_weights={self.tapered_weights}, "
262-
f"overlap_length={self.overlap_length}, "
263-
f"min_block_length={self.min_block_length})"
264-
)
265-
266-
def __str__(self) -> str:
267-
"""Return str(self)."""
268-
return self.__repr__()
269-
270-
def __eq__(self, other: object) -> bool:
271-
"""Return self == other."""
272-
if not isinstance(other, BlockBootstrapConfig):
273-
return False
274-
return (
275-
super().__eq__(other)
276-
and self.block_length == other.block_length
277-
and self.block_length_distribution
278-
== other.block_length_distribution
279-
and self.wrap_around_flag == other.wrap_around_flag
280-
and self.overlap_flag == other.overlap_flag
281-
and self.combine_generation_and_sampling_flag
282-
== other.combine_generation_and_sampling_flag
283-
and self.block_weights == other.block_weights
284-
and self.tapered_weights == other.tapered_weights
285-
and self.overlap_length == other.overlap_length
286-
and self.min_block_length == other.min_block_length
287-
)
288-
289250

290251
class BaseBlockBootstrapConfig(BlockBootstrapConfig):
291252
"""
@@ -327,24 +288,6 @@ def bootstrap_type(self, value: str | None):
327288
raise ValueError(f"bootstrap_type must be one of {valid_types}.")
328289
self._bootstrap_type = value
329290

330-
def __repr__(self) -> str:
331-
"""Return repr(self)."""
332-
base_repr = super().__repr__()
333-
return f"{base_repr[:-1]}, bootstrap_type={self.bootstrap_type})"
334-
335-
def __str__(self) -> str:
336-
"""Return str(self)."""
337-
return self.__repr__()
338-
339-
def __eq__(self, other: object) -> bool:
340-
"""Return self == other."""
341-
if not isinstance(other, BaseBlockBootstrapConfig):
342-
return False
343-
return (
344-
super().__eq__(other)
345-
and self.bootstrap_type == other.bootstrap_type
346-
)
347-
348291

349292
class MovingBlockBootstrapConfig(BlockBootstrapConfig):
350293
"""
@@ -385,34 +328,6 @@ def __init__(
385328
self._overlap_flag = True
386329
self._block_length_distribution = None
387330

388-
def __repr__(self) -> str:
389-
"""Return repr(self)."""
390-
base_repr = super().__repr__()
391-
return (
392-
f"{base_repr[:-1]}, "
393-
f"block_length={self.block_length}, "
394-
f"wrap_around_flag={self.wrap_around_flag}, "
395-
f"overlap_flag={self.overlap_flag}, "
396-
f"block_length_distribution={self.block_length_distribution})"
397-
)
398-
399-
def __str__(self) -> str:
400-
"""Return str(self)."""
401-
return self.__repr__()
402-
403-
def __eq__(self, other: object) -> bool:
404-
"""Return self == other."""
405-
if not isinstance(other, MovingBlockBootstrapConfig):
406-
return False
407-
return (
408-
super().__eq__(other)
409-
and self.block_length == other.block_length
410-
and self.wrap_around_flag == other.wrap_around_flag
411-
and self.overlap_flag == other.overlap_flag
412-
and self.block_length_distribution
413-
== other.block_length_distribution
414-
)
415-
416331

417332
class StationaryBlockBootstrapConfig(BlockBootstrapConfig):
418333
"""
@@ -453,34 +368,6 @@ def __init__(
453368
self._overlap_flag = True
454369
self._block_length_distribution = "geometric"
455370

456-
def __repr__(self) -> str:
457-
"""Return repr(self)."""
458-
base_repr = super().__repr__()
459-
return (
460-
f"{base_repr[:-1]}, "
461-
f"block_length={self.block_length}, "
462-
f"wrap_around_flag={self.wrap_around_flag}, "
463-
f"overlap_flag={self.overlap_flag}, "
464-
f"block_length_distribution={self.block_length_distribution})"
465-
)
466-
467-
def __str__(self) -> str:
468-
"""Return str(self)."""
469-
return self.__repr__()
470-
471-
def __eq__(self, other: object) -> bool:
472-
"""Return self == other."""
473-
if not isinstance(other, StationaryBlockBootstrapConfig):
474-
return False
475-
return (
476-
super().__eq__(other)
477-
and self.block_length == other.block_length
478-
and self.wrap_around_flag == other.wrap_around_flag
479-
and self.overlap_flag == other.overlap_flag
480-
and self.block_length_distribution
481-
== other.block_length_distribution
482-
)
483-
484371

485372
class CircularBlockBootstrapConfig(BlockBootstrapConfig):
486373
"""
@@ -521,34 +408,6 @@ def __init__(
521408
self._overlap_flag = True
522409
self._block_length_distribution = None
523410

524-
def __repr__(self) -> str:
525-
"""Return repr(self)."""
526-
base_repr = super().__repr__()
527-
return (
528-
f"{base_repr[:-1]}, "
529-
f"block_length={self.block_length}, "
530-
f"wrap_around_flag={self.wrap_around_flag}, "
531-
f"overlap_flag={self.overlap_flag}, "
532-
f"block_length_distribution={self.block_length_distribution})"
533-
)
534-
535-
def __str__(self) -> str:
536-
"""Return str(self)."""
537-
return self.__repr__()
538-
539-
def __eq__(self, other: object) -> bool:
540-
"""Return self == other."""
541-
if not isinstance(other, CircularBlockBootstrapConfig):
542-
return False
543-
return (
544-
super().__eq__(other)
545-
and self.block_length == other.block_length
546-
and self.wrap_around_flag == other.wrap_around_flag
547-
and self.overlap_flag == other.overlap_flag
548-
and self.block_length_distribution
549-
== other.block_length_distribution
550-
)
551-
552411

553412
class NonOverlappingBlockBootstrapConfig(BlockBootstrapConfig):
554413
"""
@@ -589,34 +448,6 @@ def __init__(
589448
self._overlap_flag = False
590449
self._block_length_distribution = None
591450

592-
def __repr__(self) -> str:
593-
"""Return repr(self)."""
594-
base_repr = super().__repr__()
595-
return (
596-
f"{base_repr[:-1]}, "
597-
f"block_length={self.block_length}, "
598-
f"wrap_around_flag={self.wrap_around_flag}, "
599-
f"overlap_flag={self.overlap_flag}, "
600-
f"block_length_distribution={self.block_length_distribution})"
601-
)
602-
603-
def __str__(self) -> str:
604-
"""Return str(self)."""
605-
return self.__repr__()
606-
607-
def __eq__(self, other: object) -> bool:
608-
"""Return self == other."""
609-
if not isinstance(other, NonOverlappingBlockBootstrapConfig):
610-
return False
611-
return (
612-
super().__eq__(other)
613-
and self.block_length == other.block_length
614-
and self.wrap_around_flag == other.wrap_around_flag
615-
and self.overlap_flag == other.overlap_flag
616-
and self.block_length_distribution
617-
== other.block_length_distribution
618-
)
619-
620451

621452
class BartlettsBootstrapConfig(BaseBlockBootstrapConfig):
622453
"""Config class for BartlettBootstrap.
@@ -653,29 +484,6 @@ def __init__(
653484
self._bootstrap_type = "moving"
654485
self._tapered_weights = np.bartlett
655486

656-
def __repr__(self) -> str:
657-
"""Return repr(self)."""
658-
base_repr = super().__repr__()
659-
return (
660-
f"{base_repr[:-1]}, "
661-
f"bootstrap_type={self.bootstrap_type}, "
662-
f"tapered_weights={self.tapered_weights})"
663-
)
664-
665-
def __str__(self) -> str:
666-
"""Return str(self)."""
667-
return self.__repr__()
668-
669-
def __eq__(self, other: object) -> bool:
670-
"""Return self == other."""
671-
if not isinstance(other, BartlettsBootstrapConfig):
672-
return False
673-
return (
674-
super().__eq__(other)
675-
and self.bootstrap_type == other.bootstrap_type
676-
and self.tapered_weights == other.tapered_weights
677-
)
678-
679487

680488
class HammingBootstrapConfig(BaseBlockBootstrapConfig):
681489
"""Config class for HammingBootstrap.
@@ -712,29 +520,6 @@ def __init__(
712520
self._bootstrap_type = "moving"
713521
self._tapered_weights = np.hamming
714522

715-
def __repr__(self) -> str:
716-
"""Return repr(self)."""
717-
base_repr = super().__repr__()
718-
return (
719-
f"{base_repr[:-1]}, "
720-
f"bootstrap_type={self.bootstrap_type}, "
721-
f"tapered_weights={self.tapered_weights})"
722-
)
723-
724-
def __str__(self) -> str:
725-
"""Return str(self)."""
726-
return self.__repr__()
727-
728-
def __eq__(self, other: object) -> bool:
729-
"""Return self == other."""
730-
if not isinstance(other, HammingBootstrapConfig):
731-
return False
732-
return (
733-
super().__eq__(other)
734-
and self.bootstrap_type == other.bootstrap_type
735-
and self.tapered_weights == other.tapered_weights
736-
)
737-
738523

739524
class HanningBootstrapConfig(BaseBlockBootstrapConfig):
740525
"""Config class for HanningBootstrap.
@@ -771,29 +556,6 @@ def __init__(
771556
self._bootstrap_type = "moving"
772557
self._tapered_weights = np.hanning
773558

774-
def __repr__(self) -> str:
775-
"""Return repr(self)."""
776-
base_repr = super().__repr__()
777-
return (
778-
f"{base_repr[:-1]}, "
779-
f"bootstrap_type={self.bootstrap_type}, "
780-
f"tapered_weights={self.tapered_weights})"
781-
)
782-
783-
def __str__(self) -> str:
784-
"""Return str(self)."""
785-
return self.__repr__()
786-
787-
def __eq__(self, other: object) -> bool:
788-
"""Return self == other."""
789-
if not isinstance(other, HanningBootstrapConfig):
790-
return False
791-
return (
792-
super().__eq__(other)
793-
and self.bootstrap_type == other.bootstrap_type
794-
and self.tapered_weights == other.tapered_weights
795-
)
796-
797559

798560
class BlackmanBootstrapConfig(BaseBlockBootstrapConfig):
799561
"""Config class for BlackmanBootstrap.
@@ -830,29 +592,6 @@ def __init__(
830592
self._bootstrap_type = "moving"
831593
self._tapered_weights = np.blackman
832594

833-
def __repr__(self) -> str:
834-
"""Return repr(self)."""
835-
base_repr = super().__repr__()
836-
return (
837-
f"{base_repr[:-1]}, "
838-
f"bootstrap_type={self.bootstrap_type}, "
839-
f"tapered_weights={self.tapered_weights})"
840-
)
841-
842-
def __str__(self) -> str:
843-
"""Return str(self)."""
844-
return self.__repr__()
845-
846-
def __eq__(self, other: object) -> bool:
847-
"""Return self == other."""
848-
if not isinstance(other, BlackmanBootstrapConfig):
849-
return False
850-
return (
851-
super().__eq__(other)
852-
and self.bootstrap_type == other.bootstrap_type
853-
and self.tapered_weights == other.tapered_weights
854-
)
855-
856595

857596
class TukeyBootstrapConfig(BaseBlockBootstrapConfig):
858597
"""Config class for TukeyBootstrap.
@@ -889,26 +628,3 @@ def __init__(
889628
alpha = kwargs.get("alpha", 0.5)
890629
self._bootstrap_type = "moving"
891630
self._tapered_weights = partial(tukey, alpha=alpha)
892-
893-
def __repr__(self) -> str:
894-
"""Return repr(self)."""
895-
base_repr = super().__repr__()
896-
return (
897-
f"{base_repr[:-1]}, "
898-
f"bootstrap_type={self.bootstrap_type}, "
899-
f"tapered_weights={self.tapered_weights})"
900-
)
901-
902-
def __str__(self) -> str:
903-
"""Return str(self)."""
904-
return self.__repr__()
905-
906-
def __eq__(self, other: object) -> bool:
907-
"""Return self == other."""
908-
if not isinstance(other, TukeyBootstrapConfig):
909-
return False
910-
return (
911-
super().__eq__(other)
912-
and self.bootstrap_type == other.bootstrap_type
913-
and self.tapered_weights == other.tapered_weights
914-
)

src/tsbootstrap/block_length_sampler.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,13 @@
3737
import numpy as np
3838
from numpy.random import Generator
3939
from scipy.stats import pareto, weibull_min
40+
from skbase.base import BaseObject
4041

4142
from tsbootstrap.utils.types import RngTypes
4243
from tsbootstrap.utils.validate import validate_integers, validate_rng
4344

4445

45-
class BlockLengthSampler:
46+
class BlockLengthSampler(BaseObject):
4647
"""
4748
A class for sampling block lengths for the random block length bootstrap.
4849
@@ -174,19 +175,3 @@ def sample_block_length(self) -> int:
174175
self.block_length_distribution
175176
](self.rng, self.avg_block_length)
176177
return max(round(sampled_block_length), 2)
177-
178-
def __repr__(self) -> str:
179-
return f"BlockLengthSampler(avg_block_length={self.avg_block_length}, block_length_distribution='{self.block_length_distribution}', rng={self.rng})"
180-
181-
def __str__(self) -> str:
182-
return f"BlockLengthSampler using avg_block_length={self.avg_block_length}, block_length_distribution='{self.block_length_distribution}', and random seed {self.rng}"
183-
184-
def __eq__(self, other: object) -> bool:
185-
if isinstance(other, BlockLengthSampler):
186-
return (
187-
self.avg_block_length == other.avg_block_length
188-
and self.block_length_distribution
189-
== other.block_length_distribution
190-
and self.rng == other.rng
191-
)
192-
return False

0 commit comments

Comments
 (0)