diff --git a/.gitignore b/.gitignore index 49c224a..36f5d35 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,7 @@ output/ *.csv *.pkl -.DS_Store \ No newline at end of file +.DS_Store + +# tftmodel (WIP folder) +/tftmodel/ \ No newline at end of file diff --git a/mlstars/custom/tftmodel.py b/mlstars/custom/tftmodel.py new file mode 100644 index 0000000..38c50f2 --- /dev/null +++ b/mlstars/custom/tftmodel.py @@ -0,0 +1,186 @@ +from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer +from pytorch_forecasting.metrics import RMSE +from lightning.pytorch.loggers import WandbLogger +import pytorch_lightning as pl +import wandb as wandb +from torch.optim import Adam +from datetime import datetime +import numpy as np +import pandas as pd +import torch +import csv + +TRAIN_PCT = 0.7 +VAL_PCT = 0.85 +DATA_PATH = "/mnt/zephyr/modeling/tft_format.csv" +DF = pd.read_csv(DATA_PATH, + header=0, + date_format="%Y-%m-%d %H:%M:%S", + parse_dates=True, + keep_date_col=False) +COLUMN_NAMES = ["time_idx","time","label","turbine","WGEN_IntTmp_max","WGEN_IntTmp_mean","WGEN_IntTmp_min","WGEN_IntTmp_sd",'WGEN_ClSt_max',"WGEN_ClSt_mean",'WGEN_ClSt_min',"WCNV_Torq_mean","WCNV_Torq_max","WGEN_Spd_max","WGEN_Spd_mean","WGEN_Spd_min","WGEN_W_mean","WGEN_W_max","WCNV_Torq_sd","WGEN_Spd_sd","WGEN_W_min","WGEN_W_sd","WTOW_PlatTmp_mean","WNAC_WdSpd1_mean","WNAC_WdSpd1_max","WNAC_WdSpd2_max","WNAC_WdSpd2_mean","WNAC_WdDir1_mean","WNAC_WdDir2_mean","WNAC_ExtPres_mean","WNAC_WdSpdAvg_mean","WNAC_WdSpdAvg_max","WNAC_WdSpdAvg_min","WNAC_IntTmp_mean","WNAC_IntTmp_max","WNAC_IntTmp_min",'WNAC_IntTmp_sd',"WNAC_Vib1_max","WNAC_Vib1_min",'WNAC_Vib1_mean','WNAC_Vib2_max',"WNAC_Vib2_mean","WNAC_Vib2_min","WNAC_Vib3_max","WNAC_Vib3_mean","WNAC_Vib3_min","WNAC_Vib4_max","WNAC_Vib4_mean","WNAC_Vib4_min","WTRM_FtrPres1_mean","WTRM_FtrPres1_max","WTRM_FtrPres2_max","WTRM_FtrPres2_mean",'WTRM_FtrPres1_min',"WTRM_FtrPres2_min","WROT_Spd1_mean","WROT_Spd1_max","WROT_Spd1_min","WROT_Spd1_sd","WROT_Spd2_max","WROT_Spd2_mean","WROT_Spd2_min","WROT_Spd2_sd","WROT_Pos_max","WROT_Pos_mean","WROT_Pos_min","WROT_MnBrgTemp1_max","WROT_MnBrgTemp1_mean","WROT_MnBrgTemp1_min","WROT_MnBrgTemp2_max","WROT_MnBrgTemp2_mean","WROT_MnBrgTemp2_min","WTRM_HyFtrPres1_mean","WTRM_HyFtrPres1_max","WTRM_HyFtrPres1_min","WTRM_HyFtrPres1_sd","WTRM_HyFtrPres2_max","WTRM_HyFtrPres2_mean","WTRM_HyFtrPres2_min","WTRM_HyFtrPres2_sd","WTRM_HySysPres1_max","WTRM_HySysPres1_mean","WTRM_HySysPres1_min","WTRM_HySysLockPres1_max","WTRM_HySysLockPres1_mean","WTRM_HySysLockPres1_sd","WTRM_HySysLockPres1_min","WROT_LockPos1_mean","WROT_LockPos1_max","WROT_LockPos1_min","WROT_LockPos1_sd","WROT_LockPos2_max","WROT_LockPos2_mean","WROT_LockPos2_min","WROT_LockPos2_sd","WROT_LockPos3_max","WROT_LockPos3_mean","WROT_LockPos3_min","WROT_LockPos3_sd","WROT_Brk2HyTmp6_sd","WROT_Brk2HyTmp6_min","WROT_Brk2HyTmp6_mean","WROT_Brk2HyTmp6_max","WROT_Brk2HyTmp5_sd","WROT_Brk2HyTmp5_min","WROT_Brk2HyTmp5_mean","WROT_Brk2HyTmp5_max","WROT_Brk1HyTmp6_min","WROT_Brk1HyTmp6_sd","WROT_Brk1HyTmp6_mean","WROT_Brk1HyTmp6_max","WROT_Brk1HyTmp5_sd","WROT_Brk1HyTmp5_min","WROT_Brk1HyTmp5_mean","WROT_Brk1HyTmp5_max","WROT_Brk2HyTmp4_sd","WROT_Brk2HyTmp4_min","WROT_Brk2HyTmp4_max","WROT_Brk2HyTmp3_sd","WROT_Brk2HyTmp4_mean","WROT_Brk2HyTmp3_min","WROT_Brk2HyTmp3_mean","WROT_Brk2HyTmp3_max","WROT_Brk1HyTmp4_sd","WROT_Brk1HyTmp4_min","WROT_Brk1HyTmp4_mean","WROT_Brk1HyTmp4_max","WROT_Brk1HyTmp3_sd","WROT_Brk1HyTmp3_min","WROT_Brk1HyTmp3_mean","WROT_Brk1HyTmp3_max","WROT_Brk2HyTmp2_sd","WROT_Brk2HyTmp2_min","WROT_Brk2HyTmp2_mean","WROT_Brk2HyTmp2_max","WROT_HyOilTmp1_sd","WROT_HyOilTmp1_min","WROT_HyOilTmp1_mean","WROT_HyOilTmp1_max","WROT_Brk2HyTmp1_sd","WROT_Brk2HyTmp1_min","WROT_Brk2HyTmp1_mean","WROT_Brk2HyTmp1_max","WROT_Brk1HyTmp2_sd","WROT_Brk1HyTmp2_min","WROT_Brk1HyTmp2_mean","WROT_Brk1HyTmp2_max","WROT_Brk1HyTmp1_sd","WROT_Brk1HyTmp1_min","WROT_Brk1HyTmp1_mean","WROT_Brk1HyTmp1_max","WROT_Brk1HyPres_max","WROT_Brk1HyPres_mean","WROT_Brk1HyPres_min","WROT_Brk1HyPres_sd","WROT_Brk2HyPres_max","WROT_Brk2HyPres_mean","WROT_Brk2HyPres_min","WROT_Brk2HyPres_sd","WROT_Brk1HyAccPres_max","WROT_Brk1HyAccPres_mean","WROT_Brk1HyAccPres_min","WROT_Brk1HyAccPres_sd","WROT_Brk2HyAccPres_max","WROT_Brk2HyAccPres_mean","WROT_Brk2HyAccPres_min","WROT_Brk2HyAccPres_sd","groupconst"] + +def format_data(DF: pd.DataFrame): + # Copy dataframe + df = DF + + # Ensure 'turbine' is a categorical column, turn 'label' from bool to float + df['label'] = df['label'].astype(float) + df['turbine'] = df['turbine'].astype('category') + + # Create a 'time_idx' column that resets for each unique 'group_id' + df['time'] = pd.to_datetime(df['time'], format="%Y-%m-%d %H:%M:%S") + df = df.sort_values(by=['time', 'turbine']) + df.insert(0, 'time_idx', df.groupby('turbine').cumcount()) + df['groupconst'] = 0 + + # DEBUGGING + f = open("/home/boom90lb/Zephyr/tftmodel/tft_formatted.csv", "w") + f.write(df.to_csv(None, index=False)) + f.close() + + training_cutoff = np.floor(df['time_idx'].max()*TRAIN_PCT) + validation_cutoff = np.floor(df['time_idx'].max()*VAL_PCT) + + return df, training_cutoff, validation_cutoff + +def get_column_names(data_path="."): + with open(data_path) as csv_file: + csv_reader = csv.reader(csv_file) + for row in csv_reader: + return row + +TSDS_PARAMS = { + "time_idx": "time_idx", + "target": "label", + "group_ids": ["turbine"], + "min_encoder_length": 1, + "max_encoder_length": 16, + "min_prediction_length": 1, + "max_prediction_length": 8, + "static_categoricals": [], + "static_reals": ["groupconst"], + "time_varying_known_reals": [], + "time_varying_unknown_reals": COLUMN_NAMES[6:-2] + ["label"], + "time_varying_unknown_categoricals": [], + "allow_missing_timesteps": True, +} + +TFT_PARAMS = { + "hidden_size": 16, + "lstm_layers": 2, + "dropout": 0.2, + "output_size": 1, + "loss": RMSE(), + "attention_head_size": 4, + "max_encoder_length": 32, + "allowed_encoder_known_variable_names": COLUMN_NAMES[6:-2], + "hidden_continuous_size": 8, + "learning_rate": 0.005, + "log_interval": 10, + "optimizer": "Adam", + "log_val_interval": 1, + "reduce_on_plateau_patience": 4, + "monotone_constaints": {}, + "share_single_variable_networks": False, + "causal_attention": True, +} + +TRAINER_PARAMS = { + "max_epochs": 10, + "accelerator": "auto" +} + +class TFTLightningModule(pl.LightningModule): + def __init__(self, df, **kwargs): + super().__init__() + self.model = TemporalFusionTransformer.from_dataset(df, **kwargs) + self.automatic_optimization = False + self.optimizers = self.configure_optimizers() + self.criterion = RMSE() + wandb.save_hyperparameters() + + + def forward(self, x): + wandb.log('x_out', self.model.forward(x)) + return self.model.forward(x) + + def training_step(self, batch): + opt: Adam = self.optimizers + x, y = batch + + opt.zero_grad() + + out = self.forward(x) + loss = self.criterion.loss(out, y) + + self.model.backward(loss) + opt.step() + + wandb.log('train_loss', loss) + + return loss + + def configure_optimizers(self): + return Adam(self.model.parameters(), lr=1e-3) + +def train(training_data: TimeSeriesDataSet, validation_data: TimeSeriesDataSet): + model2 = TFTLightningModule(training_data, **TFT_PARAMS) + wandb.watch(model2) + training_loader = training_data.to_dataloader(train=True, batch_size=8, num_workers=4) + validation_loader = validation_data.to_dataloader(train=False, batch_size=8, num_workers=4) + + wandb_logger = WandbLogger(log_model="all") + trainer = pl.Trainer(logger=wandb_logger, **TRAINER_PARAMS) + trainer.fit(model2, training_loader, validation_loader) + + return trainer + +def eval(training_data): + frozen_model = TFTLightningModule(training_data, **TFT_PARAMS) + + frozen_model.load_state_dict(torch.load("/home/boom90lb/Zephyr/tftmodel/checkpoints/saves/11-23-2023;22-49-11.ckpt")) + + frozen_model.eval() + + training_dataloader = training_data.to_dataloader(False, 4) + + loss_running_sum = 0 + num_batches = 0 + + for batch, i in enumerate(training_dataloader): + x, y = batch + out = frozen_model.forward(x) + loss_fn = RMSE() + + loss = loss_fn.loss(out, y) + print(loss) + loss_running_sum += loss + + if i > num_batches: + num_batches = i + + return loss_running_sum/num_batches + + +def main(): + df, training_cutoff, validation_cutoff = format_data(DF) + + training_data = TimeSeriesDataSet(df.loc[df['time_idx'] <= training_cutoff], **TSDS_PARAMS) + + validation_data = TimeSeriesDataSet(df.loc[df['time_idx'] <= validation_cutoff], **TSDS_PARAMS, min_prediction_idx=training_cutoff + 1) + + test_data = TimeSeriesDataSet(df, **TSDS_PARAMS, min_prediction_idx=validation_cutoff + 1) + + # prev_eval = eval(training_data) + + trained_model = train(training_data, validation_data) + + trained_model.save_checkpoint("/home/boom90lb/Zephyr/tftmodel/checkpoints/model" + datetime.now().strftime("%m-%d-%Y;%H-%M-%S") + ".ckpt") + + torch.save(trained_model.model.state_dict(), "/home/boom90lb/Zephyr/tftmodel/checkpoints/saves/" + datetime.now().strftime("%m-%d-%Y;%H-%M-%S") + ".ckpt") + + wandb.finish() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/mlstars/custom/tt.ipynb b/mlstars/custom/tt.ipynb new file mode 100644 index 0000000..7f6db5b --- /dev/null +++ b/mlstars/custom/tt.ipynb @@ -0,0 +1,640 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: wandb in ./lib/python3.10/site-packages (0.16.0)\n", + "Requirement already satisfied: PyYAML in ./lib/python3.10/site-packages (from wandb) (6.0.1)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in ./lib/python3.10/site-packages (from wandb) (4.25.1)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in ./lib/python3.10/site-packages (from wandb) (8.1.7)\n", + "Requirement already satisfied: setuptools in ./lib/python3.10/site-packages (from wandb) (59.6.0)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in ./lib/python3.10/site-packages (from wandb) (0.4.0)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in ./lib/python3.10/site-packages (from wandb) (3.1.40)\n", + "Requirement already satisfied: psutil>=5.0.0 in ./lib/python3.10/site-packages (from wandb) (5.9.6)\n", + "Requirement already satisfied: setproctitle in ./lib/python3.10/site-packages (from wandb) (1.3.3)\n", + "Requirement already satisfied: requests<3,>=2.0.0 in ./lib/python3.10/site-packages (from wandb) (2.31.0)\n", + "Requirement already satisfied: appdirs>=1.4.3 in ./lib/python3.10/site-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in ./lib/python3.10/site-packages (from wandb) (1.37.1)\n", + "Requirement already satisfied: six>=1.4.0 in ./lib/python3.10/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in ./lib/python3.10/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (3.4)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (3.3.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2023.11.17)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./lib/python3.10/site-packages (from requests<3,>=2.0.0->wandb) (2.1.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in ./lib/python3.10/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: pytorch-forecasting in ./lib/python3.10/site-packages (1.0.0)\n", + "Requirement already satisfied: matplotlib in ./lib/python3.10/site-packages (from pytorch-forecasting) (3.8.2)\n", + "Requirement already satisfied: optuna<4.0.0,>=3.1.0 in ./lib/python3.10/site-packages (from pytorch-forecasting) (3.4.0)\n", + "Requirement already satisfied: lightning<3.0.0,>=2.0.0 in ./lib/python3.10/site-packages (from pytorch-forecasting) (2.1.2)\n", + "Requirement already satisfied: torch<3.0.0,>=2.0.0 in ./lib/python3.10/site-packages (from pytorch-forecasting) (2.1.1)\n", + "Requirement already satisfied: scikit-learn<2.0,>=1.2 in ./lib/python3.10/site-packages (from pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: fastapi>=0.80 in ./lib/python3.10/site-packages (from pytorch-forecasting) (0.104.1)\n", + "Requirement already satisfied: pytorch-optimizer<3.0.0,>=2.5.1 in ./lib/python3.10/site-packages (from pytorch-forecasting) (2.12.0)\n", + "Requirement already satisfied: scipy<2.0,>=1.8 in ./lib/python3.10/site-packages (from pytorch-forecasting) (1.11.4)\n", + "Requirement already satisfied: pandas<=3.0.0,>=1.3.0 in ./lib/python3.10/site-packages (from pytorch-forecasting) (2.1.3)\n", + "Requirement already satisfied: statsmodels in ./lib/python3.10/site-packages (from pytorch-forecasting) (0.14.0)\n", + "Requirement already satisfied: anyio<4.0.0,>=3.7.1 in ./lib/python3.10/site-packages (from fastapi>=0.80->pytorch-forecasting) (3.7.1)\n", + "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in ./lib/python3.10/site-packages (from fastapi>=0.80->pytorch-forecasting) (0.27.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in ./lib/python3.10/site-packages (from fastapi>=0.80->pytorch-forecasting) (4.8.0)\n", + "Requirement already satisfied: pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4 in ./lib/python3.10/site-packages (from fastapi>=0.80->pytorch-forecasting) (2.5.1)\n", + "Requirement already satisfied: fsspec[http]<2025.0,>2021.06.0 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2023.10.0)\n", + "Requirement already satisfied: numpy<3.0,>=1.17.2 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.26.2)\n", + "Requirement already satisfied: tqdm<6.0,>=4.57.0 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.66.1)\n", + "Requirement already satisfied: torchmetrics<3.0,>=0.7.0 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.2.0)\n", + "Requirement already satisfied: PyYAML<8.0,>=5.4 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.1)\n", + "Requirement already satisfied: packaging<25.0,>=20.0 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (23.2)\n", + "Requirement already satisfied: pytorch-lightning in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.1.2)\n", + "Requirement already satisfied: lightning-utilities<2.0,>=0.8.0 in ./lib/python3.10/site-packages (from lightning<3.0.0,>=2.0.0->pytorch-forecasting) (0.10.0)\n", + "Requirement already satisfied: alembic>=1.5.0 in ./lib/python3.10/site-packages (from optuna<4.0.0,>=3.1.0->pytorch-forecasting) (1.12.1)\n", + "Requirement already satisfied: sqlalchemy>=1.3.0 in ./lib/python3.10/site-packages (from optuna<4.0.0,>=3.1.0->pytorch-forecasting) (2.0.23)\n", + "Requirement already satisfied: colorlog in ./lib/python3.10/site-packages (from optuna<4.0.0,>=3.1.0->pytorch-forecasting) (6.7.0)\n", + "Requirement already satisfied: pytz>=2020.1 in ./lib/python3.10/site-packages (from pandas<=3.0.0,>=1.3.0->pytorch-forecasting) (2023.3.post1)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in ./lib/python3.10/site-packages (from pandas<=3.0.0,>=1.3.0->pytorch-forecasting) (2.8.2)\n", + "Requirement already satisfied: tzdata>=2022.1 in ./lib/python3.10/site-packages (from pandas<=3.0.0,>=1.3.0->pytorch-forecasting) (2023.3)\n", + "Requirement already satisfied: joblib>=1.1.1 in ./lib/python3.10/site-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (1.3.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in ./lib/python3.10/site-packages (from scikit-learn<2.0,>=1.2->pytorch-forecasting) (3.2.0)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (8.9.2.26)\n", + "Requirement already satisfied: jinja2 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (3.1.2)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.105)\n", + "Requirement already satisfied: filelock in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (3.13.1)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (11.4.5.107)\n", + "Requirement already satisfied: triton==2.1.0 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (2.1.0)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (11.0.2.54)\n", + "Requirement already satisfied: sympy in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (1.12)\n", + "Requirement already satisfied: networkx in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (3.2.1)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./lib/python3.10/site-packages (from torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in ./lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch<3.0.0,>=2.0.0->pytorch-forecasting) (12.3.101)\n", + "Requirement already satisfied: cycler>=0.10 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (0.12.1)\n", + "Requirement already satisfied: contourpy>=1.0.1 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (1.2.0)\n", + "Requirement already satisfied: fonttools>=4.22.0 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (4.45.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (3.1.1)\n", + "Requirement already satisfied: kiwisolver>=1.3.1 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (1.4.5)\n", + "Requirement already satisfied: pillow>=8 in ./lib/python3.10/site-packages (from matplotlib->pytorch-forecasting) (10.1.0)\n", + "Requirement already satisfied: patsy>=0.5.2 in ./lib/python3.10/site-packages (from statsmodels->pytorch-forecasting) (0.5.3)\n", + "Requirement already satisfied: Mako in ./lib/python3.10/site-packages (from alembic>=1.5.0->optuna<4.0.0,>=3.1.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: exceptiongroup in ./lib/python3.10/site-packages (from anyio<4.0.0,>=3.7.1->fastapi>=0.80->pytorch-forecasting) (1.2.0)\n", + "Requirement already satisfied: sniffio>=1.1 in ./lib/python3.10/site-packages (from anyio<4.0.0,>=3.7.1->fastapi>=0.80->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: idna>=2.8 in ./lib/python3.10/site-packages (from anyio<4.0.0,>=3.7.1->fastapi>=0.80->pytorch-forecasting) (3.4)\n", + "Requirement already satisfied: requests in ./lib/python3.10/site-packages (from fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.31.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in ./lib/python3.10/site-packages (from fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.9.0)\n", + "Requirement already satisfied: setuptools in ./lib/python3.10/site-packages (from lightning-utilities<2.0,>=0.8.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (59.6.0)\n", + "Requirement already satisfied: six in ./lib/python3.10/site-packages (from patsy>=0.5.2->statsmodels->pytorch-forecasting) (1.16.0)\n", + "Requirement already satisfied: pydantic-core==2.14.3 in ./lib/python3.10/site-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi>=0.80->pytorch-forecasting) (2.14.3)\n", + "Requirement already satisfied: annotated-types>=0.4.0 in ./lib/python3.10/site-packages (from pydantic!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,!=2.1.0,<3.0.0,>=1.7.4->fastapi>=0.80->pytorch-forecasting) (0.6.0)\n", + "Requirement already satisfied: greenlet!=0.4.17 in ./lib/python3.10/site-packages (from sqlalchemy>=1.3.0->optuna<4.0.0,>=3.1.0->pytorch-forecasting) (3.0.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./lib/python3.10/site-packages (from jinja2->torch<3.0.0,>=2.0.0->pytorch-forecasting) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in ./lib/python3.10/site-packages (from sympy->torch<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (6.0.4)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.3.1)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.9.3)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (4.0.3)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (1.4.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (23.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./lib/python3.10/site-packages (from requests->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2023.11.17)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./lib/python3.10/site-packages (from requests->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (3.3.2)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./lib/python3.10/site-packages (from requests->fsspec[http]<2025.0,>2021.06.0->lightning<3.0.0,>=2.0.0->pytorch-forecasting) (2.1.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: pytorch-lightning in ./lib/python3.10/site-packages (2.1.2)\n", + "Requirement already satisfied: typing-extensions>=4.0.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (4.8.0)\n", + "Requirement already satisfied: PyYAML>=5.4 in ./lib/python3.10/site-packages (from pytorch-lightning) (6.0.1)\n", + "Requirement already satisfied: torch>=1.12.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (2.1.1)\n", + "Requirement already satisfied: packaging>=20.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (23.2)\n", + "Requirement already satisfied: torchmetrics>=0.7.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (1.2.0)\n", + "Requirement already satisfied: numpy>=1.17.2 in ./lib/python3.10/site-packages (from pytorch-lightning) (1.26.2)\n", + "Requirement already satisfied: fsspec[http]>2021.06.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (2023.10.0)\n", + "Requirement already satisfied: lightning-utilities>=0.8.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (0.10.0)\n", + "Requirement already satisfied: tqdm>=4.57.0 in ./lib/python3.10/site-packages (from pytorch-lightning) (4.66.1)\n", + "Requirement already satisfied: requests in ./lib/python3.10/site-packages (from fsspec[http]>2021.06.0->pytorch-lightning) (2.31.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in ./lib/python3.10/site-packages (from fsspec[http]>2021.06.0->pytorch-lightning) (3.9.0)\n", + "Requirement already satisfied: setuptools in ./lib/python3.10/site-packages (from lightning-utilities>=0.8.0->pytorch-lightning) (59.6.0)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.0.106)\n", + "Requirement already satisfied: triton==2.1.0 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.105)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (2.18.1)\n", + "Requirement already satisfied: filelock in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (3.13.1)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.3.1)\n", + "Requirement already satisfied: networkx in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (3.2.1)\n", + "Requirement already satisfied: sympy in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (1.12)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (12.1.105)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (11.4.5.107)\n", + "Requirement already satisfied: jinja2 in ./lib/python3.10/site-packages (from torch>=1.12.0->pytorch-lightning) (3.1.2)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in ./lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.12.0->pytorch-lightning) (12.3.101)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (4.0.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (1.3.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (6.0.4)\n", + "Requirement already satisfied: attrs>=17.3.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (23.1.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (1.4.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in ./lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>2021.06.0->pytorch-lightning) (1.9.3)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./lib/python3.10/site-packages (from jinja2->torch>=1.12.0->pytorch-lightning) (2.1.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./lib/python3.10/site-packages (from requests->fsspec[http]>2021.06.0->pytorch-lightning) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./lib/python3.10/site-packages (from requests->fsspec[http]>2021.06.0->pytorch-lightning) (3.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./lib/python3.10/site-packages (from requests->fsspec[http]>2021.06.0->pytorch-lightning) (2023.11.17)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./lib/python3.10/site-packages (from requests->fsspec[http]>2021.06.0->pytorch-lightning) (2.1.0)\n", + "Requirement already satisfied: mpmath>=0.19 in ./lib/python3.10/site-packages (from sympy->torch>=1.12.0->pytorch-lightning) (1.3.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: torch-optim in ./lib/python3.10/site-packages (0.0.4)\n", + "Requirement already satisfied: torch-pruning>=0.2.7 in ./lib/python3.10/site-packages (from torch-optim) (1.3.2)\n", + "Requirement already satisfied: torchvision>=0.11.1 in ./lib/python3.10/site-packages (from torch-optim) (0.16.1)\n", + "Requirement already satisfied: deap>=1.3.1 in ./lib/python3.10/site-packages (from torch-optim) (1.4.1)\n", + "Requirement already satisfied: pytorch-ignite>=0.4.8 in ./lib/python3.10/site-packages (from torch-optim) (0.4.13)\n", + "Requirement already satisfied: torch>=1.10.0 in ./lib/python3.10/site-packages (from torch-optim) (2.1.1)\n", + "Requirement already satisfied: thop>=0.0.31 in ./lib/python3.10/site-packages (from torch-optim) (0.1.1.post2209072238)\n", + "Requirement already satisfied: numpy in ./lib/python3.10/site-packages (from deap>=1.3.1->torch-optim) (1.26.2)\n", + "Requirement already satisfied: packaging in ./lib/python3.10/site-packages (from pytorch-ignite>=0.4.8->torch-optim) (23.2)\n", + "Requirement already satisfied: typing-extensions in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (4.8.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (2.1.0)\n", + "Requirement already satisfied: filelock in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (3.13.1)\n", + "Requirement already satisfied: jinja2 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (3.1.2)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (2.18.1)\n", + "Requirement already satisfied: sympy in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (1.12)\n", + "Requirement already satisfied: networkx in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (3.2.1)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.105)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (12.1.105)\n", + "Requirement already satisfied: fsspec in ./lib/python3.10/site-packages (from torch>=1.10.0->torch-optim) (2023.10.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in ./lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->torch-optim) (12.3.101)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in ./lib/python3.10/site-packages (from torchvision>=0.11.1->torch-optim) (10.1.0)\n", + "Requirement already satisfied: requests in ./lib/python3.10/site-packages (from torchvision>=0.11.1->torch-optim) (2.31.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./lib/python3.10/site-packages (from jinja2->torch>=1.10.0->torch-optim) (2.1.3)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in ./lib/python3.10/site-packages (from requests->torchvision>=0.11.1->torch-optim) (2.1.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in ./lib/python3.10/site-packages (from requests->torchvision>=0.11.1->torch-optim) (2023.11.17)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in ./lib/python3.10/site-packages (from requests->torchvision>=0.11.1->torch-optim) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in ./lib/python3.10/site-packages (from requests->torchvision>=0.11.1->torch-optim) (3.4)\n", + "Requirement already satisfied: mpmath>=0.19 in ./lib/python3.10/site-packages (from sympy->torch>=1.10.0->torch-optim) (1.3.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: numpy in ./lib/python3.10/site-packages (1.26.2)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: pandas in ./lib/python3.10/site-packages (2.1.3)\n", + "Requirement already satisfied: numpy<2,>=1.22.4 in ./lib/python3.10/site-packages (from pandas) (1.26.2)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in ./lib/python3.10/site-packages (from pandas) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in ./lib/python3.10/site-packages (from pandas) (2023.3.post1)\n", + "Requirement already satisfied: tzdata>=2022.1 in ./lib/python3.10/site-packages (from pandas) (2023.3)\n", + "Requirement already satisfied: six>=1.5 in ./lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n", + "Requirement already satisfied: torch in ./lib/python3.10/site-packages (2.1.1)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in ./lib/python3.10/site-packages (from torch) (2.18.1)\n", + "Requirement already satisfied: typing-extensions in ./lib/python3.10/site-packages (from torch) (4.8.0)\n", + "Requirement already satisfied: fsspec in ./lib/python3.10/site-packages (from torch) (2023.10.0)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in ./lib/python3.10/site-packages (from torch) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in ./lib/python3.10/site-packages (from torch) (11.4.5.107)\n", + "Requirement already satisfied: networkx in ./lib/python3.10/site-packages (from torch) (3.2.1)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in ./lib/python3.10/site-packages (from torch) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in ./lib/python3.10/site-packages (from torch) (10.3.2.106)\n", + "Requirement already satisfied: triton==2.1.0 in ./lib/python3.10/site-packages (from torch) (2.1.0)\n", + "Requirement already satisfied: jinja2 in ./lib/python3.10/site-packages (from torch) (3.1.2)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in ./lib/python3.10/site-packages (from torch) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in ./lib/python3.10/site-packages (from torch) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in ./lib/python3.10/site-packages (from torch) (12.1.105)\n", + "Requirement already satisfied: filelock in ./lib/python3.10/site-packages (from torch) (3.13.1)\n", + "Requirement already satisfied: sympy in ./lib/python3.10/site-packages (from torch) (1.12)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in ./lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.3.101)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in ./lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", + "Requirement already satisfied: mpmath>=0.19 in ./lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n", + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install wandb\n", + "%pip install pytorch-forecasting\n", + "%pip install pytorch-lightning\n", + "%pip install torch-optim\n", + "%pip install numpy\n", + "%pip install pandas\n", + "%pip install torch\n", + "\n", + "from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer\n", + "from pytorch_forecasting.metrics.point import RMSE\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "import pytorch_lightning as pl\n", + "import wandb\n", + "from datetime import datetime\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import csv" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.0" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/boom90lb/Zephyr/tftmodel/wandb/run-20231130_211121-hii2ru61" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run fresh-totem-23 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/b90/tft" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/b90/tft/runs/hii2ru61" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TRAIN_PCT = 0.5\n", + "VAL_PCT = 0.75\n", + "DATA_PATH = \"/mnt/zephyr/modeling/tft_format.csv\"\n", + "DF = pd.read_csv(DATA_PATH, \n", + " header=0, \n", + " date_format=\"%Y-%m-%d %H:%M:%S\", \n", + " parse_dates=True,\n", + " keep_date_col=False)\n", + "COLUMN_NAMES = [\"time_idx\",\"turbine_id\",\"time\",\"label\",\"timestamp\",\"turbine\",\"WGEN_IntTmp_max\",\"WGEN_IntTmp_mean\",\"WGEN_IntTmp_min\",\"WGEN_IntTmp_sd\",'WGEN_ClSt_max',\"WGEN_ClSt_mean\",'WGEN_ClSt_min',\"WCNV_Torq_mean\",\"WCNV_Torq_max\",\"WGEN_Spd_max\",\"WGEN_Spd_mean\",\"WGEN_Spd_min\",\"WGEN_W_mean\",\"WGEN_W_max\",\"WCNV_Torq_sd\",\"WGEN_Spd_sd\",\"WGEN_W_min\",\"WGEN_W_sd\",\"WTOW_PlatTmp_mean\",\"WNAC_WdSpd1_mean\",\"WNAC_WdSpd1_max\",\"WNAC_WdSpd2_max\",\"WNAC_WdSpd2_mean\",\"WNAC_WdDir1_mean\",\"WNAC_WdDir2_mean\",\"WNAC_ExtPres_mean\",\"WNAC_WdSpdAvg_mean\",\"WNAC_WdSpdAvg_max\",\"WNAC_WdSpdAvg_min\",\"WNAC_IntTmp_mean\",\"WNAC_IntTmp_max\",\"WNAC_IntTmp_min\",'WNAC_IntTmp_sd',\"WNAC_Vib1_max\",\"WNAC_Vib1_min\",'WNAC_Vib1_mean','WNAC_Vib2_max',\"WNAC_Vib2_mean\",\"WNAC_Vib2_min\",\"WNAC_Vib3_max\",\"WNAC_Vib3_mean\",\"WNAC_Vib3_min\",\"WNAC_Vib4_max\",\"WNAC_Vib4_mean\",\"WNAC_Vib4_min\",\"WTRM_FtrPres1_mean\",\"WTRM_FtrPres1_max\",\"WTRM_FtrPres2_max\",\"WTRM_FtrPres2_mean\",'WTRM_FtrPres1_min',\"WTRM_FtrPres2_min\",\"WROT_Spd1_mean\",\"WROT_Spd1_max\",\"WROT_Spd1_min\",\"WROT_Spd1_sd\",\"WROT_Spd2_max\",\"WROT_Spd2_mean\",\"WROT_Spd2_min\",\"WROT_Spd2_sd\",\"WROT_Pos_max\",\"WROT_Pos_mean\",\"WROT_Pos_min\",\"WROT_MnBrgTemp1_max\",\"WROT_MnBrgTemp1_mean\",\"WROT_MnBrgTemp1_min\",\"WROT_MnBrgTemp2_max\",\"WROT_MnBrgTemp2_mean\",\"WROT_MnBrgTemp2_min\",\"WTRM_HyFtrPres1_mean\",\"WTRM_HyFtrPres1_max\",\"WTRM_HyFtrPres1_min\",\"WTRM_HyFtrPres1_sd\",\"WTRM_HyFtrPres2_max\",\"WTRM_HyFtrPres2_mean\",\"WTRM_HyFtrPres2_min\",\"WTRM_HyFtrPres2_sd\",\"WTRM_HySysPres1_max\",\"WTRM_HySysPres1_mean\",\"WTRM_HySysPres1_min\",\"WTRM_HySysLockPres1_max\",\"WTRM_HySysLockPres1_mean\",\"WTRM_HySysLockPres1_sd\",\"WTRM_HySysLockPres1_min\",\"WROT_LockPos1_mean\",\"WROT_LockPos1_max\",\"WROT_LockPos1_min\",\"WROT_LockPos1_sd\",\"WROT_LockPos2_max\",\"WROT_LockPos2_mean\",\"WROT_LockPos2_min\",\"WROT_LockPos2_sd\",\"WROT_LockPos3_max\",\"WROT_LockPos3_mean\",\"WROT_LockPos3_min\",\"WROT_LockPos3_sd\",\"WROT_Brk2HyTmp6_sd\",\"WROT_Brk2HyTmp6_min\",\"WROT_Brk2HyTmp6_mean\",\"WROT_Brk2HyTmp6_max\",\"WROT_Brk2HyTmp5_sd\",\"WROT_Brk2HyTmp5_min\",\"WROT_Brk2HyTmp5_mean\",\"WROT_Brk2HyTmp5_max\",\"WROT_Brk1HyTmp6_min\",\"WROT_Brk1HyTmp6_sd\",\"WROT_Brk1HyTmp6_mean\",\"WROT_Brk1HyTmp6_max\",\"WROT_Brk1HyTmp5_sd\",\"WROT_Brk1HyTmp5_min\",\"WROT_Brk1HyTmp5_mean\",\"WROT_Brk1HyTmp5_max\",\"WROT_Brk2HyTmp4_sd\",\"WROT_Brk2HyTmp4_min\",\"WROT_Brk2HyTmp4_max\",\"WROT_Brk2HyTmp3_sd\",\"WROT_Brk2HyTmp4_mean\",\"WROT_Brk2HyTmp3_min\",\"WROT_Brk2HyTmp3_mean\",\"WROT_Brk2HyTmp3_max\",\"WROT_Brk1HyTmp4_sd\",\"WROT_Brk1HyTmp4_min\",\"WROT_Brk1HyTmp4_mean\",\"WROT_Brk1HyTmp4_max\",\"WROT_Brk1HyTmp3_sd\",\"WROT_Brk1HyTmp3_min\",\"WROT_Brk1HyTmp3_mean\",\"WROT_Brk1HyTmp3_max\",\"WROT_Brk2HyTmp2_sd\",\"WROT_Brk2HyTmp2_min\",\"WROT_Brk2HyTmp2_mean\",\"WROT_Brk2HyTmp2_max\",\"WROT_HyOilTmp1_sd\",\"WROT_HyOilTmp1_min\",\"WROT_HyOilTmp1_mean\",\"WROT_HyOilTmp1_max\",\"WROT_Brk2HyTmp1_sd\",\"WROT_Brk2HyTmp1_min\",\"WROT_Brk2HyTmp1_mean\",\"WROT_Brk2HyTmp1_max\",\"WROT_Brk1HyTmp2_sd\",\"WROT_Brk1HyTmp2_min\",\"WROT_Brk1HyTmp2_mean\",\"WROT_Brk1HyTmp2_max\",\"WROT_Brk1HyTmp1_sd\",\"WROT_Brk1HyTmp1_min\",\"WROT_Brk1HyTmp1_mean\",\"WROT_Brk1HyTmp1_max\",\"WROT_Brk1HyPres_max\",\"WROT_Brk1HyPres_mean\",\"WROT_Brk1HyPres_min\",\"WROT_Brk1HyPres_sd\",\"WROT_Brk2HyPres_max\",\"WROT_Brk2HyPres_mean\",\"WROT_Brk2HyPres_min\",\"WROT_Brk2HyPres_sd\",\"WROT_Brk1HyAccPres_max\",\"WROT_Brk1HyAccPres_mean\",\"WROT_Brk1HyAccPres_min\",\"WROT_Brk1HyAccPres_sd\",\"WROT_Brk2HyAccPres_max\",\"WROT_Brk2HyAccPres_mean\",\"WROT_Brk2HyAccPres_min\",\"WROT_Brk2HyAccPres_sd\"]\n", + "\n", + "wandb.login()\n", + "\n", + "wandb.init(\n", + " project = \"tft\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "def format_data(DF: pd.DataFrame):\n", + " # Copy dataframe\n", + " df = DF\n", + "\n", + " # Ensure 'turbine' is a categorical column, turn 'label' from bool to float\n", + " df['label'] = df['label'].astype(float)\n", + " df['turbine'] = df['turbine'].astype('category')\n", + "\n", + " # Create a 'time_idx' column that resets for each unique 'group_id'\n", + " df['time'] = pd.to_datetime(df['time'], format=\"%Y-%m-%d %H:%M:%S\")\n", + " df = df.sort_values(by=['time', 'turbine'])\n", + " df.insert(0, 'time_idx', df.groupby('turbine').cumcount())\n", + "\n", + " # DEBUGGING\n", + " f = open(\"/home/boom90lb/Zephyr/tftmodel/tft_formatted.csv\", \"w\")\n", + " f.write(df.to_csv(None, index=False))\n", + " f.close()\n", + "\n", + " training_cutoff = np.floor(df['time_idx'].max()*TRAIN_PCT)\n", + " validation_cutoff = np.floor(df['time_idx'].max()*VAL_PCT)\n", + "\n", + " return df, training_cutoff, validation_cutoff\n", + "\n", + "def get_column_names(data_path=\".\"):\n", + " with open(data_path) as csv_file:\n", + " csv_reader = csv.reader(csv_file)\n", + " for row in csv_reader:\n", + " return row\n", + "\n", + "TSDS_PARAMS = {\n", + " \"time_idx\": \"time_idx\",\n", + " \"target\": \"label\",\n", + " \"group_ids\": [\"turbine\"],\n", + " \"min_encoder_length\": 1,\n", + " \"max_encoder_length\": 16,\n", + " \"min_prediction_length\": 1,\n", + " \"max_prediction_length\": 8,\n", + " \"static_categoricals\": [],\n", + " \"static_reals\": [],\n", + " \"time_varying_known_reals\": [],\n", + " \"time_varying_unknown_reals\": COLUMN_NAMES[6:-1] + [\"label\"],\n", + " \"time_varying_unknown_categoricals\": [],\n", + " \"allow_missing_timesteps\": True,\n", + "}\n", + "\n", + "TFT_PARAMS = {\n", + " \"hidden_size\": 16,\n", + " \"lstm_layers\": 2,\n", + " \"dropout\": 0.2,\n", + " \"output_size\": 1,\n", + " \"loss\": RMSE(),\n", + " \"attention_head_size\": 4,\n", + " \"max_encoder_length\": 16,\n", + " \"allowed_encoder_known_variable_names\": COLUMN_NAMES[6:-1],\n", + " \"hidden_continuous_size\": 8, \n", + " \"learning_rate\": 0.003, \n", + " \"log_interval\": 10,\n", + " \"optimizer\": \"Adam\",\n", + " \"log_val_interval\": 1, \n", + " \"reduce_on_plateau_patience\": 4,\n", + " \"monotone_constaints\": {},\n", + " \"share_single_variable_networks\": False,\n", + " \"causal_attention\": True,\n", + "}\n", + "\n", + "TRAINER_PARAMS = {\n", + " \"max_epochs\": 5,\n", + " \"accelerator\": \"auto\"\n", + "}\n", + "\n", + "class TFTLightningModule(pl.LightningModule):\n", + " def __init__(self, df, **kwargs):\n", + " super().__init__()\n", + " self.model = TemporalFusionTransformer.from_dataset(df, **kwargs)\n", + " self.criterion = RMSE()\n", + " self.optimizers = self.configure_optimizers()\n", + " self.save_hyperparameters()\n", + " wandb.watch(models=self.model, criterion=self.criterion, log=\"all\", log_freq=100)\n", + " \n", + " def forward(self, x):\n", + " return self.model.forward(x)\n", + " \n", + " def training_step(self, batch):\n", + " x, y = batch\n", + "\n", + " out = self.forward(x)\n", + " loss = self.criterion.loss(out, y)\n", + " \n", + " self.log(\"out\", out)\n", + " self.log(\"train/loss\", loss)\n", + " \n", + " return loss\n", + "\n", + " def configure_optimizers(self):\n", + " return torch.optim.Adam(self.model.parameters(), lr=1e-3)\n", + " \n", + "def train(training_data: TimeSeriesDataSet, validation_data: TimeSeriesDataSet):\n", + " model = TFTLightningModule(training_data, **TFT_PARAMS)\n", + " \n", + " # wandb.log({\"test\": 1})\n", + " wandb_logger = WandbLogger(log_model=\"all\", save_dir=\"wandb/\")\n", + " \n", + " training_loader = training_data.to_dataloader(train=True, batch_size=8, num_workers=2)\n", + " validation_loader = validation_data.to_dataloader(train=False, batch_size=8, num_workers=2)\n", + " \n", + " trainer = pl.Trainer(logger=wandb_logger, log_every_n_steps=1, **TRAINER_PARAMS)\n", + " trainer.fit(model, training_loader, validation_loader)\n", + "\n", + " return trainer\n", + " \n", + "def eval(frozen_model, validation_data):\n", + "\n", + " frozen_model.load_state_dict(torch.load(\"/home/boom90lb/Zephyr/tftmodel/checkpoints/saves/11-23-2023;22-49-11.ckpt\"))\n", + "\n", + " frozen_model.eval()\n", + "\n", + " validation_dataloader = validation_data.to_dataloader(False, 4)\n", + "\n", + " loss_running_sum = 0\n", + " num_batches = 0\n", + "\n", + " for batch, i in enumerate(validation_dataloader):\n", + " x, y = batch\n", + " out = frozen_model.forward(x)\n", + " loss_fn = RMSE()\n", + " \n", + " loss = loss_fn.loss(out, y)\n", + " print(loss)\n", + " loss_running_sum += loss\n", + "\n", + " if i > num_batches:\n", + " num_batches = i\n", + "\n", + " return loss_running_sum/num_batches\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1351773/3567532967.py:12: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.\n", + " df.insert(0, 'time_idx', df.groupby('turbine').cumcount())\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainlen: 2310\n", + "vallen: 2100\n", + "testlen: 2128\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/boom90lb/Zephyr/tftmodel/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n", + "/home/boom90lb/Zephyr/tftmodel/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.\n", + "/home/boom90lb/Zephyr/tftmodel/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.\n", + "Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n", + "/home/boom90lb/Zephyr/tftmodel/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.\n", + "/home/boom90lb/Zephyr/tftmodel/lib/python3.10/site-packages/lightning/pytorch/loggers/wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", + "\n", + " | Name | Type | Params\n", + "--------------------------------------------------------\n", + "0 | model | TemporalFusionTransformer | 132 K \n", + "1 | criterion | RMSE | 0 \n", + "--------------------------------------------------------\n", + "132 K Trainable params\n", + "0 Non-trainable params\n", + "132 K Total params\n", + "0.530 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4: 0%| | 0/288 [00:00fresh-totem-23 at: https://wandb.ai/b90/tft/runs/hii2ru61
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20231130_211121-hii2ru61/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def main():\n", + " df, training_cutoff, validation_cutoff = format_data(DF)\n", + "\n", + " training_data = TimeSeriesDataSet(df.loc[df['time_idx'] <= training_cutoff], **TSDS_PARAMS)\n", + " \n", + " print(f'trainlen: {len(training_data)}')\n", + "\n", + " validation_data = TimeSeriesDataSet(df.loc[df['time_idx'] <= validation_cutoff], **TSDS_PARAMS, min_prediction_idx=training_cutoff + 1)\n", + " \n", + " print(f'vallen: {len(validation_data)}')\n", + "\n", + " test_data = TimeSeriesDataSet(df, **TSDS_PARAMS, min_prediction_idx=validation_cutoff + 1)\n", + " \n", + " print(f'testlen: {len(test_data)}')\n", + "\n", + " # prev_eval = eval(validation_data)\n", + "\n", + " trained_model = train(training_data, validation_data)\n", + "\n", + " trained_model.save_checkpoint(\"/home/boom90lb/Zephyr/tftmodel/checkpoints/model\" + datetime.now().strftime(\"%m-%d-%Y;%H-%M-%S\") + \".ckpt\")\n", + "\n", + " torch.save(trained_model.model.state_dict(), \"/home/boom90lb/Zephyr/tftmodel/checkpoints/saves/\" + datetime.now().strftime(\"%m-%d-%Y;%H-%M-%S\") + \".ckpt\")\n", + " \n", + " wandb.finish()\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tftmodel", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/mlstars/primitives/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.json b/mlstars/primitives/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.json new file mode 100644 index 0000000..8da98c9 --- /dev/null +++ b/mlstars/primitives/pytorch_forecasting.data.timeseries.TimeSeriesDataSet.json @@ -0,0 +1,92 @@ +{ + "name": "pytorch_forecasting.data.timeseries.TimeSeriesDataSet", + "contributors": [ + "Brendon Reperttang " + ], + "documentation": "https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html", + "description": "Two-dimensional size-mutable, potentially heterogeneous tabular data structure with labeled axes (rows and columns).", + "classifiers": { + "type": "helper", + "subtype": "fitting" + }, + "modalities": [], + "primitive": "pytorch_forecasting.data.timeseries.TimeSeriesDataSet", + "produce": { + "args": [ + { + "name": "X", + "keyword": "data", + "type": "DataFrame" + } + ], + "output": [ + { + "name": "dataset", + "type": "TimeSeriesDataSet" + } + ] + }, + "hyperparameters": { + "fixed": { + "time_idx": { + "type": "str", + "default": "time_idx", + "description": "Name of time_idx column, denoting the order of examples." + }, + "target": { + "type": "str", + "default": "target", + "description": "Name of target column, providing ground truth labels." + }, + "group_ids": { + "type": "ndarray", + "default": null, + "description": "If provided, the names of columns that associate groups of examples together with their own set of time_idx." + }, + "static_categoricals": { + "type": "ndarray", + "default": null + }, + "static_reals": { + "type": "ndarray", + "default": null + }, + "time_varying_unknown_categoricals": { + "type": "ndarray", + "default": null + }, + "time_varying_known_categoricals": { + "type": "ndarray", + "default": null + }, + "time_varying_unknown_reals": { + "type": "ndarray", + "default": null + }, + "time_varying_known_reals": { + "type": "ndarray", + "default": null + }, + "allow_missing_timesteps": { + "type": "bool", + "default": true + }, + "min_encoder_length": { + "type": "int", + "default": 1 + }, + "min_prediction_length": { + "type": "int", + "default": 1 + }, + "max_encoder_length": { + "type": "int", + "default": 16 + }, + "max_prediction_length": { + "type": "int", + "default": 8 + } + } + } +} \ No newline at end of file diff --git a/mlstars/primitives/pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.json b/mlstars/primitives/pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.json new file mode 100644 index 0000000..f872014 --- /dev/null +++ b/mlstars/primitives/pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.json @@ -0,0 +1,109 @@ +{ + "name": "TemporalFusionTransformer", + "contributors": [ + "Brendon Reperttang " + ], + "documentation": "https://pytorch-forecasting.readthedocs.io/en/stable/api/pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer.html", + "description": "Implementation of the article Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting. The network outperforms DeepAR by Amazon by 36-69% in benchmarks.", + "classifiers": { + "type": "estimator", + "subtype": "regressor" + }, + "modalities": [], + "primitive": "pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer", + "produce": { + "method": "from_dataset", + "args": [ + { + "name": "dataset", + "type": "TimeSeriesDataSet" + } + ], + "output": [ + { + "name": "model", + "type": "TemporalFusionTransformer" + } + ] + }, + "hyperparameters": { + "fixed": { + "hidden_size": { + "type": "int", + "default": 256, + "range": [ + 8, + 512 + ] + }, + "hidden_continuous_size": { + "type": "int", + "default": 256, + "range": [ + 8, + 512 + ] + }, + "lstm_layers": { + "type": "int", + "default": 2 + }, + "output_size": { + "type": "int", + "default": 1 + }, + "attention_head_size": { + "type": "int", + "default": 4 + }, + "max_encoder_length": { + "type": "int", + "default": 32 + }, + "allowed_encoder_known_variable_names": { + "type": "ndarray", + "default": null + }, + "log_val_interval": { + "type": "int", + "default": 1 + }, + "reduce_on_plateau_patience": { + "type": "int", + "default": 4 + }, + "share_single_variable_networks": { + "type": "bool", + "default": false + }, + "causal_attention": { + "type": "bool", + "default": true + }, + "optimizer": { + "type": "str", + "default": "Adam" + }, + "loss": { + "type": "function" + }, + "tunable": { + "dropout": { + "type": "float", + "default": 0.25, + "range": [ + 0, + 1 + ] + }, + "learning_rate": { + "type": "float", + "default": 0.01, + "range": [ + 0, + 1 + ] + } + } + } +}