Skip to content

Commit 3c8d68f

Browse files
committed
feat: 🎸 support time series classification task
1 parent 83c9d69 commit 3c8d68f

34 files changed

Lines changed: 3368 additions & 61 deletions

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ MOIRAI (inference) | Unified Training of Universal Time Series Forecasting Trans
123123
124124
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
125125
| :--------- | :------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------- | :----- |
126-
| STDN | Spatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/33247) | [Link](https://github.com/roarer008/STDN) | AAAI'25 | STF |
127-
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
126+
| STDN | Spatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/33247) | [Link](https://github.com/roarer008/STDN) | AAAI'25 | STF |
127+
| HimNet | Heterogeneity-Informed Meta-Parameter Learning for Spatiotemporal Time Series Forecasting | [Link](https://arxiv.org/abs/2405.10800) | [Link](https://github.com/XDZhelheim/HimNet) | SIGKDD'24 | STF |
128+
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
128129
| BigST | Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks | [Link](https://dl.acm.org/doi/10.14778/3641204.3641217) | [Link](https://github.com/usail-hkust/BigST?tab=readme-ov-file) | VLDB'24 | STF |
129130
| STDMAE | Spatio-Temporal-Decoupled Masked Pre-training for Traffic Forecasting | [Link](https://arxiv.org/abs/2312.00516) | [Link](https://github.com/Jimmy-7664/STD-MAE) | IJCAI'24 | STF |
130131
| STWave | When Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention Networks | [Link](https://ieeexplore.ieee.org/document/10184591) | [Link](https://github.com/LMissher/STWave) | ICDE'23 | STF |

README_CN.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ MOIRAI (inference) | Unified Training of Universal Time Series Forecasting Trans
125125
126126
| 📊Baseline | 📝Title | 📄Paper | 💻Code | 🏛Venue | 🎯Task |
127127
| :--------- | :------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------- | :----- |
128-
| STDN | Spatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/33247) | [Link](https://github.com/roarer008/STDN) | AAAI'25 | STF |
129-
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
128+
| STDN | Spatiotemporal-aware Trend-Seasonality Decomposition Network for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/33247) | [Link](https://github.com/roarer008/STDN) | AAAI'25 | STF |
129+
| HimNet | Heterogeneity-Informed Meta-Parameter Learning for Spatiotemporal Time Series Forecasting | [Link](https://arxiv.org/abs/2405.10800) | [Link](https://github.com/XDZhelheim/HimNet) | SIGKDD'24 | STF |
130+
| STPGNN | Spatio-Temporal Pivotal Graph Neural Networks for Traffic Flow Forecasting | [Link](https://ojs.aaai.org/index.php/AAAI/article/view/28707) | [Link](https://github.com/Kongwy5689/STPGNN?tab=readme-ov-file) | AAAI'24 | STF |
130131
| BigST | Linear Complexity Spatio-Temporal Graph Neural Network for Traffic Forecasting on Large-Scale Road Networks | [Link](https://dl.acm.org/doi/10.14778/3641204.3641217) | [Link](https://github.com/usail-hkust/BigST?tab=readme-ov-file) | VLDB'24 | STF |
131132
| STDMAE | Spatio-Temporal-Decoupled Masked Pre-training for Traffic Forecasting | [Link](https://arxiv.org/abs/2312.00516) | [Link](https://github.com/Jimmy-7664/STD-MAE) | IJCAI'24 | STF |
132133
| STWave | When Spatio-Temporal Meet Wavelets: Disentangled Traffic Forecasting via Efficient Spectral Graph Attention Networks | [Link](https://ieeexplore.ieee.org/document/10184591) | [Link](https://github.com/LMissher/STWave) | ICDE'23 | STF |

baselines/HimNet/METR-LA.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
from easydict import EasyDict
5+
6+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
7+
8+
from basicts.metrics import masked_mae, masked_mape, masked_rmse, masked_huber
9+
from basicts.data import TimeSeriesForecastingDataset
10+
from basicts.scaler import ZScoreScaler
11+
from basicts.utils import get_regular_settings
12+
13+
from .arch import HimNet
14+
from .runner import HimNetRunner
15+
16+
############################## Hot Parameters ##############################
17+
# Dataset & Metrics configuration
18+
DATA_NAME = 'METR-LA' # Dataset name
19+
regular_settings = get_regular_settings(DATA_NAME)
20+
INPUT_LEN = 12 # Length of input sequence
21+
OUTPUT_LEN = 12 # Length of output sequence
22+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
23+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
24+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
25+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
26+
# Model architecture and parameters
27+
MODEL_ARCH = HimNet
28+
HIMNET_CONFIG = {
29+
'lr': 0.001,
30+
'eps': 0.001,
31+
'weight_decay': 0.0005,
32+
'milestones': [30, 40],
33+
'clip_grad': 5,
34+
'batch_size': 16,
35+
'max_epochs': 200,
36+
'early_stop': 20
37+
}
38+
39+
NUM_EPOCHS = HIMNET_CONFIG['max_epochs']
40+
41+
MODEL_PARAM = {
42+
'num_nodes': 207,
43+
'input_dim': 3,
44+
'output_dim': 1,
45+
'out_steps': 12,
46+
'hidden_dim': 64,
47+
'num_layers': 1,
48+
'cheb_k': 2,
49+
'ycov_dim': 2,
50+
'tod_embedding_dim': 8,
51+
'dow_embedding_dim': 8,
52+
'node_embedding_dim': 16,
53+
'st_embedding_dim': 16,
54+
'tf_decay_steps': 6000,
55+
'use_teacher_forcing': True
56+
}
57+
58+
############################## General Configuration ##############################
59+
CFG = EasyDict()
60+
# General settings
61+
CFG.DESCRIPTION = 'An Example Config'
62+
CFG.GPU_NUM = 8 # Number of GPUs to use (0 for CPU mode)
63+
# Runner
64+
CFG.RUNNER = HimNetRunner
65+
66+
67+
############################## Dataset Configuration ##############################
68+
CFG.DATASET = EasyDict()
69+
# Dataset settings
70+
CFG.DATASET.NAME = DATA_NAME
71+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
72+
CFG.DATASET.PARAM = EasyDict({
73+
'dataset_name': DATA_NAME,
74+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
75+
'input_len': INPUT_LEN,
76+
'output_len': OUTPUT_LEN,
77+
# 'mode' is automatically set by the runner
78+
})
79+
80+
############################## Scaler Configuration ##############################
81+
CFG.SCALER = EasyDict()
82+
# Scaler settings
83+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
84+
CFG.SCALER.PARAM = EasyDict({
85+
'dataset_name': DATA_NAME,
86+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
87+
'norm_each_channel': NORM_EACH_CHANNEL,
88+
'rescale': RESCALE,
89+
})
90+
91+
############################## Model Configuration ##############################
92+
CFG.MODEL = EasyDict()
93+
# Model settings
94+
CFG.MODEL.NAME = MODEL_ARCH.__name__
95+
CFG.MODEL.ARCH = MODEL_ARCH
96+
CFG.MODEL.PARAM = MODEL_PARAM
97+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2]
98+
CFG.MODEL.TARGET_FEATURES = [0]
99+
CFG.MODEL.SETUP_GRAPH = True
100+
101+
############################## Metrics Configuration ##############################
102+
103+
CFG.METRICS = EasyDict()
104+
# Metrics settings
105+
CFG.METRICS.FUNCS = EasyDict({
106+
'MAE': masked_mae,
107+
'MAPE': masked_mape,
108+
'RMSE': masked_rmse,
109+
})
110+
CFG.METRICS.TARGET = 'MAE'
111+
CFG.METRICS.NULL_VAL = NULL_VAL
112+
113+
############################## Training Configuration ##############################
114+
CFG.TRAIN = EasyDict()
115+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
116+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
117+
'checkpoints',
118+
MODEL_ARCH.__name__,
119+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
120+
)
121+
CFG.TRAIN.LOSS = masked_mae
122+
# Optimizer settings
123+
CFG.TRAIN.OPTIM = EasyDict()
124+
CFG.TRAIN.OPTIM.TYPE = "Adam"
125+
CFG.TRAIN.OPTIM.PARAM = {
126+
"lr":max(HIMNET_CONFIG['lr']*CFG.GPU_NUM, HIMNET_CONFIG['lr']),
127+
"eps":HIMNET_CONFIG['eps'],
128+
"weight_decay":HIMNET_CONFIG['weight_decay'],
129+
}
130+
# Learning rate scheduler settings
131+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
132+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
133+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
134+
'milestones': HIMNET_CONFIG['milestones'],
135+
'gamma': HIMNET_CONFIG.get('lr_decay_rate', 0.1),
136+
'verbose': False
137+
}
138+
# Train data loader settings
139+
CFG.TRAIN.DATA = EasyDict()
140+
CFG.TRAIN.DATA.BATCH_SIZE = HIMNET_CONFIG['batch_size']
141+
CFG.TRAIN.DATA.SHUFFLE = True
142+
143+
CFG.TRAIN.EARLY_STOPPING_PATIENCE = HIMNET_CONFIG['early_stop']
144+
145+
CFG.TRAIN.CLIP_GRAD_PARAM = {
146+
'max_norm': HIMNET_CONFIG.get('clip_grad', 5)
147+
}
148+
149+
############################## Validation Configuration ##############################
150+
CFG.VAL = EasyDict()
151+
CFG.VAL.INTERVAL = 1
152+
CFG.VAL.DATA = EasyDict()
153+
CFG.VAL.DATA.BATCH_SIZE = 64
154+
155+
############################## Test Configuration ##############################
156+
CFG.TEST = EasyDict()
157+
CFG.TEST.INTERVAL = 1
158+
CFG.TEST.DATA = EasyDict()
159+
CFG.TEST.DATA.BATCH_SIZE = 64
160+
161+
############################## Evaluation Configuration ##############################
162+
163+
CFG.EVAL = EasyDict()
164+
165+
# Evaluation parameters
166+
CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: []
167+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/HimNet/PEMS-BAY.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
from easydict import EasyDict
5+
6+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
7+
8+
from basicts.metrics import masked_mae, masked_mape, masked_rmse, masked_huber
9+
from basicts.data import TimeSeriesForecastingDataset
10+
from basicts.scaler import ZScoreScaler
11+
from basicts.utils import get_regular_settings
12+
13+
from .arch import HimNet
14+
from .runner import HimNetRunner
15+
16+
############################## Hot Parameters ##############################
17+
# Dataset & Metrics configuration
18+
DATA_NAME = 'PEMS-BAY' # Dataset name
19+
regular_settings = get_regular_settings(DATA_NAME)
20+
INPUT_LEN = 12 # Length of input sequence
21+
OUTPUT_LEN = 12 # Length of output sequence
22+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
23+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
24+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
25+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
26+
# Model architecture and parameters
27+
MODEL_ARCH = HimNet
28+
HIMNET_CONFIG = {
29+
'lr': 0.001,
30+
'eps': 0.001,
31+
'weight_decay': 0.0001,
32+
'milestones': [25, 35],
33+
'clip_grad': 5,
34+
'batch_size': 16,
35+
'max_epochs': 200,
36+
'early_stop': 20
37+
}
38+
39+
NUM_EPOCHS = HIMNET_CONFIG['max_epochs']
40+
41+
MODEL_PARAM = {
42+
'num_nodes': 325,
43+
'input_dim': 3,
44+
'output_dim': 1,
45+
'tod_embedding_dim': 8,
46+
'dow_embedding_dim': 8,
47+
'out_steps': 12,
48+
'hidden_dim': 64,
49+
'num_layers': 1,
50+
'cheb_k': 2,
51+
'ycov_dim': 2,
52+
'node_embedding_dim': 16,
53+
'st_embedding_dim': 16,
54+
'tf_decay_steps': 6000,
55+
'use_teacher_forcing': True
56+
}
57+
58+
############################## General Configuration ##############################
59+
CFG = EasyDict()
60+
# General settings
61+
CFG.DESCRIPTION = 'An Example Config'
62+
CFG.GPU_NUM = 8 # Number of GPUs to use (0 for CPU mode)
63+
# Runner
64+
CFG.RUNNER = HimNetRunner
65+
66+
67+
############################## Dataset Configuration ##############################
68+
CFG.DATASET = EasyDict()
69+
# Dataset settings
70+
CFG.DATASET.NAME = DATA_NAME
71+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
72+
CFG.DATASET.PARAM = EasyDict({
73+
'dataset_name': DATA_NAME,
74+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
75+
'input_len': INPUT_LEN,
76+
'output_len': OUTPUT_LEN,
77+
# 'mode' is automatically set by the runner
78+
})
79+
80+
############################## Scaler Configuration ##############################
81+
CFG.SCALER = EasyDict()
82+
# Scaler settings
83+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
84+
CFG.SCALER.PARAM = EasyDict({
85+
'dataset_name': DATA_NAME,
86+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
87+
'norm_each_channel': NORM_EACH_CHANNEL,
88+
'rescale': RESCALE,
89+
})
90+
91+
############################## Model Configuration ##############################
92+
CFG.MODEL = EasyDict()
93+
# Model settings
94+
CFG.MODEL.NAME = MODEL_ARCH.__name__
95+
CFG.MODEL.ARCH = MODEL_ARCH
96+
CFG.MODEL.PARAM = MODEL_PARAM
97+
CFG.MODEL.FORWARD_FEATURES = [0, 1, 2]
98+
CFG.MODEL.TARGET_FEATURES = [0]
99+
CFG.MODEL.SETUP_GRAPH = True
100+
101+
############################## Metrics Configuration ##############################
102+
103+
CFG.METRICS = EasyDict()
104+
# Metrics settings
105+
CFG.METRICS.FUNCS = EasyDict({
106+
'MAE': masked_mae,
107+
'MAPE': masked_mape,
108+
'RMSE': masked_rmse,
109+
})
110+
CFG.METRICS.TARGET = 'MAE'
111+
CFG.METRICS.NULL_VAL = NULL_VAL
112+
113+
############################## Training Configuration ##############################
114+
CFG.TRAIN = EasyDict()
115+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
116+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
117+
'checkpoints',
118+
MODEL_ARCH.__name__,
119+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
120+
)
121+
CFG.TRAIN.LOSS = masked_mae
122+
# Optimizer settings
123+
CFG.TRAIN.OPTIM = EasyDict()
124+
CFG.TRAIN.OPTIM.TYPE = "Adam"
125+
CFG.TRAIN.OPTIM.PARAM = {
126+
"lr":max(HIMNET_CONFIG['lr']*CFG.GPU_NUM, HIMNET_CONFIG['lr']),
127+
"eps":HIMNET_CONFIG['eps'],
128+
"weight_decay":HIMNET_CONFIG['weight_decay'],
129+
}
130+
# Learning rate scheduler settings
131+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
132+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
133+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
134+
'milestones': HIMNET_CONFIG['milestones'],
135+
'gamma': HIMNET_CONFIG.get('lr_decay_rate', 0.1),
136+
'verbose': False
137+
}
138+
# Train data loader settings
139+
CFG.TRAIN.DATA = EasyDict()
140+
CFG.TRAIN.DATA.BATCH_SIZE = HIMNET_CONFIG['batch_size']
141+
CFG.TRAIN.DATA.SHUFFLE = True
142+
143+
CFG.TRAIN.EARLY_STOPPING_PATIENCE = HIMNET_CONFIG['early_stop']
144+
145+
CFG.TRAIN.CLIP_GRAD_PARAM = {
146+
'max_norm': HIMNET_CONFIG.get('clip_grad', 5)
147+
}
148+
149+
############################## Validation Configuration ##############################
150+
CFG.VAL = EasyDict()
151+
CFG.VAL.INTERVAL = 1
152+
CFG.VAL.DATA = EasyDict()
153+
CFG.VAL.DATA.BATCH_SIZE = 64
154+
155+
############################## Test Configuration ##############################
156+
CFG.TEST = EasyDict()
157+
CFG.TEST.INTERVAL = 1
158+
CFG.TEST.DATA = EasyDict()
159+
CFG.TEST.DATA.BATCH_SIZE = 64
160+
161+
############################## Evaluation Configuration ##############################
162+
163+
CFG.EVAL = EasyDict()
164+
165+
# Evaluation parameters
166+
CFG.EVAL.HORIZONS = [3, 6, 12] # Prediction horizons for evaluation. Default: []
167+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

0 commit comments

Comments
 (0)