diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index a9e8a0d4..a455c4f5 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -5,9 +5,9 @@ name: Pytest on: push: - branches: [ "master" ] + branches: [ "master", "release/*" ] pull_request: - branches: [ "master" ] + branches: [ "master", "release/*" ] permissions: contents: read @@ -29,6 +29,7 @@ jobs: pip install flake8 pytest pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f tests/requirements.txt ]; then pip install -r tests/requirements.txt; fi - name: Test with pytest run: | pytest diff --git a/.isort.cfg b/.isort.cfg index 4e7a1fb3..d6e80717 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,3 +1,3 @@ [settings] -src_paths=basicts,tests -skip_glob=baselines/*,assets/*,examples/* \ No newline at end of file +src_paths=src/basicts,tests +skip_glob=baselines/*,assets/* \ No newline at end of file diff --git a/.pylintrc b/.pylintrc index 2066e61c..503e623d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,7 +12,7 @@ ignore=baselines,assets,checkpoints,examples,scripts # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. -ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE +ignore-patterns=^\.|^_|^.*\.md|^.*\.txt|^.*\.csv|^.*\.CFF|^LICENSE|^.*\.toml # Pickle collected data for later comparisons. persistent=no diff --git a/examples/classification/classification_demo.py b/examples/classification/classification_demo.py index 75e1e7b7..a5ab6582 100644 --- a/examples/classification/classification_demo.py +++ b/examples/classification/classification_demo.py @@ -1,6 +1,7 @@ from basicts import BasicTSLauncher from basicts.configs import BasicTSClassificationConfig -from basicts.models.iTransformer import iTransformerForClassification, iTransformerConfig +from basicts.models.iTransformer import (iTransformerConfig, + iTransformerForClassification) def main(): diff --git a/examples/forecasting/forecasting_demo.py b/examples/forecasting/forecasting_demo.py index 939b126b..f4da5737 100644 --- a/examples/forecasting/forecasting_demo.py +++ b/examples/forecasting/forecasting_demo.py @@ -1,9 +1,11 @@ +from torch.optim.lr_scheduler import MultiStepLR + from basicts import BasicTSLauncher from basicts.configs import BasicTSForecastingConfig -from basicts.models.iTransformer import iTransformerForForecasting, iTransformerConfig -from basicts.runners.callback import EarlyStopping, GradientClipping from basicts.metrics import masked_mse -from torch.optim.lr_scheduler import MultiStepLR +from basicts.models.iTransformer import (iTransformerConfig, + iTransformerForForecasting) +from basicts.runners.callback import EarlyStopping, GradientClipping def main(): diff --git a/examples/imputation/imputation_demo.py b/examples/imputation/imputation_demo.py index ba267dbd..08caa181 100644 --- a/examples/imputation/imputation_demo.py +++ b/examples/imputation/imputation_demo.py @@ -1,6 +1,7 @@ from basicts import BasicTSLauncher from basicts.configs import BasicTSImputationConfig -from basicts.models.iTransformer import iTransformerForReconstruction, iTransformerConfig +from basicts.models.iTransformer import (iTransformerConfig, + iTransformerForReconstruction) def main(): diff --git a/pyproject.toml b/pyproject.toml index c9454733..a3d2cc04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "sympy", "openpyxl", "setuptools==59.5.0", - "numpy==1.24.4", + "numpy", "tqdm==4.67.1", "tensorboard==2.18.0", "transformers==4.40.1" diff --git a/src/basicts/__init__.py b/src/basicts/__init__.py index a4a935bf..62a90835 100644 --- a/src/basicts/__init__.py +++ b/src/basicts/__init__.py @@ -1,5 +1,5 @@ from .launcher import BasicTSLauncher -__version__ = '1.0.2' +__version__ = '1.1.0' __all__ = ['__version__', 'BasicTSLauncher'] diff --git a/src/basicts/configs/base_config.py b/src/basicts/configs/base_config.py index 76e18e47..c1a5c39e 100644 --- a/src/basicts/configs/base_config.py +++ b/src/basicts/configs/base_config.py @@ -9,17 +9,22 @@ from functools import partial from numbers import Number from types import FunctionType -from typing import Callable, List, Literal, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Callable, List, Literal, Optional, Tuple, + Union) import numpy as np import torch -from basicts.runners.callback import BasicTSCallback -from basicts.runners.taskflow import BasicTSTaskFlow from easydict import EasyDict from torch.optim.lr_scheduler import LRScheduler from .model_config import BasicTSModelConfig +# avoid circular imports +if TYPE_CHECKING: + from basicts.runners.callback import BasicTSCallback + from basicts.runners.taskflow import BasicTSTaskFlow + + @dataclass(init=False) class BasicTSConfig(EasyDict): @@ -35,8 +40,8 @@ class BasicTSConfig(EasyDict): model_config: BasicTSModelConfig dataset_name: str - taskflow: BasicTSTaskFlow - callbacks: List[BasicTSCallback] + taskflow: "BasicTSTaskFlow" + callbacks: List["BasicTSCallback"] ############################## General Configuration ############################## @@ -277,7 +282,7 @@ def _pack_params(self, obj: type, obj_params: Union[dict, None]) -> dict: elif issubclass(obj, LRScheduler) and k == "optimizer": continue # short cut has higher priority than params in config - elif k in self: + elif k in self and self[k] is not None: obj_params[k] = self[k] return obj_params @@ -338,7 +343,7 @@ def _serialize_obj(self, obj: object) -> object: if not isinstance(is_default, bool): raise ValueError(f"Parameter {k} of {obj.__class__.__name__} is not serializable.") if not is_default: - params[k] = repr(v) + params[k] = self._serialize_obj(v) return { "name": obj.__class__.__name__, diff --git a/src/basicts/configs/tsc_config.py b/src/basicts/configs/tsc_config.py index 4d0a9d43..519273b3 100644 --- a/src/basicts/configs/tsc_config.py +++ b/src/basicts/configs/tsc_config.py @@ -2,12 +2,13 @@ from typing import Callable, List, Literal, Tuple, Union import numpy as np +from torch.nn import CrossEntropyLoss +from torch.optim import Adam + from basicts.data import UEADataset from basicts.runners.callback import BasicTSCallback from basicts.runners.taskflow import (BasicTSClassificationTaskFlow, BasicTSTaskFlow) -from torch.nn import CrossEntropyLoss -from torch.optim import Adam from .base_config import BasicTSConfig from .model_config import BasicTSModelConfig @@ -99,9 +100,11 @@ class BasicTSClassificationConfig(BasicTSConfig): # Dataset settings dataset_type: type = field(default=UEADataset, metadata={"help": "Dataset type."}) - dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."}) + dataset_params: Union[dict, None] = field( + default_factory=lambda: {"memmap": False}, + metadata={"help": "Dataset parameters."}) use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."}) - memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."}) + memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."}) null_val: float = field(default=np.nan, metadata={"help": "Null value."}) null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) @@ -148,7 +151,7 @@ class BasicTSClassificationConfig(BasicTSConfig): optimizer_params: dict = field( default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4}, metadata={"help": "Optimizer parameters."}) - lr: float = field(default=2e-4, metadata={"help": "Learning rate."}) + lr: float = field(default=None, metadata={"help": "Learning rate."}) # Learning rate scheduler lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."}) diff --git a/src/basicts/configs/tsf_config.py b/src/basicts/configs/tsf_config.py index 0013f441..627b72d4 100644 --- a/src/basicts/configs/tsf_config.py +++ b/src/basicts/configs/tsf_config.py @@ -2,12 +2,13 @@ from typing import Callable, List, Literal, Tuple, Union import numpy as np +from torch.optim import Adam + from basicts.data import BasicTSForecastingDataset from basicts.runners.callback import BasicTSCallback from basicts.runners.taskflow import (BasicTSForecastingTaskFlow, BasicTSTaskFlow) from basicts.scaler import ZScoreScaler -from torch.optim import Adam from .base_config import BasicTSConfig from .model_config import BasicTSModelConfig @@ -99,17 +100,26 @@ class BasicTSForecastingConfig(BasicTSConfig): # Dataset settings dataset_type: type = field(default=BasicTSForecastingDataset, metadata={"help": "Dataset type."}) - dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."}) - input_len: int = field(default=336, metadata={"help": "Input length."}) - output_len: int = field(default=336, metadata={"help": "Output length."}) - use_timestamps: bool = field(default=True, metadata={"help": "Whether to use timestamps as supplementary."}) - memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."}) - null_val: float = field(default=np.nan, metadata={"help": "Null value."}) - null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) - + dataset_params: Union[dict, None] = field( + default_factory=lambda: { + "input_len": 336, + "output_len": 336, + "use_timestamps": True, + "memmap": False, + }, metadata={"help": "Dataset parameters."}) + + # shortcuts + input_len: int = field(default=None, metadata={"help": "Input length."}) + output_len: int = field(default=None, metadata={"help": "Output length."}) + use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."}) + memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."}) batch_size: Union[int, None] = field( default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."}) + + null_val: float = field(default=np.nan, metadata={"help": "Null value."}) + null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) + # Scaler settings scaler: type = field(default=ZScoreScaler, metadata={"help": "Scaler type."}) norm_each_channel: bool = field(default=True, metadata={"help": "Whether to normalize data for each channel independently."}) @@ -147,7 +157,7 @@ class BasicTSForecastingConfig(BasicTSConfig): optimizer_params: dict = field( default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4}, metadata={"help": "Optimizer parameters."}) - lr: float = field(default=2e-4, metadata={"help": "Learning rate."}) + lr: float = field(default=None, metadata={"help": "Learning rate."}) # Learning rate scheduler lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."}) diff --git a/src/basicts/configs/tsfm_config.py b/src/basicts/configs/tsfm_config.py index b0980056..f209b658 100644 --- a/src/basicts/configs/tsfm_config.py +++ b/src/basicts/configs/tsfm_config.py @@ -3,12 +3,13 @@ import numpy as np import torch +from torch.optim import AdamW + from basicts.data import BasicTSForecastingDataset from basicts.runners.callback import BasicTSCallback from basicts.runners.optim.lr_schedulers import CosineWarmup from basicts.runners.taskflow import (BasicTSForecastingTaskFlow, BasicTSTaskFlow) -from torch.optim import AdamW from .base_config import BasicTSConfig from .model_config import BasicTSModelConfig @@ -100,10 +101,10 @@ class BasicTSFoundationModelConfig(BasicTSConfig): # Dataset settings dataset_type: type = field(default=BasicTSForecastingDataset, metadata={"help": "Dataset type."}) dataset_params: dict = field(default_factory=dict) - input_len: int = field(default=336, metadata={"help": "Input length."}) - output_len: int = field(default=336, metadata={"help": "Output length."}) - use_timestamps: bool = field(default=False, metadata={"help": "Whether to use timestamps as supplementary."}) - memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."}) + input_len: int = field(default=None, metadata={"help": "Input length."}) + output_len: int = field(default=None, metadata={"help": "Output length."}) + use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."}) + memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."}) batch_size: Optional[int] = field(default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."}) null_val: float = field(default=np.nan, metadata={"help": "Null value."}) null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) @@ -142,7 +143,7 @@ class BasicTSFoundationModelConfig(BasicTSConfig): # Optimizer optimizer: type = field(default=AdamW) optimizer_params: dict = field(default_factory=lambda: {"lr": 1e-3, "fused": True}) - lr: float = field(default=1e-3, metadata={"help": "Learning rate."}) + lr: float = field(default=None, metadata={"help": "Learning rate."}) # Learning rate scheduler lr_scheduler: type = field(default=CosineWarmup) diff --git a/src/basicts/configs/tsi_config.py b/src/basicts/configs/tsi_config.py index e7c96522..1d44a8e0 100644 --- a/src/basicts/configs/tsi_config.py +++ b/src/basicts/configs/tsi_config.py @@ -2,11 +2,12 @@ from typing import Callable, List, Literal, Tuple, Union import numpy as np +from torch.optim import Adam + from basicts.data import BasicTSImputationDataset from basicts.runners.callback import BasicTSCallback from basicts.runners.taskflow import BasicTSImputationTaskFlow, BasicTSTaskFlow from basicts.scaler import ZScoreScaler -from torch.optim import Adam from .base_config import BasicTSConfig from .model_config import BasicTSModelConfig @@ -98,17 +99,24 @@ class BasicTSImputationConfig(BasicTSConfig): # Dataset settings dataset_type: type = field(default=BasicTSImputationDataset, metadata={"help": "Dataset type."}) - dataset_params: Union[dict, None] = field(default=None, metadata={"help": "Dataset parameters."}) - input_len: int = field(default=336, metadata={"help": "Input length."}) + dataset_params: Union[dict, None] = field( + default_factory=lambda: { + "input_len": 336, + "use_timestamps": True, + "memmap": False, + }, metadata={"help": "Dataset parameters."}) + + # shortcuts + input_len: int = field(default=None, metadata={"help": "Input length."}) + use_timestamps: bool = field(default=None, metadata={"help": "Whether to use timestamps as supplementary."}) + memmap: bool = field(default=None, metadata={"help": "Whether to use memmap to load datasets."}) + batch_size: Union[int, None] = field( + default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."}) + mask_ratio: float = field(default=0.25, metadata={"help": "Mask ratio."}) - use_timestamps: bool = field(default=True, metadata={"help": "Whether to use timestamps as supplementary."}) - memmap: bool = field(default=False, metadata={"help": "Whether to use memmap to load datasets."}) null_val: float = field(default=np.nan, metadata={"help": "Null value."}) null_to_num: float = field(default=0.0, metadata={"help": "Null value to number."}) - batch_size: Union[int, None] = field( - default=None, metadata={"help": "Batch size. If setted, all dataloaders will be setted to the same batch size."}) - # Scaler settings scaler: type = field(default=ZScoreScaler, metadata={"help": "Scaler type."}) norm_each_channel: bool = field(default=True, metadata={"help": "Whether to normalize data for each channel independently."}) @@ -146,7 +154,7 @@ class BasicTSImputationConfig(BasicTSConfig): optimizer_params: dict = field( default_factory=lambda: {"lr": 2e-4, "weight_decay": 5e-4}, metadata={"help": "Optimizer parameters."}) - lr: float = field(default=2e-4, metadata={"help": "Learning rate."}) + lr: float = field(default=None, metadata={"help": "Learning rate."}) # Learning rate scheduler lr_scheduler: Union[type, None] = field(default=None, metadata={"help": "Learning rate scheduler type."}) diff --git a/src/basicts/data/base_dataset.py b/src/basicts/data/base_dataset.py index b587178d..1d3e2271 100644 --- a/src/basicts/data/base_dataset.py +++ b/src/basicts/data/base_dataset.py @@ -1,9 +1,10 @@ from typing import Union import numpy as np -from basicts.utils.constants import BasicTSMode from torch.utils.data import Dataset +from basicts.utils.constants import BasicTSMode + class BasicTSDataset(Dataset): """ diff --git a/src/basicts/data/blast.py b/src/basicts/data/blast.py index e22429c0..b53901cb 100644 --- a/src/basicts/data/blast.py +++ b/src/basicts/data/blast.py @@ -3,6 +3,7 @@ from typing import Optional, Union import numpy as np + from basicts.utils.constants import BasicTSMode from .base_dataset import BasicTSDataset @@ -56,7 +57,7 @@ class BLAST(BasicTSDataset): def __post_init__(self): # load data - self.data = self._load_data() + self._data = self._load_data() self.output_len = self.output_len or 0 # minimum valid history sequence length @@ -244,6 +245,15 @@ def __getitem__(self, idx: int) -> tuple: def __len__(self): return self.data.shape[0] + def __getstate__(self): + state = self.__dict__.copy() + del state["_data"] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._data = self._load_data() + @property def data(self) -> np.ndarray: return self._data diff --git a/src/basicts/data/tsf_dataset.py b/src/basicts/data/tsf_dataset.py index c088c0b8..619a550a 100644 --- a/src/basicts/data/tsf_dataset.py +++ b/src/basicts/data/tsf_dataset.py @@ -2,6 +2,7 @@ from typing import Union import numpy as np + from basicts.utils.constants import BasicTSMode from .base_dataset import BasicTSDataset diff --git a/src/basicts/data/tsi_dataset.py b/src/basicts/data/tsi_dataset.py index 2c1256ff..5b14ef12 100644 --- a/src/basicts/data/tsi_dataset.py +++ b/src/basicts/data/tsi_dataset.py @@ -2,6 +2,7 @@ from typing import Union import numpy as np + from basicts.utils.constants import BasicTSMode from .base_dataset import BasicTSDataset diff --git a/src/basicts/data/uea_dataset.py b/src/basicts/data/uea_dataset.py index 86191485..5a5e8358 100644 --- a/src/basicts/data/uea_dataset.py +++ b/src/basicts/data/uea_dataset.py @@ -2,6 +2,7 @@ from typing import Union import numpy as np + from basicts.utils import BasicTSMode from .base_dataset import BasicTSDataset diff --git a/src/basicts/launcher.py b/src/basicts/launcher.py index cbcb4883..e42a1143 100644 --- a/src/basicts/launcher.py +++ b/src/basicts/launcher.py @@ -1,12 +1,13 @@ import traceback from typing import Optional -from basicts.configs.base_config import BasicTSConfig -from basicts.runners import BasicTSRunner from easytorch.device import set_device_type from easytorch.launcher.dist_wrap import dist_wrap from easytorch.utils import get_logger, set_visible_devices +from basicts.configs.base_config import BasicTSConfig +from basicts.runners import BasicTSRunner + class BasicTSLauncher: @@ -46,19 +47,6 @@ def launch_training(cfg: BasicTSConfig, node_rank: int = 0) -> None: set_device_type("cpu") device_num = 0 - def training_func(cfg: BasicTSConfig): - # init runner - runner = BasicTSRunner(cfg) - # init logger (after making ckpt save dir) - runner.init_logger(logger_name="BasicTS-training", log_file_name="training_log") - # train - try: - runner.train() - except BaseException as e: - # log exception to file - runner.logger.error(traceback.format_exc()) - raise e - train_dist = dist_wrap( training_func, node_num=cfg.get("dist_node_num", 1), @@ -110,3 +98,16 @@ def launch_evaluation( # start the evaluation pipeline runner.eval(ckpt_path) + +def training_func(cfg: BasicTSConfig): + # init runner + runner = BasicTSRunner(cfg) + # init logger (after making ckpt save dir) + runner.init_logger(logger_name="BasicTS-training", log_file_name="training_log") + # train + try: + runner.train() + except BaseException as e: + # log exception to file + runner.logger.error(traceback.format_exc()) + raise e diff --git a/src/basicts/models/Autoformer/arch/autoformer_arch.py b/src/basicts/models/Autoformer/arch/autoformer_arch.py index 7e27fd00..1f297068 100644 --- a/src/basicts/models/Autoformer/arch/autoformer_arch.py +++ b/src/basicts/models/Autoformer/arch/autoformer_arch.py @@ -1,10 +1,13 @@ +from typing import Dict, Union + import torch +from torch import nn + from basicts.modules.decomposition import MovingAverageDecomposition from basicts.modules.embed import FeatureEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.norm import CenteredLayerNorm from basicts.modules.transformer import AutoCorrelation, Encoder -from torch import nn from ..config.autoformer_config import AutoformerConfig from .layers import (AutoformerDecoder, AutoformerDecoderLayer, @@ -81,9 +84,9 @@ def forward( self, inputs: torch.Tensor, targets: torch.Tensor, - inputs_timestamps: torch.Tensor, - targets_timestamps: torch.Tensor - ) -> torch.Tensor: + inputs_timestamps: torch.Tensor = None, + targets_timestamps: torch.Tensor = None + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: """ Feed forward of Autoformer. @@ -95,6 +98,7 @@ def forward( Returns: Output data with shape: [batch_size, output_len, num_features] + Attention weights if output_attentions is True, otherwise None. """ # decomp init @@ -114,6 +118,7 @@ def forward( ) # decoder + targets_timestamps = torch.cat([inputs_timestamps[:, -self.label_len:, :], targets_timestamps], dim=1) dec_hidden_states = self.dec_embedding(seasonal, targets_timestamps) dec_output, trend, dec_self_attn_weights, dec_cross_attn_weights = self.decoder( diff --git a/src/basicts/models/Autoformer/arch/layers.py b/src/basicts/models/Autoformer/arch/layers.py index 17e51e4d..3df4f532 100644 --- a/src/basicts/models/Autoformer/arch/layers.py +++ b/src/basicts/models/Autoformer/arch/layers.py @@ -1,9 +1,10 @@ from typing import Any, Callable, Optional, Tuple import torch -from basicts.modules.transformer.utils import build_layer from torch import nn +from basicts.modules.transformer.utils import build_layer + from ..config.autoformer_config import AutoformerConfig diff --git a/src/basicts/models/Crossformer/arch/crossformer_arch.py b/src/basicts/models/Crossformer/arch/crossformer_arch.py index 2d8c3c12..b9be69b0 100644 --- a/src/basicts/models/Crossformer/arch/crossformer_arch.py +++ b/src/basicts/models/Crossformer/arch/crossformer_arch.py @@ -1,10 +1,11 @@ from math import ceil import torch -from basicts.modules.embed import PatchEmbedding from einops import rearrange from torch import nn +from basicts.modules.embed import PatchEmbedding + from ..config.crossformer_config import CrossformerConfig from .crossformer_layers import (CrossformerDecoder, CrossformerDecoderLayer, CrossformerEncoder, CrossformerEncoderLayer) diff --git a/src/basicts/models/Crossformer/arch/crossformer_layers.py b/src/basicts/models/Crossformer/arch/crossformer_layers.py index 6853254d..8d5d87ec 100644 --- a/src/basicts/models/Crossformer/arch/crossformer_layers.py +++ b/src/basicts/models/Crossformer/arch/crossformer_layers.py @@ -1,11 +1,12 @@ from typing import Iterable, Optional import torch -from basicts.modules.mlps import MLPLayer -from basicts.modules.transformer import MultiHeadAttention from einops import rearrange, repeat from torch import nn +from basicts.modules.mlps import MLPLayer +from basicts.modules.transformer import MultiHeadAttention + class PatchMergingLayer(nn.Module): """ diff --git a/src/basicts/models/DLinear/arch/dlinear_arch.py b/src/basicts/models/DLinear/arch/dlinear_arch.py index ac8292f2..ba3b8b66 100644 --- a/src/basicts/models/DLinear/arch/dlinear_arch.py +++ b/src/basicts/models/DLinear/arch/dlinear_arch.py @@ -1,7 +1,8 @@ import torch -from basicts.modules.decomposition import MovingAverageDecomposition from torch import nn +from basicts.modules.decomposition import MovingAverageDecomposition + from ..config.dlinear_config import DLinearConfig diff --git a/src/basicts/models/DUET/arch/duet_arch.py b/src/basicts/models/DUET/arch/duet_arch.py index 07f2130a..9aef59e4 100755 --- a/src/basicts/models/DUET/arch/duet_arch.py +++ b/src/basicts/models/DUET/arch/duet_arch.py @@ -1,12 +1,13 @@ from typing import Any, Dict, List import torch +from einops import rearrange +from torch import nn + from basicts.modules.mlps import MLPLayer from basicts.modules.transformer import (Encoder, EncoderLayer, MultiHeadAttention) from basicts.runners.callback import AddAuxiliaryLoss -from einops import rearrange -from torch import nn from ..config.duet_config import DUETConfig from .linear_extractor_cluster import LinearExtractorCluster diff --git a/src/basicts/models/DUET/arch/linear_extractor_cluster.py b/src/basicts/models/DUET/arch/linear_extractor_cluster.py index c7d0a608..ea5eda09 100755 --- a/src/basicts/models/DUET/arch/linear_extractor_cluster.py +++ b/src/basicts/models/DUET/arch/linear_extractor_cluster.py @@ -1,11 +1,12 @@ import torch -from basicts.models.DLinear import DLinear, DLinearConfig -from basicts.modules.mlps import MLPLayer -from basicts.modules.norm import RevIN from einops import rearrange from torch import nn from torch.distributions.normal import Normal +from basicts.models.DLinear import DLinear, DLinearConfig +from basicts.modules.mlps import MLPLayer +from basicts.modules.norm import RevIN + from ..config.duet_config import DUETConfig diff --git a/src/basicts/models/DUET/arch/mahalanobis_mask.py b/src/basicts/models/DUET/arch/mahalanobis_mask.py index a60cb8a6..1931b3c7 100755 --- a/src/basicts/models/DUET/arch/mahalanobis_mask.py +++ b/src/basicts/models/DUET/arch/mahalanobis_mask.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable import torch import torch.nn.functional as F from einops import rearrange diff --git a/src/basicts/models/FITS/arch/fits_arch.py b/src/basicts/models/FITS/arch/fits_arch.py index 35659151..1cbc7a3e 100644 --- a/src/basicts/models/FITS/arch/fits_arch.py +++ b/src/basicts/models/FITS/arch/fits_arch.py @@ -1,9 +1,11 @@ +# pylint: disable=not-callable from typing import Callable, Dict import torch +from torch import nn + from basicts.metrics import ALL_METRICS from basicts.modules.norm import RevIN -from torch import nn from ..config.fits_config import FITSConfig diff --git a/src/basicts/models/FiLM/__init__.py b/src/basicts/models/FiLM/__init__.py new file mode 100644 index 00000000..f56498d4 --- /dev/null +++ b/src/basicts/models/FiLM/__init__.py @@ -0,0 +1,2 @@ +from .arch import FiLM +from .config.film_config import FiLMConfig diff --git a/src/basicts/models/FiLM/arch/__init__.py b/src/basicts/models/FiLM/arch/__init__.py new file mode 100644 index 00000000..e297243d --- /dev/null +++ b/src/basicts/models/FiLM/arch/__init__.py @@ -0,0 +1 @@ +from .film_arch import FiLM \ No newline at end of file diff --git a/src/basicts/models/FiLM/arch/film_arch.py b/src/basicts/models/FiLM/arch/film_arch.py new file mode 100644 index 00000000..b382d10f --- /dev/null +++ b/src/basicts/models/FiLM/arch/film_arch.py @@ -0,0 +1,176 @@ +# pylint: disable=not-callable +import numpy as np +import torch +import torch.nn.functional as F +from scipy import signal, special +from torch import nn + +from basicts.modules.norm import RevIN + +from ..config.film_config import FiLMConfig + + +class HippoProj(nn.Module): + """ + Hippo projection layer. + """ + def __init__( + self, + order_hippo: int, + discretization_timestep: float = 1.0, + discretization_method: str = "bilinear"): + """ + order_hippo: the order of the HiPPO projection + discretization step size: It should be roughly inverse to the length of the sequence + """ + super().__init__() + self.order_hippo = order_hippo + Q = np.arange(order_hippo, dtype=np.float64) + R = (2 * Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.) ** (i - j + 1)) * R + B = (-1.) ** Q[:, None] * R + C = np.ones((1, self.order_hippo)) + D = np.zeros((1,)) + A, B, _, _, _ = signal.cont2discrete( + (A, B, C, D), dt = discretization_timestep, method = discretization_method) + B = B.squeeze(-1) + + self.register_buffer("A", torch.Tensor(A)) + self.register_buffer("B", torch.Tensor(B)) + vals = np.arange(0.0, 1.0, discretization_timestep) + self.register_buffer( + "eval_matrix", + torch.Tensor( + special.eval_legendre(np.arange(self.order_hippo)[:, None], 1 - 2 * vals).T) + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + inputs : (length, ...) + output : (length, ..., N) where N is the order of the HiPPO projection + """ + c = torch.zeros(inputs.shape[:-1] + tuple([self.order_hippo])).to(inputs.device) + cs = [] + for f in inputs.permute([-1, 0, 1]): + f = f.unsqueeze(-1) + new = f @ self.B.unsqueeze(0) + new = new.to(inputs.device) + c = F.linear(c, self.A.to(inputs.device)) + new + cs.append(c) + + return torch.stack(cs, dim=0) + + +class SpectralConv1d(nn.Module): + """ + 1D Fourier layer. It does FFT, linear transform, and Inverse FFT. + """ + def __init__(self, input_size: int, output_size: int, input_len: int): + super().__init__() + self.input_size = input_size + self.output_size = output_size + self.modes = min(32, input_len // 2) + self.index = list(range(0, self.modes)) + + self.scale = 1 / (self.input_size * self.output_size) + self.weights_real = nn.Parameter( + self.scale * torch.rand(self.input_size, self.output_size, len(self.index), dtype=torch.float)) + self.weights_imag = nn.Parameter( + self.scale * torch.rand(self.input_size, self.output_size, len(self.index), dtype=torch.float)) + + def compl_mul1d( + self, + order: str, + x: torch.Tensor, + weights_real: torch.Tensor, + weights_imag: torch.Tensor + ) -> torch.Tensor: + return torch.complex( + torch.einsum(order, x.real, weights_real) - torch.einsum(order, x.imag, weights_imag), + torch.einsum(order, x.real, weights_imag) + torch.einsum(order, x.imag, weights_real) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, _, _ = x.shape + x_ft = torch.fft.rfft(x) + out_ft = torch.zeros( + B, H, self.output_size, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat) + x_ft = x_ft[:, :, :, :self.modes] + out_ft[:, :, :, :self.modes] = self.compl_mul1d( + "bjix,iox->bjox", x_ft, self.weights_real, self.weights_imag) + x = torch.fft.irfft(out_ft, n=x.size(-1)) + return x + + +class FiLM(nn.Module): + """ + Paper: FiLM: Frequency improved Legendre Memory Model for Long-term Time Series Forecasting + Official Code: https://github.com/tianzhou2011/FiLM + Link: https://arxiv.org/abs/2205.08897 + Venue: NeurIPS 2022 + Task: Long-term Time Series Forecasting + """ + def __init__(self, config: FiLMConfig): + super().__init__() + + self.input_len = config.input_len + self.output_len = config.output_len + self.use_revin = config.use_revin + # b, s, f means b, f + if self.use_revin: + self.revin = RevIN(config.num_features, affine=True) + + self.multiscale = config.multiscale + self.hidden_size = config.hidden_size + self.legts = nn.ModuleList([ + HippoProj( + order_hippo = hidden_size, + discretization_timestep = 1. / self.output_len / i + ) for hidden_size in self.hidden_size for i in self.multiscale]) + self.spec_conv_1 = nn.ModuleList([ + SpectralConv1d( + input_size = hidden_size, + output_size = hidden_size, + input_len = min(self.output_len, self.input_len) + ) for hidden_size in self.hidden_size for _ in range(len(self.multiscale))]) + self.mlp = nn.Linear(len(self.multiscale) * len(self.hidden_size), 1) + + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Feed forward of FiLM. + + Args: + inputs (torch.Tensor): inputs data with shape [batch_size, input_len, num_features] + + Returns: + torch.Tensor: prediction with shape [batch_size, output_len, num_features] + """ + + # Normalization + if self.use_revin: + inputs = self.revin(inputs, "norm") + + # Backbone + prediction = [] + for i in range(0, len(self.multiscale) * len(self.hidden_size)): + x_in_len = self.multiscale[i % len(self.multiscale)] * self.output_len + x_in = inputs[:, -x_in_len:] + legt = self.legts[i] + x_in = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0]) + x_out = self.spec_conv_1[i](x_in) + if self.input_len >= self.output_len: + x_out = x_out.transpose(2, 3)[:, :, self.output_len - 1, :] + else: + x_out = x_out.transpose(2, 3)[:, :, -1, :] + x_out = x_out @ legt.eval_matrix[-self.output_len:, :].T + prediction.append(x_out) + + prediction = torch.stack(prediction, dim=-1) + prediction = self.mlp(prediction).squeeze(-1).permute(0, 2, 1) + + # De-Normalization + if self.use_revin: + prediction = self.revin(prediction, "denorm") + + return prediction diff --git a/src/basicts/models/FiLM/config/film_config.py b/src/basicts/models/FiLM/config/film_config.py new file mode 100644 index 00000000..40508f2f --- /dev/null +++ b/src/basicts/models/FiLM/config/film_config.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass, field + +from basicts.configs import BasicTSModelConfig + + +@dataclass +class FiLMConfig(BasicTSModelConfig): + + """ + Config class for FiLM model. + """ + + input_len: int = field(default=None, metadata={"help": "Input sequence length."}) + output_len: int = field(default=None, metadata={"help": "Output sequence length."}) + num_features: int = field(default=None, metadata={"help": "Number of features."}) + hidden_size: list = field(default_factory = lambda: [256], metadata={"help": "Hidden size."}) + multiscale: list = field(default_factory = lambda: [1, 2, 4], metadata={"help": "Different scales for input length."}) + use_revin: bool = field(default=True, metadata={"help": "Whether to use RevIN."}) + diff --git a/src/basicts/models/FreTS/arch/frets_arch.py b/src/basicts/models/FreTS/arch/frets_arch.py index 26d35370..9cbb932e 100644 --- a/src/basicts/models/FreTS/arch/frets_arch.py +++ b/src/basicts/models/FreTS/arch/frets_arch.py @@ -1,8 +1,10 @@ +# pylint: disable=not-callable import torch -from basicts.modules import MLPLayer from torch import nn from torch.nn import functional as F +from basicts.modules import MLPLayer + from ..config.frets_config import FreTSConfig diff --git a/src/basicts/models/HI/arch/hi_arch.py b/src/basicts/models/HI/arch/hi_arch.py index d3b6be76..f179badc 100644 --- a/src/basicts/models/HI/arch/hi_arch.py +++ b/src/basicts/models/HI/arch/hi_arch.py @@ -3,9 +3,10 @@ from typing import List import torch -from basicts.runners.callback import NoBP from torch import nn +from basicts.runners.callback import NoBP + from ..config.hi_config import HIConfig diff --git a/src/basicts/models/Informer/arch/conv.py b/src/basicts/models/Informer/arch/conv.py index 6a507b52..218532aa 100644 --- a/src/basicts/models/Informer/arch/conv.py +++ b/src/basicts/models/Informer/arch/conv.py @@ -1,7 +1,8 @@ import torch -from basicts.modules.activations import ACT2FN from torch import nn +from basicts.modules.activations import ACT2FN + class ConvLayer(nn.Module): diff --git a/src/basicts/models/Informer/arch/informer_arch.py b/src/basicts/models/Informer/arch/informer_arch.py index f2db51e8..612bb7c4 100644 --- a/src/basicts/models/Informer/arch/informer_arch.py +++ b/src/basicts/models/Informer/arch/informer_arch.py @@ -1,13 +1,14 @@ from typing import Optional import torch +from torch import nn + from basicts.modules.embed import FeatureEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.transformer import (EncoderLayer, MultiHeadAttention, ProbAttention, Seq2SeqDecoder, Seq2SeqDecoderLayer, prepare_causal_attention_mask) -from torch import nn from ..config.informer_config import InformerConfig from .conv import ConvLayer diff --git a/src/basicts/models/Koopa/__init__.py b/src/basicts/models/Koopa/__init__.py new file mode 100644 index 00000000..a028c781 --- /dev/null +++ b/src/basicts/models/Koopa/__init__.py @@ -0,0 +1,2 @@ +from .arch import Koopa +from .config.koopa_config import KoopaConfig \ No newline at end of file diff --git a/src/basicts/models/Koopa/arch/__init__.py b/src/basicts/models/Koopa/arch/__init__.py new file mode 100644 index 00000000..099dc9d7 --- /dev/null +++ b/src/basicts/models/Koopa/arch/__init__.py @@ -0,0 +1 @@ +from .koopa_arch import Koopa \ No newline at end of file diff --git a/src/basicts/models/Koopa/arch/koopa_arch.py b/src/basicts/models/Koopa/arch/koopa_arch.py new file mode 100644 index 00000000..6a85b1e1 --- /dev/null +++ b/src/basicts/models/Koopa/arch/koopa_arch.py @@ -0,0 +1,104 @@ +import torch +from torch import nn + +from ..callback.koopa_mask_init import KoopaMaskInitCallback +from ..config.koopa_config import KoopaConfig +from .layers import MLP, FourierFilter, TimeInvKP, TimeVarKP + + +class Koopa(nn.Module): + """ + Paper: Koopa: Learning Non-stationary Time Series Dynamics with Koopman Predictors + Official Code: https://github.com/thuml/Koopa + Link: https://arxiv.org/abs/2305.18803 + Venue: NeurIPS 2024 + Task: Long-term Time Series Forecasting + """ + + _required_callbacks: list[type] = [KoopaMaskInitCallback] + + def __init__(self, config: KoopaConfig): + super().__init__() + self.mask_spectrum = None + self.amps = None + self.alpha = config.alpha + self.enc_in = config.enc_in + self.input_len = config.input_len + self.output_len = config.output_len + self.seg_len = config.seg_len + self.num_blocks = config.num_blocks + self.dynamic_dim = config.dynamic_dim + self.hidden_dim = config.hidden_dim + self.hidden_layers = config.hidden_layers + self.multistep = config.multistep + self.disentanglement = FourierFilter(self.mask_spectrum) + # shared encoder/decoder to make koopman embedding consistent + self.time_inv_encoder = MLP(f_in=self.input_len, f_out=self.dynamic_dim, activation='relu', + hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) + # fix: use self.output_len instead of non-existent attribute + self.time_inv_decoder = MLP(f_in=self.dynamic_dim, f_out=self.output_len, activation='relu', + hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) + # separate module lists for time-invariant and time-variant KPs + self.time_inv_kps = nn.ModuleList([ + TimeInvKP(input_len=self.input_len, + pred_len=self.output_len, + dynamic_dim=self.dynamic_dim, + encoder=self.time_inv_encoder, + decoder=self.time_inv_decoder) + for _ in range(self.num_blocks)]) + + # shared encoder/decoder to make koopman embedding consistent + self.time_var_encoder = MLP(f_in=self.seg_len * self.enc_in, f_out=self.dynamic_dim, activation='tanh', + hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) + self.time_var_decoder = MLP(f_in=self.dynamic_dim, f_out=self.seg_len * self.enc_in, activation='tanh', + hidden_dim=self.hidden_dim, hidden_layers=self.hidden_layers) + self.time_var_kps = nn.ModuleList([ + TimeVarKP(enc_in=self.enc_in, + input_len=self.input_len, + pred_len=self.output_len, + seg_len=self.seg_len, + dynamic_dim=self.dynamic_dim, + encoder=self.time_var_encoder, + decoder=self.time_var_decoder, + multistep=self.multistep) + for _ in range(self.num_blocks)]) + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Single-`inputs` forward to match runner API. + + Args: + inputs (torch.Tensor): history input with shape [B, L, C] or [B, L, C, 1] + + Returns: + torch.Tensor: prediction tensor with shape [B, output_len, num_features] (may include trailing feature dim) + """ + history_data = inputs + + if history_data.dim() == 4: + x_enc = history_data[..., 0] + elif history_data.dim() == 3: + x_enc = history_data + else: + raise ValueError(f'Unsupported inputs shape: {tuple(history_data.shape)}') + + mean_enc = x_enc.mean(1, keepdim=True).detach() + x_enc = x_enc - mean_enc + std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() + x_enc = x_enc / std_enc + if self.disentanglement is None: + raise ValueError('Koopa mask_spectrum is not initialized.') + + residual, forecast = x_enc, None + for i in range(self.num_blocks): + time_var_input, time_inv_input = self.disentanglement(residual) + time_inv_output = self.time_inv_kps[i](time_inv_input) + time_var_backcast, time_var_output = self.time_var_kps[i](time_var_input) + residual = residual - time_var_backcast + if forecast is None: + forecast = time_inv_output + time_var_output + else: + forecast += (time_inv_output + time_var_output) + res = forecast * std_enc + mean_enc + if history_data is not None and history_data.dim() == 4 and res.dim() == 3: + res = res.unsqueeze(-1) + return res diff --git a/src/basicts/models/Koopa/arch/layers.py b/src/basicts/models/Koopa/arch/layers.py new file mode 100644 index 00000000..e0d90245 --- /dev/null +++ b/src/basicts/models/Koopa/arch/layers.py @@ -0,0 +1,238 @@ +# pylint: disable=not-callable +import math + +import torch +from torch import nn + + +class FourierFilter(nn.Module): + """ + Fourier Filter: to time-variant and time-invariant term + """ + def __init__(self, mask_spectrum): + super().__init__() + self.mask_spectrum = mask_spectrum + + def forward(self, x): + xf = torch.fft.rfft(x, dim=1) + mask = torch.ones_like(xf) + mask[:, self.mask_spectrum, :] = 0 + x_var = torch.fft.irfft(xf * mask, dim=1) + x_inv = x - x_var + + return x_var, x_inv + + +class MLP(nn.Module): + ''' + Multilayer perceptron to encode/decode high dimension representation of sequential data + ''' + + def __init__(self, + f_in, + f_out, + hidden_dim=128, + hidden_layers=2, + dropout=0.05, + activation='tanh'): + super().__init__() + self.f_in = f_in + self.f_out = f_out + self.hidden_dim = hidden_dim + self.hidden_layers = hidden_layers + self.dropout = dropout + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'tanh': + self.activation = nn.Tanh() + else: + raise NotImplementedError + + layers = [nn.Linear(self.f_in, self.hidden_dim), + self.activation, nn.Dropout(self.dropout)] + for _ in range(self.hidden_layers - 2): + layers += [nn.Linear(self.hidden_dim, self.hidden_dim), + self.activation, nn.Dropout(dropout)] + + layers += [nn.Linear(hidden_dim, f_out)] + self.layers = nn.Sequential(*layers) + + def forward(self, x): + # x: B x S x f_in + # y: B x S x f_out + y = self.layers(x) + return y + + +class KPLayer(nn.Module): + """ + A demonstration of finding one step transition of linear system by DMD iteratively + """ + + def __init__(self): + super().__init__() + + self.K = None # B E E + + def one_step_forward(self, z, return_rec=False): + B, input_len, _ = z.shape + assert input_len > 1, 'snapshots number should be larger than 1' + x, y = z[:, :-1], z[:, 1:] + + # solve linear system + self.K = torch.linalg.lstsq(x, y).solution # B E E + if torch.isnan(self.K).any(): + print('Encounter K with nan, replace K by identity matrix') + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) + + z_pred = torch.bmm(z[:, -1:], self.K) + if return_rec: + z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) + return z_rec, z_pred + + return z_pred + + def forward(self, z, pred_len=1): + assert pred_len >= 1, 'prediction length should not be less than 1' + z_rec, z_pred = self.one_step_forward(z, return_rec=True) + z_preds = [z_pred] + for _ in range(1, pred_len): + z_pred = torch.bmm(z_pred, self.K) + z_preds.append(z_pred) + z_preds = torch.cat(z_preds, dim=1) + return z_rec, z_preds + + +class KPLayerApprox(nn.Module): + """ + Find koopman transition of linear system by DMD with multistep K approximation + """ + + def __init__(self): + super().__init__() + + self.K = None # B E E + self.K_step = None # B E E + + def forward(self, z, pred_len=1): + # z: B L E, koopman invariance space representation + # z_rec: B L E, reconstructed representation + # z_pred: B S E, forecasting representation + B, input_len, _ = z.shape + assert input_len > 1, 'snapshots number should be larger than 1' + x, y = z[:, :-1], z[:, 1:] + + # solve linear system + self.K = torch.linalg.lstsq(x, y).solution # B E E + + if torch.isnan(self.K).any(): + print('Encounter K with nan, replace K by identity matrix') + self.K = torch.eye(self.K.shape[1]).to(self.K.device).unsqueeze(0).repeat(B, 1, 1) + + z_rec = torch.cat((z[:, :1], torch.bmm(x, self.K)), dim=1) # B L E + + if pred_len <= input_len: + self.K_step = torch.linalg.matrix_power(self.K, pred_len) + if torch.isnan(self.K_step).any(): + print('Encounter multistep K with nan, replace it by identity matrix') + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) + z_pred = torch.bmm(z[:, -pred_len:, :], self.K_step) + else: + self.K_step = torch.linalg.matrix_power(self.K, input_len) + if torch.isnan(self.K_step).any(): + print('Encounter multistep K with nan, replace it by identity matrix') + self.K_step = torch.eye(self.K_step.shape[1]).to(self.K_step.device).unsqueeze(0).repeat(B, 1, 1) + temp_z_pred, all_pred = z, [] + for _ in range(math.ceil(pred_len / input_len)): + temp_z_pred = torch.bmm(temp_z_pred, self.K_step) + all_pred.append(temp_z_pred) + z_pred = torch.cat(all_pred, dim=1)[:, :pred_len, :] + + return z_rec, z_pred + + +class TimeVarKP(nn.Module): + """ + Koopman Predictor with DMD (analysitical solution of Koopman operator) + Utilize local variations within individual sliding window to predict the future of time-variant term + """ + + def __init__(self, + enc_in=8, + input_len=96, + pred_len=96, + seg_len=24, + dynamic_dim=128, + encoder=None, + decoder=None, + multistep=False, + ): + super().__init__() + self.input_len = input_len + self.pred_len = pred_len + self.enc_in = enc_in + self.seg_len = seg_len + self.dynamic_dim = dynamic_dim + self.multistep = multistep + self.encoder, self.decoder = encoder, decoder + self.freq = math.ceil(self.input_len / self.seg_len) # segment number of input + self.step = math.ceil(self.pred_len / self.seg_len) # segment number of output + self.padding_len = self.seg_len * self.freq - self.input_len + # Approximate mulitstep K by KPLayerApprox when pred_len is large + self.dynamics = KPLayerApprox() if self.multistep else KPLayer() + + def forward(self, x): + B, L, _ = x.shape + + res = torch.cat((x[:, L - self.padding_len:, :], x), dim=1) + + res = res.chunk(self.freq, dim=1) # F x B P C, P means seg_len + res = torch.stack(res, dim=1).reshape(B, self.freq, -1) # B F PC + + res = self.encoder(res) # B F H + x_rec, x_pred = self.dynamics(res, self.step) # B F H, B S H + + x_rec = self.decoder(x_rec) # B F PC + x_rec = x_rec.reshape(B, self.freq, self.seg_len, self.enc_in) + x_rec = x_rec.reshape(B, -1, self.enc_in)[:, :self.input_len, :] # B L C + + x_pred = self.decoder(x_pred) # B S PC + x_pred = x_pred.reshape(B, self.step, self.seg_len, self.enc_in) + x_pred = x_pred.reshape(B, -1, self.enc_in)[:, :self.pred_len, :] # B S C + + return x_rec, x_pred + + +class TimeInvKP(nn.Module): + """ + Koopman Predictor with learnable Koopman operator + Utilize lookback and forecast window snapshots to predict the future of time-invariant term + """ + + def __init__(self, + input_len=96, + pred_len=96, + dynamic_dim=128, + encoder=None, + decoder=None): + super().__init__() + self.dynamic_dim = dynamic_dim + self.input_len = input_len + self.pred_len = pred_len + self.encoder = encoder + self.decoder = decoder + + K_init = torch.randn(self.dynamic_dim, self.dynamic_dim) + U, _, V = torch.svd(K_init) # stable initialization + self.K = nn.Linear(self.dynamic_dim, self.dynamic_dim, bias=False) + self.K.weight.data = torch.mm(U, V.t()) + + def forward(self, x): + # x: B L C + res = x.transpose(1, 2) # B C L + res = self.encoder(res) # B C H + res = self.K(res) # B C H + res = self.decoder(res) # B C S + res = res.transpose(1, 2) # B S C + + return res diff --git a/src/basicts/models/Koopa/callback/koopa_mask_init.py b/src/basicts/models/Koopa/callback/koopa_mask_init.py new file mode 100644 index 00000000..1da54dff --- /dev/null +++ b/src/basicts/models/Koopa/callback/koopa_mask_init.py @@ -0,0 +1,57 @@ +# pylint: disable=not-callable +import torch +from easytorch.utils import get_logger + +from basicts.models.Koopa.arch.layers import FourierFilter +from basicts.runners.callback import BasicTSCallback + +logger = get_logger("BasicTS-training") + + +class KoopaMaskInitCallback(BasicTSCallback): + + """Callback for initializing Koopa mask during training. + + Changes made: + - Robust handling when training loader is empty. + - Ensure k >= 1 and k <= number of frequencies. + - Move mask indices and amps to model device. + - Update any existing FourierFilter module instances inside the model. + - Defensive typing of indices to torch.long. + """ + + def __init__(self, alpha: float = 0.2): + super().__init__() + self.alpha = alpha + + @torch.no_grad() + def on_train_begin(self, runner): + model = runner.model + loader = runner.train_data_loader + device = next(model.parameters()).device + + amps_sum = 0 + count = 0 + + for batch in loader: + x = batch["inputs"] + x = x.squeeze(-1) if x.dim() == 4 else x # (B, L, C) + r = torch.fft.rfft(x, dim=1) + amp = torch.abs(r).mean((0, 1)) # (F,) + amps_sum += amp + count += 1 + + if count == 0: + logger.info("No training data found, skipping mask init.") + return + + amps = amps_sum / count # (F,) + F = amps.numel() + + k = max(1, int(F * self.alpha)) + idx = torch.topk(amps, k).indices.to(device) + + model.mask_spectrum = idx + for m in model.modules(): + if isinstance(m, FourierFilter): + m.mask_spectrum = idx diff --git a/src/basicts/models/Koopa/config/koopa_config.py b/src/basicts/models/Koopa/config/koopa_config.py new file mode 100644 index 00000000..7990ea5c --- /dev/null +++ b/src/basicts/models/Koopa/config/koopa_config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass, field + +from basicts.configs import BasicTSModelConfig + + +@dataclass +class KoopaConfig(BasicTSModelConfig): + """ + Config class for Koopa model. + """ + alpha: float = field(default=0.2, metadata={"help": "Scaling coefficient."}) + enc_in: int = field(default=7, metadata={"help": "Input feature dimension."}) + input_len: int = field(default=None, metadata={"help": "Input sequence length."}) + output_len: int = field(default=None, metadata={"help": "Prediction length."}) + seg_len: int = field(default=48, metadata={"help": "Segment length. Recommended: e.g., 24 for hourly data."}) + num_blocks: int = field(default=3, metadata={"help": "Number of blocks."}) + dynamic_dim: int = field(default=64, metadata={"help": "Dynamic feature dimension. Must be > 0."}) + hidden_dim: int = field(default=64, metadata={"help": "Hidden dimension."}) + hidden_layers: int = field(default=2, metadata={"help": "Number of hidden layers (>=2 recommended)."}) + multistep: bool = field(default=False, metadata={"help": "Whether to use multistep forecasting."}) diff --git a/src/basicts/models/Leddam/arch/leddam_arch.py b/src/basicts/models/Leddam/arch/leddam_arch.py index fdf6e400..24d34963 100644 --- a/src/basicts/models/Leddam/arch/leddam_arch.py +++ b/src/basicts/models/Leddam/arch/leddam_arch.py @@ -1,9 +1,10 @@ import torch +from torch import nn + from basicts.modules import MLPLayer from basicts.modules.embed import PositionEmbedding, SequenceEmbedding from basicts.modules.norm import RevIN from basicts.modules.transformer import Encoder, MultiHeadAttention -from torch import nn from ..config.leddam_config import LeddamConfig from .leddam_layers import (AutoAttention, LearnableDecomposition, diff --git a/src/basicts/models/Leddam/arch/leddam_layers.py b/src/basicts/models/Leddam/arch/leddam_layers.py index 039a7d44..a8a5663b 100644 --- a/src/basicts/models/Leddam/arch/leddam_layers.py +++ b/src/basicts/models/Leddam/arch/leddam_layers.py @@ -3,9 +3,10 @@ import torch import torch.nn.functional as F -from basicts.modules.transformer import EncoderLayer from torch import nn +from basicts.modules.transformer import EncoderLayer + class LeddamEncoderLayer(EncoderLayer): """ diff --git a/src/basicts/models/LightTS/arch/lightts_arch.py b/src/basicts/models/LightTS/arch/lightts_arch.py index bed03486..7d54f064 100644 --- a/src/basicts/models/LightTS/arch/lightts_arch.py +++ b/src/basicts/models/LightTS/arch/lightts_arch.py @@ -1,7 +1,8 @@ import torch -from basicts.modules import MLPLayer from torch import nn +from basicts.modules import MLPLayer + from ..config.lightts_config import LightTSConfig diff --git a/src/basicts/models/MTSMixer/arch/mtsmixer_arch.py b/src/basicts/models/MTSMixer/arch/mtsmixer_arch.py index fadbf6b0..cd64b883 100644 --- a/src/basicts/models/MTSMixer/arch/mtsmixer_arch.py +++ b/src/basicts/models/MTSMixer/arch/mtsmixer_arch.py @@ -1,7 +1,8 @@ import torch -from basicts.modules.norm import RevIN from torch import nn +from basicts.modules.norm import RevIN + from ..config.mtsmixer_config import MTSMixerConfig from .mtsmixer_layers import ChannelProjection, MixerLayer diff --git a/src/basicts/models/MTSMixer/arch/mtsmixer_layers.py b/src/basicts/models/MTSMixer/arch/mtsmixer_layers.py index 1c06dc02..5394c836 100644 --- a/src/basicts/models/MTSMixer/arch/mtsmixer_layers.py +++ b/src/basicts/models/MTSMixer/arch/mtsmixer_layers.py @@ -1,7 +1,8 @@ import torch -from basicts.modules import MLPLayer from torch import nn +from basicts.modules import MLPLayer + from ..config.mtsmixer_config import MTSMixerConfig diff --git a/src/basicts/models/NLinear/arch/nlinear_arch.py b/src/basicts/models/NLinear/arch/nlinear_arch.py index 1ab8b110..f689d286 100644 --- a/src/basicts/models/NLinear/arch/nlinear_arch.py +++ b/src/basicts/models/NLinear/arch/nlinear_arch.py @@ -1,7 +1,8 @@ import torch -from basicts.models.NLinear.config.nlinear_config import NLinearConfig from torch import nn +from basicts.models.NLinear.config.nlinear_config import NLinearConfig + class NLinear(nn.Module): """ diff --git a/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_arch.py b/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_arch.py index 15438da9..4522fc3d 100644 --- a/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_arch.py +++ b/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_arch.py @@ -1,11 +1,12 @@ from typing import List, Optional, Tuple import torch +from torch import nn + from basicts.modules.activations import ACT2FN from basicts.modules.embed import FeatureEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.transformer import Encoder, Seq2SeqDecoder -from torch import nn from ..config.ns_transformer_config import NonstationaryTransformerConfig from .ns_transformer_layers import (DSAttention, @@ -91,6 +92,7 @@ def forward( torch.Tensor: outputs with shape [batch_size, output_len, num_features] """ + inputs_raw = inputs.clone().detach() # Normalization if inputs_mask is None: inputs_mask = torch.ones_like(inputs) @@ -101,10 +103,10 @@ def forward( (inputs ** 2).sum(dim=1, keepdim=True) / valid_count + 1e-5) inputs /= self.std - tau = self.tau_learner(inputs, self.std) + tau = self.tau_learner(inputs_raw, self.std) tau_clamped = torch.clamp(tau, max=self.threshold) # avoid numerical overflow tau = tau_clamped.exp() - delta = self.delta_learner(inputs, self.mean) + delta = self.delta_learner(inputs_raw, self.mean) hidden_states = self.enc_embedding(inputs, inputs_timestamps) hidden_states, attns = self.encoder(hidden_states, tau=tau, delta=delta) return hidden_states, attns, tau, delta @@ -179,6 +181,7 @@ def forward( enc_hidden_states, enc_attn_weights, tau, delta = self.backbone(inputs, inputs_timestamps) dec_hidden_states = torch.cat([inputs[:, -self.label_len:, :], torch.zeros_like(targets)], dim=1) + targets_timestamps = torch.cat([inputs_timestamps[:, -self.label_len:, :], targets_timestamps], dim=1) dec_hidden_states = self.dec_embedding(dec_hidden_states, targets_timestamps) dec_hidden_states, dec_self_attn_weights, dec_cross_attn_weights = self.decoder( dec_hidden_states, enc_hidden_states, tau=tau, delta=delta) diff --git a/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_layers.py b/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_layers.py index c4a602ac..5b08aa09 100644 --- a/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_layers.py +++ b/src/basicts/models/NonstationaryTransformer/arch/ns_transformer_layers.py @@ -1,9 +1,10 @@ from typing import Optional, Tuple import torch -from basicts.modules.transformer import EncoderLayer, Seq2SeqDecoderLayer from torch import nn +from basicts.modules.transformer import EncoderLayer, Seq2SeqDecoderLayer + class DSAttention(nn.Module): @@ -58,8 +59,9 @@ def forward( # Key/Value if is_cross: # cross-attn (typically does not use rope) - key = self._shape(self.k_proj(key_value_states), L) - value = self._shape(self.v_proj(key_value_states), L) + kv_len = key_value_states.size(1) + key = self._shape(self.k_proj(key_value_states), kv_len) + value = self._shape(self.v_proj(key_value_states), kv_len) else: # self-attn # compute key/value from hidden_states key = self._shape(self.k_proj(hidden_states), L) @@ -226,7 +228,7 @@ def __init__( super().__init__() self.series_conv = nn.Conv1d( - input_size, 1, kernel_size, padding=2, padding_mode="circular", bias=False) + input_size, 1, kernel_size, padding=1, padding_mode="circular", bias=False) layers = [nn.Linear(2 * num_features, hidden_size), nn.ReLU()] for _ in range(num_layers - 1): diff --git a/src/basicts/models/PatchTST/arch/patchtst_arch.py b/src/basicts/models/PatchTST/arch/patchtst_arch.py index a218b4b6..3fe7458a 100644 --- a/src/basicts/models/PatchTST/arch/patchtst_arch.py +++ b/src/basicts/models/PatchTST/arch/patchtst_arch.py @@ -2,13 +2,14 @@ from typing import List, Optional, Tuple import torch +from torch import nn + from basicts.modules.decomposition import MovingAverageDecomposition from basicts.modules.embed import PatchEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.norm import RevIN from basicts.modules.transformer import (Encoder, EncoderLayer, MultiHeadAttention) -from torch import nn from ..config.patchtst_config import PatchTSTConfig from .patchtst_layers import PatchTSTBatchNorm, PatchTSTHead @@ -147,10 +148,12 @@ def __init__(self, config: PatchTSTConfig): self.num_classes = config.num_classes self.backbone = PatchTSTBackbone(config) self.flatten = nn.Flatten(start_dim=1) - self.forecasting_head = PatchTSTHead( - config.num_features, self.backbone.num_patches * config.hidden_size, + self.classification_head = PatchTSTHead( + self.backbone.num_patches * config.hidden_size * config.num_features, config.num_classes, - dropout=config.head_dropout) + config.individual_head, + config.num_features, + config.head_dropout) self.use_revin = config.use_revin if self.use_revin: self.revin = RevIN( @@ -173,11 +176,9 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = self.revin(inputs, "norm") # [batch_size, num_features, num_patches, hidden_size] hidden_states, attn_weights = self.backbone(inputs) - hidden_states = self.flatten(hidden_states) # [batch_size, num_features, num_patches * hidden_size] - # [batch_size, output_len, num_features] - prediction = self.forecasting_head(hidden_states).transpose(1, 2) - if self.use_revin: - prediction = self.revin(prediction, "denorm") + hidden_states = self.flatten(hidden_states) # [batch_size, num_features * num_patches * hidden_size] + # [batch_size, num_classes] + prediction = self.classification_head(hidden_states) if self.output_attentions: return {"prediction": prediction, "attn_weights": attn_weights} else: diff --git a/src/basicts/models/SOFTS/arch/softs_arch.py b/src/basicts/models/SOFTS/arch/softs_arch.py index 8c9a2a8e..24dd3394 100644 --- a/src/basicts/models/SOFTS/arch/softs_arch.py +++ b/src/basicts/models/SOFTS/arch/softs_arch.py @@ -1,9 +1,10 @@ import torch +from torch import nn + from basicts.modules.embed import SequenceEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.norm import RevIN from basicts.modules.transformer import Encoder, EncoderLayer -from torch import nn from ..config.softs_config import SOFTSConfig from .star import STAR @@ -44,7 +45,7 @@ def __init__(self, config: SOFTSConfig): self.use_revin = config.use_revin if self.use_revin: - self.revin = RevIN() + self.revin = RevIN(affine=False) def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor) -> torch.Tensor: """ diff --git a/src/basicts/models/SOFTS/arch/star.py b/src/basicts/models/SOFTS/arch/star.py index 932cde60..8cc8eeed 100644 --- a/src/basicts/models/SOFTS/arch/star.py +++ b/src/basicts/models/SOFTS/arch/star.py @@ -1,8 +1,9 @@ import torch import torch.nn.functional as F +from torch import nn + from basicts.modules.activations import ACT2FN from basicts.modules.mlps import MLPLayer -from torch import nn class STAR(nn.Module): diff --git a/src/basicts/models/STID/arch/stid_arch.py b/src/basicts/models/STID/arch/stid_arch.py index 2627a6a0..234e08ca 100644 --- a/src/basicts/models/STID/arch/stid_arch.py +++ b/src/basicts/models/STID/arch/stid_arch.py @@ -1,7 +1,8 @@ import torch -from basicts.modules import ResMLPLayer from torch import nn +from basicts.modules import ResMLPLayer + from ..config.stid_config import STIDConfig diff --git a/src/basicts/models/STID/config/stid_config.py b/src/basicts/models/STID/config/stid_config.py index c65e8b6e..f14c9b29 100644 --- a/src/basicts/models/STID/config/stid_config.py +++ b/src/basicts/models/STID/config/stid_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Optional from basicts.configs import BasicTSModelConfig @@ -14,7 +15,7 @@ class STIDConfig(BasicTSModelConfig): output_len: int = field(default=None, metadata={"help": "Output sequence length."}) num_features: int = field(default=None, metadata={"help": "Number of features."}) input_hidden_size: int = field(default=32, metadata={"help": "Hidden size of the imput embedding."}) - intermediate_size: int | None = field(default=None, metadata={"help": "Intermediate size of MLP layers. " \ + intermediate_size: Optional[int] = field(default=None, metadata={"help": "Intermediate size of MLP layers. " \ "If None, use hidden_size in STID."}) hidden_act: str = field(default="relu", metadata={"help": "Activation function of MLP layers."}) num_layers: int = field(default=1, metadata={"help": "Number of MLP layers."}) diff --git a/src/basicts/models/SegRNN/arch/segrnn_arch.py b/src/basicts/models/SegRNN/arch/segrnn_arch.py index 177e9c0f..f6374afc 100644 --- a/src/basicts/models/SegRNN/arch/segrnn_arch.py +++ b/src/basicts/models/SegRNN/arch/segrnn_arch.py @@ -1,9 +1,10 @@ from math import ceil import torch +from torch import nn + from basicts.modules.embed import PatchEmbedding from basicts.modules.norm import RevIN -from torch import nn from ..config.segrnn_config import SegRNNConfig diff --git a/src/basicts/models/StemGNN/arch/stemgnn_arch.py b/src/basicts/models/StemGNN/arch/stemgnn_arch.py index 7e395868..b39e084e 100644 --- a/src/basicts/models/StemGNN/arch/stemgnn_arch.py +++ b/src/basicts/models/StemGNN/arch/stemgnn_arch.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable import torch import torch.nn.functional as F from torch import nn diff --git a/src/basicts/models/TiDE/__init__.py b/src/basicts/models/TiDE/__init__.py new file mode 100644 index 00000000..ad2e1e37 --- /dev/null +++ b/src/basicts/models/TiDE/__init__.py @@ -0,0 +1,2 @@ +from .arch import TiDE +from .config.tide_config import TiDEConfig diff --git a/src/basicts/models/TiDE/arch/__init__.py b/src/basicts/models/TiDE/arch/__init__.py new file mode 100644 index 00000000..2f24eb7a --- /dev/null +++ b/src/basicts/models/TiDE/arch/__init__.py @@ -0,0 +1 @@ +from .tide_arch import TiDE \ No newline at end of file diff --git a/src/basicts/models/TiDE/arch/tide_arch.py b/src/basicts/models/TiDE/arch/tide_arch.py new file mode 100644 index 00000000..dc5ba181 --- /dev/null +++ b/src/basicts/models/TiDE/arch/tide_arch.py @@ -0,0 +1,129 @@ +from typing import Optional + +import torch +from torch import nn + +from basicts.modules.norm import RevIN + +from ..config.tide_config import TiDEConfig + + +class ResBlock(nn.Module): + """ + This is the MLP-based Residual Block + """ + def __init__( + self, + input_size: int, + hidden_size: int, + output_size: int, + dropout: float = 0.1, + bias: bool = True): + super().__init__() + + self.fc1 = nn.Linear(input_size, hidden_size, bias=bias) + self.fc2 = nn.Linear(hidden_size, output_size, bias=bias) + self.fc3 = nn.Linear(input_size, output_size, bias=bias) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + self.ln = nn.LayerNorm(output_size, bias=bias) + + def forward(self, inputs): + outputs = self.fc1(inputs) + outputs = self.relu(outputs) + outputs = self.fc2(outputs) + outputs = self.dropout(outputs) + outputs = outputs + self.fc3(inputs) + outputs = self.ln(outputs) + return outputs + + +class TiDE(nn.Module): + """ + Paper: Long-term Forecasting with TiDE: Time-series Dense Encoder + Official Code: https://github.com/lich99/TiDE + Link: https://arxiv.org/abs/2304.08424 + Venue: TMLR 2023 + Task: Long-term Time Series Forecasting + """ + + def __init__(self, config: TiDEConfig): + super().__init__() + + self.input_len = config.input_len + self.num_features = config.num_features + self.output_len = config.output_len + self.hidden_size = config.hidden_size + self.num_encoder_layers = config.num_encoder_layers + self.num_decoder_layers = config.num_decoder_layers + self.intermediate_size = config.intermediate_size + self.num_timestamps = config.num_timestamps + self.timestamps_encode_size = config.timestamps_encode_size + self.dropout = config.dropout + self.bias = config.bias + self.use_revin = config.use_revin + if self.use_revin: + self.revin = RevIN(affine=False) + + flatten_dim = self.input_len + (self.input_len + self.output_len) * self.timestamps_encode_size + + self.feature_encoder = ResBlock( + self.num_timestamps, self.hidden_size, self.timestamps_encode_size, self.dropout, self.bias) + self.dense_encoders = nn.Sequential( + ResBlock(flatten_dim, self.hidden_size, self.hidden_size, self.dropout, self.bias), + *([ResBlock(self.hidden_size, self.hidden_size, self.hidden_size, self.dropout, self.bias) + for _ in range(self.num_encoder_layers - 1)])) + + self.dense_decoders = nn.Sequential( + *([ResBlock(self.hidden_size, self.hidden_size, self.hidden_size, self.dropout, self.bias) + for _ in range(self.num_decoder_layers - 1)]), + ResBlock(self.hidden_size, self.hidden_size, self.num_features * self.output_len, + self.dropout, self.bias)) + self.temporal_decoder = ResBlock( + self.num_features + self.timestamps_encode_size, self.intermediate_size, 1, + self.dropout, self.bias) + self.residual_proj = nn.Linear(self.input_len, self.output_len, bias=self.bias) + + def forward( + self, + inputs: torch.Tensor, + inputs_timestamps: Optional[torch.Tensor] = None, + targets_timestamps: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Feed forward of TiDE. + + Args: + inputs: Input data with shape: [batch_size, input_len, num_features] + inputs_timestamps: Input timestamps with shape: [batch_size, input_len, num_timestamps] + targets_timestamps: Future timestamps with shape: [batch_size, output_len, num_timestamps] + + Returns: + Output data with shape: [batch_size, output_len, num_features] + """ + # Normalization + if self.use_revin: + inputs = self.revin(inputs, "norm") + + # Timestamps + if targets_timestamps is None: + timestamps = torch.zeros((inputs.shape[0], self.input_len + self.output_len, self.num_timestamps)).to(inputs.device).detach() + else: + timestamps = torch.concat([inputs_timestamps, targets_timestamps[:, -self.output_len:, :]], dim=1) + + # Backbone + prediction = [] + for i in range(self.num_features): + x_enc = inputs[:,:,i] + feature = self.feature_encoder(timestamps) + hidden = self.dense_encoders(torch.cat([x_enc, feature.reshape(feature.shape[0], -1)], dim=-1)) + decoded = self.dense_decoders(hidden).reshape(hidden.shape[0], self.output_len, self.num_features) + dec_out = self.temporal_decoder(torch.cat([feature[:, self.output_len:], decoded], dim=-1)).squeeze( + -1) + self.residual_proj(x_enc) + prediction.append(dec_out.unsqueeze(-1)) + + # De-Normalization + prediction = torch.cat(prediction, dim=-1) + if self.use_revin: + prediction = self.revin(prediction, "denorm") + + return prediction diff --git a/src/basicts/models/TiDE/config/tide_config.py b/src/basicts/models/TiDE/config/tide_config.py new file mode 100644 index 00000000..4e323b4e --- /dev/null +++ b/src/basicts/models/TiDE/config/tide_config.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass, field + +from basicts.configs import BasicTSModelConfig + + +@dataclass +class TiDEConfig(BasicTSModelConfig): + """ + Config class for TiDE model. + """ + input_len: int = field(default=None, metadata={"help": "Input sequence length."}) + output_len: int = field(default=None, metadata={"help": "Output sequence length."}) + num_features: int = field(default=None, metadata={"help": "Number of features."}) + hidden_size: list = field(default=256, metadata={"help": "Hidden size."}) + dropout: float = field(default=0.3, metadata={"help": "Dropout rate."}) + use_revin: bool = field(default=True, metadata={"help": "Whether to use RevIN."}) + intermediate_size: int = field(default=256, metadata={"help": "Intermediate size of FFN layers."}) + num_encoder_layers: int = field(default=2, metadata={"help": "Number of encoder layers."}) + num_decoder_layers: int = field(default=2, metadata={"help": "Number of decoder layers."}) + timestamps_encode_size: int = field(default=2, metadata={"help": "Encoding size of Timestamps."}) + num_timestamps: str = field(default= 4 , metadata={"help": "Sizes of timestamps used."}) + bias: bool = field(default=True, metadata={"help": "Whether to use bias."}) diff --git a/src/basicts/models/TimeKAN/arch/timekan_arch.py b/src/basicts/models/TimeKAN/arch/timekan_arch.py index f14754d9..851b8203 100644 --- a/src/basicts/models/TimeKAN/arch/timekan_arch.py +++ b/src/basicts/models/TimeKAN/arch/timekan_arch.py @@ -1,7 +1,8 @@ import torch +from torch import nn + from basicts.modules.embed import FeatureEmbedding from basicts.modules.norm import RevIN -from torch import nn from ..config.timekan_config import TimeKANConfig from .timekan_layers import FrequencyDecompLayer, FrequencyMixingLayer diff --git a/src/basicts/models/TimeKAN/arch/timekan_layers.py b/src/basicts/models/TimeKAN/arch/timekan_layers.py index 7d5176be..9e165ecd 100644 --- a/src/basicts/models/TimeKAN/arch/timekan_layers.py +++ b/src/basicts/models/TimeKAN/arch/timekan_layers.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable from typing import Callable, List import torch diff --git a/src/basicts/models/TimeMixer/arch/mixing_layers.py b/src/basicts/models/TimeMixer/arch/mixing_layers.py index b1435de4..da5444eb 100644 --- a/src/basicts/models/TimeMixer/arch/mixing_layers.py +++ b/src/basicts/models/TimeMixer/arch/mixing_layers.py @@ -1,8 +1,9 @@ import torch +from torch import nn + from basicts.modules.decomposition import (DFTDecomposition, MovingAverageDecomposition) from basicts.modules.mlps import MLPLayer -from torch import nn from ..config.timemixer_config import TimeMixerConfig diff --git a/src/basicts/models/TimeMixer/arch/timemixer_arch.py b/src/basicts/models/TimeMixer/arch/timemixer_arch.py index 5a45a19f..4947c9fa 100644 --- a/src/basicts/models/TimeMixer/arch/timemixer_arch.py +++ b/src/basicts/models/TimeMixer/arch/timemixer_arch.py @@ -1,10 +1,11 @@ from typing import Optional import torch +from torch import nn + from basicts.modules.decomposition import MovingAverageDecomposition from basicts.modules.embed import FeatureEmbedding from basicts.modules.norm import RevIN -from torch import nn from ..config.timemixer_config import TimeMixerConfig from .mixing_layers import PastDecomposableMixing @@ -64,7 +65,7 @@ def _decomposition( trend_list.append(trend) return seasonal_list, trend_list - def _multi_scale_process_inputs( + def _prepare_multi_scale_inputs( self, inputs: torch.Tensor, inputs_timestamps: Optional[torch.Tensor] = None @@ -81,7 +82,8 @@ def _multi_scale_process_inputs( sample = down_sampled if inputs_timestamps is not None: - multi_scale_timestamps.append(sample_ts[:, :, ::self.down_sampling_window]) + multi_scale_timestamps.append( + sample_ts[:, :, ::self.down_sampling_window].permute(0, 2, 1)) sample_ts = sample_ts[:, :, ::self.down_sampling_window] return multi_scale_inputs, multi_scale_timestamps @@ -93,7 +95,7 @@ def forward(self, decomp: bool = False, ) -> tuple[list[torch.Tensor], Optional[list[torch.Tensor]]]: - x_list, x_ts_list = self._multi_scale_process_inputs(inputs, inputs_timestamps) + x_list, x_ts_list = self._prepare_multi_scale_inputs(inputs, inputs_timestamps) num_scales = len(x_list) for i in range(num_scales): diff --git a/src/basicts/models/TimeXer/arch/layers.py b/src/basicts/models/TimeXer/arch/layers.py index 3e482120..43dbece9 100644 --- a/src/basicts/models/TimeXer/arch/layers.py +++ b/src/basicts/models/TimeXer/arch/layers.py @@ -1,9 +1,10 @@ from typing import Callable, Tuple import torch -from basicts.modules.transformer import Seq2SeqDecoderLayer from torch import nn +from basicts.modules.transformer import Seq2SeqDecoderLayer + class FlattenHead(nn.Module): """ @@ -86,4 +87,4 @@ def forward( if not output_attentions: self_attn_weights = cross_attn_weights = None - return hidden_states, self_attn_weights, cross_attn_weights, None + return hidden_states, self_attn_weights, cross_attn_weights diff --git a/src/basicts/models/TimeXer/arch/timexer_arch.py b/src/basicts/models/TimeXer/arch/timexer_arch.py index ac4069da..76b6f1fb 100644 --- a/src/basicts/models/TimeXer/arch/timexer_arch.py +++ b/src/basicts/models/TimeXer/arch/timexer_arch.py @@ -1,9 +1,10 @@ import torch +from torch import nn + from basicts.modules.embed import PatchEmbedding, SequenceEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.norm import RevIN from basicts.modules.transformer import MultiHeadAttention, Seq2SeqDecoder -from torch import nn from ..config.timexer_config import TimeXerConfig from .layers import FlattenHead, TimeXerEncoderLayer @@ -93,7 +94,7 @@ def forward(self, inputs: torch.Tensor, inputs_timestamps: torch.Tensor) -> torc # add exogenous variables ex_hidden_states = self.ex_embed(inputs, inputs_timestamps) - hidden_states, self_attn_weights, cross_attn_weights, _ = self.encoder( + hidden_states, self_attn_weights, cross_attn_weights = self.encoder( hidden_states, ex_hidden_states, output_attentions=self.output_attentions) hidden_states = hidden_states.reshape( batch_size, self.num_features, self.num_patches + 1, -1) diff --git a/src/basicts/models/Timer/arch/timer_arch.py b/src/basicts/models/Timer/arch/timer_arch.py index da9c5ca1..56fcc15d 100644 --- a/src/basicts/models/Timer/arch/timer_arch.py +++ b/src/basicts/models/Timer/arch/timer_arch.py @@ -1,11 +1,12 @@ import torch +from torch import nn + from basicts.modules import MLPLayer from basicts.modules.embed import PatchEmbedding from basicts.modules.norm import RevIN from basicts.modules.transformer import (AutoRegressiveDecoder, DecoderOnlyLayer, MultiHeadAttention, prepare_causal_attention_mask) -from torch import nn from ..config import TimerConfig diff --git a/src/basicts/models/TimesNet/arch/times_block.py b/src/basicts/models/TimesNet/arch/times_block.py index f2bc8a4b..f7585f89 100644 --- a/src/basicts/models/TimesNet/arch/times_block.py +++ b/src/basicts/models/TimesNet/arch/times_block.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable import torch import torch.fft import torch.nn.functional as F diff --git a/src/basicts/models/TimesNet/arch/timesnet_arch.py b/src/basicts/models/TimesNet/arch/timesnet_arch.py index 8e7d5435..db0cd7b6 100644 --- a/src/basicts/models/TimesNet/arch/timesnet_arch.py +++ b/src/basicts/models/TimesNet/arch/timesnet_arch.py @@ -1,7 +1,8 @@ import torch +from torch import nn + from basicts.modules.embed import FeatureEmbedding from basicts.modules.norm import RevIN -from torch import nn from ..config.timesnet_config import TimesNetConfig from .times_block import TimesBlock diff --git a/src/basicts/models/iTransformer/arch/itransformer_arch.py b/src/basicts/models/iTransformer/arch/itransformer_arch.py index 3ec5d119..717d9343 100644 --- a/src/basicts/models/iTransformer/arch/itransformer_arch.py +++ b/src/basicts/models/iTransformer/arch/itransformer_arch.py @@ -1,13 +1,14 @@ from typing import List, Optional, Tuple import torch +from torch import nn + from basicts.modules.activations import ACT2FN from basicts.modules.embed import SequenceEmbedding from basicts.modules.mlps import MLPLayer from basicts.modules.norm import RevIN from basicts.modules.transformer import (Encoder, EncoderLayer, MultiHeadAttention) -from torch import nn from ..config.itransformer_config import iTransformerConfig diff --git a/src/basicts/modules/decomposition.py b/src/basicts/modules/decomposition.py index d113106c..bee8ddf9 100644 --- a/src/basicts/modules/decomposition.py +++ b/src/basicts/modules/decomposition.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable from typing import Sequence, Tuple import torch diff --git a/src/basicts/modules/transformer/attentions/auto_correlation.py b/src/basicts/modules/transformer/attentions/auto_correlation.py index 28d8f8a7..b5e9724e 100644 --- a/src/basicts/modules/transformer/attentions/auto_correlation.py +++ b/src/basicts/modules/transformer/attentions/auto_correlation.py @@ -1,3 +1,4 @@ +# pylint: disable=not-callable import math from typing import Optional, Tuple diff --git a/src/basicts/runners/basicts_runner.py b/src/basicts/runners/basicts_runner.py index aae93984..dd473e3a 100644 --- a/src/basicts/runners/basicts_runner.py +++ b/src/basicts/runners/basicts_runner.py @@ -10,9 +10,6 @@ import numpy as np import setproctitle import torch -from basicts.metrics import ALL_METRICS -from basicts.scaler import BasicTSScaler -from basicts.utils import BasicTSMode, MeterPool, RunnerStatus from easytorch.core.checkpoint import (backup_last_ckpt, clear_ckpt, load_ckpt, save_ckpt) from easytorch.device import to_device @@ -25,6 +22,10 @@ from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm +from basicts.metrics import ALL_METRICS +from basicts.scaler import BasicTSScaler +from basicts.utils import BasicTSMode, MeterPool, RunnerStatus + from .builder import Builder from .callback import BasicTSCallbackHandler # from .distributed import distributed diff --git a/src/basicts/runners/builder.py b/src/basicts/runners/builder.py index f5f84863..da359390 100644 --- a/src/basicts/runners/builder.py +++ b/src/basicts/runners/builder.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING import torch -from basicts.scaler import BasicTSScaler -from basicts.utils import BasicTSMode from easytorch.device import to_device from easytorch.utils import get_local_rank, get_world_size from easytorch.utils.data_prefetcher import DataLoaderX @@ -16,6 +14,9 @@ from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler +from basicts.scaler import BasicTSScaler +from basicts.utils import BasicTSMode + if TYPE_CHECKING: from basicts.configs import BasicTSConfig diff --git a/src/basicts/runners/callback/__init__.py b/src/basicts/runners/callback/__init__.py index cbe83303..43a671c4 100644 --- a/src/basicts/runners/callback/__init__.py +++ b/src/basicts/runners/callback/__init__.py @@ -16,5 +16,5 @@ 'EarlyStopping', 'GradAccumulation', 'NoBP', - 'SelectiveLearning', + 'SelectiveLearning' ] diff --git a/src/basicts/runners/callback/selective_learning.py b/src/basicts/runners/callback/selective_learning.py index 5f246d75..30c6a3bd 100644 --- a/src/basicts/runners/callback/selective_learning.py +++ b/src/basicts/runners/callback/selective_learning.py @@ -1,15 +1,16 @@ from typing import TYPE_CHECKING, Optional import torch -from basicts.utils import RunnerStatus from easytorch.core.checkpoint import load_ckpt from easytorch.device import to_device from easytorch.utils import get_local_rank -from torch import nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate +from basicts.configs import BasicTSModelConfig +from basicts.utils import RunnerStatus + from .callback import BasicTSCallback if TYPE_CHECKING: @@ -27,7 +28,8 @@ class SelectiveLearning(BasicTSCallback): Args: r_u (float, optional): Uncertainty mask ratio, a float in (0, 1). Default: None. r_a (float, optional): Anomaly mask ratio, a float in (0, 1). Default: None. - estimator (nn.Module, optional): Estimation model for anomaly mask. Default: None. + estimator (type, optional): Estimation model class for anomaly mask. Default: None. + estimator_config (BasicTSModelConfig, optional): Config of the estimation model. Default: None. ckpt_path (str, optional): Path to the checkpoint of the estimation model. Default: None. """ @@ -35,18 +37,24 @@ def __init__( self, r_u: Optional[float] = None, r_a: Optional[float] = None, - estimator: Optional[nn.Module] = None, + estimator: Optional[type] = None, + estimator_config: Optional[BasicTSModelConfig] = None, ckpt_path: Optional[str] = None): super().__init__() + + # config self.r_u = r_u self.r_a = r_a self.estimator = estimator + self.estimator_config = estimator_config self.ckpt_path = ckpt_path - if self.r_a is not None and self.estimator is None: + self.estimation_model = self.estimator(estimator_config) + + if self.r_a is not None and self.estimation_model is None: raise RuntimeError("Anomaly mask ratio is set but estimation model is not provided.") - if self.estimator is not None and self.ckpt_path is None: + if self.estimation_model is not None and self.ckpt_path is None: raise RuntimeError("Estimation model is set but checkpoint path is not provided.") self.history_residual: torch.Tensor = None @@ -56,6 +64,7 @@ def __init__( def on_train_start(self, runner: "BasicTSRunner"): runner.logger.info(f"Use selective learning with r_u={self.r_u}, r_a={self.r_a}.") self._load_estimator(runner) + self.estimation_model.eval() self.num_samples = len(runner.train_data_loader.dataset) runner.train_data_loader = _DataLoaderWithIndex(runner.train_data_loader) @@ -81,15 +90,16 @@ def on_compute_loss(self, runner: "BasicTSRunner", **kwargs): # Anomaly mask if self.r_a is not None: - est_fr = runner._forward(self.estimator, data, step=0) - residual_lb = torch.abs(est_fr["prediction"] - forward_return["targets"]) + with torch.no_grad(): + est_foward_return = runner._forward(self.estimation_model, data, step=0) + residual_lb = torch.abs(est_foward_return["prediction"] - forward_return["targets"]) dist = residual - residual_lb thresholds = torch.quantile( dist, self.r_a, dim=1, keepdim=True) ano_mask = dist > thresholds forward_return["targets_mask"] = forward_return["targets_mask"] * ano_mask - def on_epoch_end(self, runner, **kwargs): + def on_epoch_end(self, runner: "BasicTSRunner", **kwargs): if self.r_u is not None: res_entropy = self._compute_entropy(self.history_residual) thresholds = torch.quantile( @@ -98,13 +108,13 @@ def on_epoch_end(self, runner, **kwargs): def _load_estimator(self, runner: "BasicTSRunner"): - runner.logger.info(f"Building estimation model {self.estimator.__class__.__name__}.") - self.estimator = to_device(self.estimator) + runner.logger.info(f"Building estimation model {self.estimation_model.__class__.__name__}.") + self.estimation_model = to_device(self.estimation_model) # DDP if torch.distributed.is_initialized(): - self.estimator = DDP( - self.estimator, + self.estimation_model = DDP( + self.estimation_model, device_ids=[get_local_rank()], find_unused_parameters=runner.cfg.ddp_find_unused_parameters ) @@ -112,10 +122,10 @@ def _load_estimator(self, runner: "BasicTSRunner"): # load model weights try: checkpoint_dict = load_ckpt(None, ckpt_path=self.ckpt_path, logger=runner.logger) - if isinstance(self.estimator, DDP): - self.estimator.module.load_state_dict(checkpoint_dict["model_state_dict"]) + if isinstance(self.estimation_model, DDP): + self.estimation_model.module.load_state_dict(checkpoint_dict["model_state_dict"]) else: - self.estimator.load_state_dict(checkpoint_dict["model_state_dict"]) + self.estimation_model.load_state_dict(checkpoint_dict["model_state_dict"]) except (IndexError, OSError) as e: raise OSError(f"Ckpt file {self.ckpt_path} does not exist") from e diff --git a/src/basicts/runners/taskflow/classification_taskflow.py b/src/basicts/runners/taskflow/classification_taskflow.py index 70814005..e90c31cd 100644 --- a/src/basicts/runners/taskflow/classification_taskflow.py +++ b/src/basicts/runners/taskflow/classification_taskflow.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict import torch + from basicts.utils.mask import null_val_mask from .basicts_taskflow import BasicTSTaskFlow diff --git a/src/basicts/runners/taskflow/forecasting_taskflow.py b/src/basicts/runners/taskflow/forecasting_taskflow.py index b21140af..cacf9747 100644 --- a/src/basicts/runners/taskflow/forecasting_taskflow.py +++ b/src/basicts/runners/taskflow/forecasting_taskflow.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict import torch + from basicts.utils.mask import null_val_mask from .basicts_taskflow import BasicTSTaskFlow diff --git a/src/basicts/runners/taskflow/imputation_taskflow.py b/src/basicts/runners/taskflow/imputation_taskflow.py index 38a0f6b1..b8db8901 100644 --- a/src/basicts/runners/taskflow/imputation_taskflow.py +++ b/src/basicts/runners/taskflow/imputation_taskflow.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict import torch + from basicts.utils.mask import null_val_mask, reconstruction_mask from .basicts_taskflow import BasicTSTaskFlow diff --git a/src/test.py b/src/test.py new file mode 100644 index 00000000..d54a8041 --- /dev/null +++ b/src/test.py @@ -0,0 +1,31 @@ +from basicts.models.DLinear import DLinear, DLinearConfig +from basicts.models.iTransformer import iTransformerConfig, iTransformerForForecasting +from basicts import BasicTSLauncher +from basicts.configs import BasicTSForecastingConfig +from basicts.runners.callback import SelectiveLearning + + +if __name__ == "__main__": + + cb = SelectiveLearning( + r_u=0.3, + r_a=0.3, + estimator=DLinear, + estimator_config=DLinearConfig(input_len=336, output_len=336), + ckpt_path="checkpoints/DLinear/ETTh1_100_336_336/1f037d3a0fb4a6de40ce3dcb2656b136/DLinear_best_val_MSE.pt" + ) + + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=iTransformerForForecasting, + input_len=336, + output_len=336, + use_timestamps=False, + model_config=iTransformerConfig( + input_len=336, + output_len=336, + num_features=7), + dataset_name="ETTh1", + gpus="0", + callbacks=[cb], + )) \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 00000000..8fdd8e95 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1 @@ +einops \ No newline at end of file diff --git a/tests/smoke_test/datasets/ETTh2_mini/meta.json b/tests/smoke_test/datasets/ETTh2_mini/meta.json new file mode 100644 index 00000000..7ed90f3b --- /dev/null +++ b/tests/smoke_test/datasets/ETTh2_mini/meta.json @@ -0,0 +1,36 @@ +{ + "name": "ETTh2_mini", + "domain": "electricity transformer temperature", + "frequency (minutes)": 60, + "shape": [ + 1344, + 7 + ], + "timestamps_shape": [ + 1344, + 4 + ], + "timestamps_description": [ + "time of day", + "day of week", + "day of month", + "day of year" + ], + "num_time_steps": 1344, + "num_vars": 7, + "has_graph": false, + "regular_settings": { + "train_val_test_ratio": [ + 0.6, + 0.2, + 0.2 + ], + "norm_each_channel": true, + "rescale": false, + "metrics": [ + "MAE", + "MSE" + ], + "null_val": NaN + } +} \ No newline at end of file diff --git a/tests/smoke_test/datasets/ETTh2_mini/test_data.npy b/tests/smoke_test/datasets/ETTh2_mini/test_data.npy new file mode 100644 index 00000000..546d8289 Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/test_data.npy differ diff --git a/tests/smoke_test/datasets/ETTh2_mini/test_timestamps.npy b/tests/smoke_test/datasets/ETTh2_mini/test_timestamps.npy new file mode 100644 index 00000000..a54465aa Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/test_timestamps.npy differ diff --git a/tests/smoke_test/datasets/ETTh2_mini/train_data.npy b/tests/smoke_test/datasets/ETTh2_mini/train_data.npy new file mode 100644 index 00000000..6a217bd1 Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/train_data.npy differ diff --git a/tests/smoke_test/datasets/ETTh2_mini/train_timestamps.npy b/tests/smoke_test/datasets/ETTh2_mini/train_timestamps.npy new file mode 100644 index 00000000..f6804ebc Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/train_timestamps.npy differ diff --git a/tests/smoke_test/datasets/ETTh2_mini/val_data.npy b/tests/smoke_test/datasets/ETTh2_mini/val_data.npy new file mode 100644 index 00000000..85db42d1 Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/val_data.npy differ diff --git a/tests/smoke_test/datasets/ETTh2_mini/val_timestamps.npy b/tests/smoke_test/datasets/ETTh2_mini/val_timestamps.npy new file mode 100644 index 00000000..ecf417d3 Binary files /dev/null and b/tests/smoke_test/datasets/ETTh2_mini/val_timestamps.npy differ diff --git a/tests/smoke_test/test_crossformer.py b/tests/smoke_test/test_crossformer.py new file mode 100644 index 00000000..0ac7c6c7 --- /dev/null +++ b/tests/smoke_test/test_crossformer.py @@ -0,0 +1,41 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.Crossformer import Crossformer, CrossformerConfig + + +def test_crossformer_smoke_test(): + output_len = 24 + input_len = 96 + crossformer_config = CrossformerConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + patch_len=24, + hidden_size=256, + intermediate_size=512, + n_heads=4, + dropout=0.2, + baseline=False, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=Crossformer, + dataset_name="ETTh1_mini", + model_config=crossformer_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001, + ) + ) + +if __name__ == "__main__": + test_crossformer_smoke_test() diff --git a/tests/smoke_test/test_dlinear.py b/tests/smoke_test/test_dlinear.py index f3447bde..f21dc3cd 100644 --- a/tests/smoke_test/test_dlinear.py +++ b/tests/smoke_test/test_dlinear.py @@ -8,8 +8,7 @@ from basicts.configs import BasicTSForecastingConfig from basicts.launcher import BasicTSLauncher -from basicts.models.DLinear import DLinear -from basicts.models.DLinear.config.dlinear_config import DLinearConfig +from basicts.models.DLinear import DLinear, DLinearConfig def test_dlinear_smoke_test(): diff --git a/tests/smoke_test/test_duet.py b/tests/smoke_test/test_duet.py new file mode 100644 index 00000000..bcf6727d --- /dev/null +++ b/tests/smoke_test/test_duet.py @@ -0,0 +1,37 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.DUET import DUET, DUETConfig +from basicts.runners.callback import AddAuxiliaryLoss + + +def test_duet_smoke_test(): + output_len = 24 + input_len = 96 + duet_config = DUETConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=DUET, + dataset_name="ETTh1_mini", + model_config=duet_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001, + callbacks=[AddAuxiliaryLoss()] + ) + ) + +if __name__ == "__main__": + test_duet_smoke_test() diff --git a/tests/smoke_test/test_fits.py b/tests/smoke_test/test_fits.py new file mode 100644 index 00000000..49f13ab8 --- /dev/null +++ b/tests/smoke_test/test_fits.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.FITS import FITS, FITSConfig + + +def test_fits_smoke_test(): + output_len = 24 + input_len = 96 + fits_config = FITSConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=FITS, + dataset_name="ETTh1_mini", + model_config=fits_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_fits_smoke_test() diff --git a/tests/smoke_test/test_frets.py b/tests/smoke_test/test_frets.py new file mode 100644 index 00000000..b4eb3797 --- /dev/null +++ b/tests/smoke_test/test_frets.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.FreTS import FreTS, FreTSConfig + + +def test_frets_smoke_test(): + output_len = 24 + input_len = 96 + frets_config = FreTSConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=FreTS, + dataset_name="ETTh1_mini", + model_config=frets_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_frets_smoke_test() diff --git a/tests/smoke_test/test_informer.py b/tests/smoke_test/test_informer.py new file mode 100644 index 00000000..1c267ce2 --- /dev/null +++ b/tests/smoke_test/test_informer.py @@ -0,0 +1,41 @@ +# pylint: disable=wrong-import-position + +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts import BasicTSLauncher +from basicts.configs import BasicTSForecastingConfig +from basicts.models.Informer import Informer, InformerConfig + + +def test_informer_smoke_test(): + + output_len = 48 + input_len = 96 + informer_config = InformerConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + label_len=input_len / 2, + use_timestamps=True, + timestamp_sizes=[24, 7, 31, 366], + + ) + + BasicTSLauncher.launch_training(BasicTSForecastingConfig( + model=Informer, + dataset_name="ETTh1_mini", + model_config=informer_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + )) + + +if __name__ == "__main__": + test_informer_smoke_test() diff --git a/tests/smoke_test/test_koopa.py b/tests/smoke_test/test_koopa.py new file mode 100644 index 00000000..d22480ad --- /dev/null +++ b/tests/smoke_test/test_koopa.py @@ -0,0 +1,41 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.Koopa import Koopa, KoopaConfig +from basicts.models.Koopa.callback.koopa_mask_init import KoopaMaskInitCallback + + +def test_koopa_smoke_test(): + output_len = 48 + input_len = 96 + koopa_config = KoopaConfig( + input_len=input_len, + output_len=output_len, + enc_in=7, + seg_len=48, + dynamic_dim=64, + hidden_dim=512, + num_blocks=4 + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=Koopa, + dataset_name="ETTh2_mini", + model_config=koopa_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001, + callbacks=[KoopaMaskInitCallback(alpha=0.2)], + ) + ) + +if __name__ == "__main__": + test_koopa_smoke_test() diff --git a/tests/smoke_test/test_leddam.py b/tests/smoke_test/test_leddam.py new file mode 100644 index 00000000..86212bce --- /dev/null +++ b/tests/smoke_test/test_leddam.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.Leddam import Leddam, LeddamConfig + + +def test_leddam_smoke_test(): + output_len = 24 + input_len = 96 + leddam_config = LeddamConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=Leddam, + dataset_name="ETTh1_mini", + model_config=leddam_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_leddam_smoke_test() diff --git a/tests/smoke_test/test_lightts.py b/tests/smoke_test/test_lightts.py new file mode 100644 index 00000000..2fd00803 --- /dev/null +++ b/tests/smoke_test/test_lightts.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.LightTS import LightTS, LightTSConfig + + +def test_lightts_smoke_test(): + output_len = 24 + input_len = 96 + lightts_config = LightTSConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=LightTS, + dataset_name="ETTh1_mini", + model_config=lightts_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_lightts_smoke_test() diff --git a/tests/smoke_test/test_mtsmixer.py b/tests/smoke_test/test_mtsmixer.py new file mode 100644 index 00000000..0975a977 --- /dev/null +++ b/tests/smoke_test/test_mtsmixer.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.MTSMixer import MTSMixer, MTSMixerConfig + + +def test_mtsmixer_smoke_test(): + output_len = 24 + input_len = 96 + mtsmixer_config = MTSMixerConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=MTSMixer, + dataset_name="ETTh1_mini", + model_config=mtsmixer_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_mtsmixer_smoke_test() diff --git a/tests/smoke_test/test_nlinear.py b/tests/smoke_test/test_nlinear.py new file mode 100644 index 00000000..d0eef7f8 --- /dev/null +++ b/tests/smoke_test/test_nlinear.py @@ -0,0 +1,34 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.NLinear import NLinear, NLinearConfig + + +def test_nlinear_smoke_test(): + output_len = 24 + input_len = 96 + nlinear_config = NLinearConfig( + input_len=input_len, + output_len=output_len, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=NLinear, + dataset_name="ETTh1_mini", + model_config=nlinear_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_nlinear_smoke_test() diff --git a/tests/smoke_test/test_patchtstforf.py b/tests/smoke_test/test_patchtstforf.py new file mode 100644 index 00000000..7e0de93f --- /dev/null +++ b/tests/smoke_test/test_patchtstforf.py @@ -0,0 +1,36 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.PatchTST import PatchTSTConfig, PatchTSTForForecasting + + +def test_patchtst_smoke_test(): + output_len = 24 + input_len = 96 + patchtst_config = PatchTSTConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + intermediate_size=128, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=PatchTSTForForecasting, + dataset_name="ETTh1_mini", + model_config=patchtst_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_patchtst_smoke_test() diff --git a/tests/smoke_test/test_patchtstforr.py b/tests/smoke_test/test_patchtstforr.py new file mode 100644 index 00000000..9373e887 --- /dev/null +++ b/tests/smoke_test/test_patchtstforr.py @@ -0,0 +1,34 @@ +# pylint: disable=wrong-import-position + +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts import BasicTSLauncher +from basicts.configs import BasicTSImputationConfig +from basicts.models.PatchTST import PatchTSTConfig, PatchTSTForReconstruction + + +def test_patchtstforr_smoke_test(): + input_len=32 + model_config = PatchTSTConfig( + input_len=input_len, + num_features=7 + ) + + BasicTSLauncher.launch_training(BasicTSImputationConfig( + model=PatchTSTForReconstruction, + model_config=model_config, + dataset_name="ETTh1_mini", + mask_ratio=0.25, + gpus=None, + batch_size=16, + input_len=input_len, + num_epochs=1, + )) + + +if __name__ == "__main__": + test_patchtstforr_smoke_test() diff --git a/tests/smoke_test/test_segrnn.py b/tests/smoke_test/test_segrnn.py new file mode 100644 index 00000000..ecdbf7a5 --- /dev/null +++ b/tests/smoke_test/test_segrnn.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.SegRNN import SegRNN, SegRNNConfig + + +def test_segrnn_smoke_test(): + output_len = 24 + input_len = 96 + segrnn_config = SegRNNConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=SegRNN, + dataset_name="ETTh1_mini", + model_config=segrnn_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_segrnn_smoke_test() diff --git a/tests/smoke_test/test_sparsetsf.py b/tests/smoke_test/test_sparsetsf.py new file mode 100644 index 00000000..330a4b48 --- /dev/null +++ b/tests/smoke_test/test_sparsetsf.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.SparseTSF import SparseTSF, SparseTSFConfig + + +def test_sparsetsf_smoke_test(): + output_len = 24 + input_len = 96 + sparsetsf_config = SparseTSFConfig( + input_len=input_len, + output_len=output_len, + period_len=24, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=SparseTSF, + dataset_name="ETTh1_mini", + model_config=sparsetsf_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_sparsetsf_smoke_test() diff --git a/tests/smoke_test/test_stemgnn.py b/tests/smoke_test/test_stemgnn.py new file mode 100644 index 00000000..c4905b52 --- /dev/null +++ b/tests/smoke_test/test_stemgnn.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.StemGNN import StemGNN, StemGNNConfig + + +def test_stemgnn_smoke_test(): + output_len = 24 + input_len = 96 + stemgnn_config = StemGNNConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=StemGNN, + dataset_name="ETTh1_mini", + model_config=stemgnn_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_stemgnn_smoke_test() diff --git a/tests/smoke_test/test_stid.py b/tests/smoke_test/test_stid.py new file mode 100644 index 00000000..1d445957 --- /dev/null +++ b/tests/smoke_test/test_stid.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.STID import STID, STIDConfig + + +def test_stid_smoke_test(): + output_len = 24 + input_len = 96 + stid_config = STIDConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=STID, + dataset_name="ETTh1_mini", + model_config=stid_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_stid_smoke_test() diff --git a/tests/smoke_test/test_timekan.py b/tests/smoke_test/test_timekan.py new file mode 100644 index 00000000..30c05616 --- /dev/null +++ b/tests/smoke_test/test_timekan.py @@ -0,0 +1,35 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.TimeKAN import TimeKAN, TimeKANConfig + + +def test_timekan_smoke_test(): + output_len = 24 + input_len = 96 + timekan_config = TimeKANConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=TimeKAN, + dataset_name="ETTh1_mini", + model_config=timekan_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_timekan_smoke_test() diff --git a/tests/smoke_test/test_timesnet.py b/tests/smoke_test/test_timesnet.py new file mode 100644 index 00000000..ed6d7e80 --- /dev/null +++ b/tests/smoke_test/test_timesnet.py @@ -0,0 +1,39 @@ +# pylint: disable=wrong-import-position +import os +import sys + +sys.path.append(os.path.abspath(__file__ + "/../../../src/")) +os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__)))) + +from basicts.configs import BasicTSForecastingConfig +from basicts.launcher import BasicTSLauncher +from basicts.models.TimesNet import TimesNetConfig, TimesNetForForecasting + + +def test_timesnet_smoke_test(): + output_len = 24 + input_len = 96 + timesnet_config = TimesNetConfig( + input_len=input_len, + output_len=output_len, + num_features=7, + use_timestamps=True, + timestamp_sizes=[24, 7, 31, 366], + hidden_size=16, + intermediate_size=64, + ) + BasicTSLauncher.launch_training( + BasicTSForecastingConfig( + model=TimesNetForForecasting, + dataset_name="ETTh1_mini", + model_config=timesnet_config, + gpus=None, + num_epochs=1, + input_len=input_len, + output_len=output_len, + lr=0.001 + ) + ) + +if __name__ == "__main__": + test_timesnet_smoke_test()