Skip to content

Commit aba908d

Browse files
committed
'add-s4'
1 parent 962435e commit aba908d

7 files changed

Lines changed: 652 additions & 0 deletions

File tree

baselines/S4/ETTm2.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
6+
from basicts.metrics import masked_mae, masked_mse
7+
from basicts.data import TimeSeriesForecastingDataset
8+
from basicts.runners import SimpleTimeSeriesForecastingRunner
9+
from basicts.scaler import ZScoreScaler
10+
from basicts.utils import get_regular_settings
11+
12+
from .arch import S4
13+
14+
############################## Hot Parameters ##############################
15+
# Dataset & Metrics configuration
16+
DATA_NAME = 'ETTm2' # Dataset name
17+
regular_settings = get_regular_settings(DATA_NAME)
18+
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
19+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
20+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
21+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
22+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
23+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
24+
# Model architecture and parameters
25+
MODEL_ARCH = S4
26+
MODEL_PARAM = {
27+
"seq_len": INPUT_LEN,
28+
"pred_len": OUTPUT_LEN,
29+
"individual": False,
30+
"d_input": 7,
31+
"d_output": OUTPUT_LEN,
32+
"prenorm": True,
33+
"d_model": 128,
34+
"n_layers": 4,
35+
"dropout": 0.1,
36+
"lr": 0.01
37+
}
38+
39+
NUM_EPOCHS = 50
40+
41+
############################## General Configuration ##############################
42+
CFG = EasyDict()
43+
# General settings
44+
CFG.DESCRIPTION = 'An Example Config'
45+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
46+
# Runner
47+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
48+
49+
############################## Dataset Configuration ##############################
50+
CFG.DATASET = EasyDict()
51+
# Dataset settings
52+
CFG.DATASET.NAME = DATA_NAME
53+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
54+
CFG.DATASET.PARAM = EasyDict({
55+
'dataset_name': DATA_NAME,
56+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
57+
'input_len': INPUT_LEN,
58+
'output_len': OUTPUT_LEN,
59+
# 'mode' is automatically set by the runner
60+
})
61+
62+
############################## Scaler Configuration ##############################
63+
CFG.SCALER = EasyDict()
64+
# Scaler settings
65+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
66+
CFG.SCALER.PARAM = EasyDict({
67+
'dataset_name': DATA_NAME,
68+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
69+
'norm_each_channel': NORM_EACH_CHANNEL,
70+
'rescale': RESCALE,
71+
})
72+
73+
############################## Model Configuration ##############################
74+
CFG.MODEL = EasyDict()
75+
# Model settings
76+
CFG.MODEL.NAME = MODEL_ARCH.__name__
77+
CFG.MODEL.ARCH = MODEL_ARCH
78+
CFG.MODEL.PARAM = MODEL_PARAM
79+
CFG.MODEL.FORWARD_FEATURES = [0]
80+
CFG.MODEL.TARGET_FEATURES = [0]
81+
82+
############################## Metrics Configuration ##############################
83+
84+
CFG.METRICS = EasyDict()
85+
# Metrics settings
86+
CFG.METRICS.FUNCS = EasyDict({
87+
'MAE': masked_mae,
88+
'MSE': masked_mse,
89+
})
90+
CFG.METRICS.TARGET = 'MSE'
91+
CFG.METRICS.NULL_VAL = NULL_VAL
92+
93+
############################## Training Configuration ##############################
94+
CFG.TRAIN = EasyDict()
95+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
96+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
97+
'checkpoints',
98+
MODEL_ARCH.__name__,
99+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
100+
)
101+
CFG.TRAIN.LOSS = masked_mse
102+
# Optimizer settings
103+
CFG.TRAIN.OPTIM = EasyDict()
104+
CFG.TRAIN.OPTIM.TYPE = "Adam"
105+
CFG.TRAIN.OPTIM.PARAM = {
106+
"lr": 0.001
107+
}
108+
# Learning rate scheduler settings
109+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
110+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
111+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
112+
"milestones": [1, 25]
113+
}
114+
CFG.TRAIN.CLIP_GRAD_PARAM = {
115+
'max_norm': 5.0
116+
}
117+
# Train data loader settings
118+
CFG.TRAIN.DATA = EasyDict()
119+
CFG.TRAIN.DATA.BATCH_SIZE = 64
120+
CFG.TRAIN.DATA.SHUFFLE = True
121+
122+
############################## Validation Configuration ##############################
123+
CFG.VAL = EasyDict()
124+
CFG.VAL.INTERVAL = 1
125+
CFG.VAL.DATA = EasyDict()
126+
CFG.VAL.DATA.BATCH_SIZE = 64
127+
128+
############################## Test Configuration ##############################
129+
CFG.TEST = EasyDict()
130+
CFG.TEST.INTERVAL = 1
131+
CFG.TEST.DATA = EasyDict()
132+
CFG.TEST.DATA.BATCH_SIZE = 64
133+
134+
############################## Evaluation Configuration ##############################
135+
136+
CFG.EVAL = EasyDict()
137+
138+
# Evaluation parameters
139+
CFG.EVAL.HORIZONS = [12, 24, 48, 96]
140+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

baselines/S4/Electricity.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import sys
3+
from easydict import EasyDict
4+
sys.path.append(os.path.abspath(__file__ + '/../../..'))
5+
6+
from basicts.metrics import masked_mae, masked_mse
7+
from basicts.data import TimeSeriesForecastingDataset
8+
from basicts.runners import SimpleTimeSeriesForecastingRunner
9+
from basicts.scaler import ZScoreScaler
10+
from basicts.utils import get_regular_settings
11+
12+
from .arch import S4
13+
14+
############################## Hot Parameters ##############################
15+
# Dataset & Metrics configuration
16+
DATA_NAME = 'Electricity' # Dataset name
17+
regular_settings = get_regular_settings(DATA_NAME)
18+
INPUT_LEN = regular_settings['INPUT_LEN'] # Length of input sequence
19+
OUTPUT_LEN = regular_settings['OUTPUT_LEN'] # Length of output sequence
20+
TRAIN_VAL_TEST_RATIO = regular_settings['TRAIN_VAL_TEST_RATIO'] # Train/Validation/Test split ratios
21+
NORM_EACH_CHANNEL = regular_settings['NORM_EACH_CHANNEL'] # Whether to normalize each channel of the data
22+
RESCALE = regular_settings['RESCALE'] # Whether to rescale the data
23+
NULL_VAL = regular_settings['NULL_VAL'] # Null value in the data
24+
# Model architecture and parameters
25+
MODEL_ARCH = S4
26+
MODEL_PARAM = {
27+
"seq_len": INPUT_LEN,
28+
"pred_len": OUTPUT_LEN,
29+
"individual": False,
30+
"d_input": 321,
31+
"d_output": OUTPUT_LEN,
32+
"prenorm": True,
33+
"d_model": 128,
34+
"n_layers": 4,
35+
"dropout": 0.1,
36+
"lr": 0.01
37+
}
38+
39+
NUM_EPOCHS = 50
40+
41+
############################## General Configuration ##############################
42+
CFG = EasyDict()
43+
# General settings
44+
CFG.DESCRIPTION = 'An Example Config'
45+
CFG.GPU_NUM = 1 # Number of GPUs to use (0 for CPU mode)
46+
# Runner
47+
CFG.RUNNER = SimpleTimeSeriesForecastingRunner
48+
49+
############################## Dataset Configuration ##############################
50+
CFG.DATASET = EasyDict()
51+
# Dataset settings
52+
CFG.DATASET.NAME = DATA_NAME
53+
CFG.DATASET.TYPE = TimeSeriesForecastingDataset
54+
CFG.DATASET.PARAM = EasyDict({
55+
'dataset_name': DATA_NAME,
56+
'train_val_test_ratio': TRAIN_VAL_TEST_RATIO,
57+
'input_len': INPUT_LEN,
58+
'output_len': OUTPUT_LEN,
59+
# 'mode' is automatically set by the runner
60+
})
61+
62+
############################## Scaler Configuration ##############################
63+
CFG.SCALER = EasyDict()
64+
# Scaler settings
65+
CFG.SCALER.TYPE = ZScoreScaler # Scaler class
66+
CFG.SCALER.PARAM = EasyDict({
67+
'dataset_name': DATA_NAME,
68+
'train_ratio': TRAIN_VAL_TEST_RATIO[0],
69+
'norm_each_channel': NORM_EACH_CHANNEL,
70+
'rescale': RESCALE,
71+
})
72+
73+
############################## Model Configuration ##############################
74+
CFG.MODEL = EasyDict()
75+
# Model settings
76+
CFG.MODEL.NAME = MODEL_ARCH.__name__
77+
CFG.MODEL.ARCH = MODEL_ARCH
78+
CFG.MODEL.PARAM = MODEL_PARAM
79+
CFG.MODEL.FORWARD_FEATURES = [0]
80+
CFG.MODEL.TARGET_FEATURES = [0]
81+
82+
############################## Metrics Configuration ##############################
83+
84+
CFG.METRICS = EasyDict()
85+
# Metrics settings
86+
CFG.METRICS.FUNCS = EasyDict({
87+
'MAE': masked_mae,
88+
'MSE': masked_mse,
89+
})
90+
CFG.METRICS.TARGET = 'MSE'
91+
CFG.METRICS.NULL_VAL = NULL_VAL
92+
93+
############################## Training Configuration ##############################
94+
CFG.TRAIN = EasyDict()
95+
CFG.TRAIN.NUM_EPOCHS = NUM_EPOCHS
96+
CFG.TRAIN.CKPT_SAVE_DIR = os.path.join(
97+
'checkpoints',
98+
MODEL_ARCH.__name__,
99+
'_'.join([DATA_NAME, str(CFG.TRAIN.NUM_EPOCHS), str(INPUT_LEN), str(OUTPUT_LEN)])
100+
)
101+
CFG.TRAIN.LOSS = masked_mse
102+
# Optimizer settings
103+
CFG.TRAIN.OPTIM = EasyDict()
104+
CFG.TRAIN.OPTIM.TYPE = "Adam"
105+
CFG.TRAIN.OPTIM.PARAM = {
106+
"lr": 0.001
107+
}
108+
# Learning rate scheduler settings
109+
CFG.TRAIN.LR_SCHEDULER = EasyDict()
110+
CFG.TRAIN.LR_SCHEDULER.TYPE = "MultiStepLR"
111+
CFG.TRAIN.LR_SCHEDULER.PARAM = {
112+
"milestones": [1, 25]
113+
}
114+
CFG.TRAIN.CLIP_GRAD_PARAM = {
115+
'max_norm': 5.0
116+
}
117+
# Train data loader settings
118+
CFG.TRAIN.DATA = EasyDict()
119+
CFG.TRAIN.DATA.BATCH_SIZE = 64
120+
CFG.TRAIN.DATA.SHUFFLE = True
121+
122+
############################## Validation Configuration ##############################
123+
CFG.VAL = EasyDict()
124+
CFG.VAL.INTERVAL = 1
125+
CFG.VAL.DATA = EasyDict()
126+
CFG.VAL.DATA.BATCH_SIZE = 64
127+
128+
############################## Test Configuration ##############################
129+
CFG.TEST = EasyDict()
130+
CFG.TEST.INTERVAL = 1
131+
CFG.TEST.DATA = EasyDict()
132+
CFG.TEST.DATA.BATCH_SIZE = 64
133+
134+
############################## Evaluation Configuration ##############################
135+
136+
CFG.EVAL = EasyDict()
137+
138+
# Evaluation parameters
139+
CFG.EVAL.HORIZONS = [12, 24, 48, 96]
140+
CFG.EVAL.USE_GPU = True # Whether to use GPU for evaluation. Default: True

0 commit comments

Comments
 (0)