-
Notifications
You must be signed in to change notification settings - Fork 203
Description
Is there an existing issue / discussion for this? | 是否已有关于该错误的issue或讨论?
- I have searched the existing issues / discussions | 我已经搜索过已有的issues和讨论
Is there an existing answer for this in tutorial? | 该问题是否在教程中有解答?
- I have searched tutorial | 我已经搜索过tutorial
Current Behavior | 当前行为
使用这个config实现的STID在PEMS08数据上效果与论文差别较大是config配置的问题吗:
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(file), 'src'))
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from basicts.configs import BasicTSForecastingConfig
from basicts.models import STID
from basicts.data import BasicTSForecastingDataset
def get_config():
# Create model config for STID
model_config = STID.STIDConfig(
input_len=12,
output_len=12,
num_features=170, # Number of nodes in PEMS08
input_hidden_size=32,
intermediate_size=None,
hidden_act="relu",
num_layers=1,
if_spatial=True,
spatial_hidden_size=32,
if_time_in_day=False,
if_day_in_week=False,
num_time_in_day=24,
num_day_in_week=7,
tid_hidden_size=32,
diw_hidden_size=32
)
# Create and return the config using the proper constructor
config = BasicTSForecastingConfig(
model=STID.STID,
model_config=model_config,
dataset_name="PEMS08",
dataset_type=BasicTSForecastingDataset,
dataset_params={
"dataset_name": "PEMS08",
"input_len": 12,
"output_len": 12,
"use_timestamps": True,
"memmap": False,
},
num_epochs=100,
optimizer=Adam,
optimizer_params={
"lr": 0.001,
"weight_decay": 0.0001
},
lr_scheduler=MultiStepLR,
lr_scheduler_params={
"milestones": [50, 80],
"gamma": 0.1
},
batch_size=64,
train_batch_size=64,
val_batch_size=64,
test_batch_size=64,
metrics=["MAE", "MAPE", "RMSE"],
target_metric="MAE",
best_metric="min",
# Global config parameters
gpus="0", # Specify GPU to use
norm_each_channel=True, # Normalize each channel separately
null_val=0.0, # Value to fill for missing values
rescale=True, # Whether to rescale data (Updated for spatial-temporal forecasting)
scaler=None # Use default scaler
)
return config
if name == "main":
config = get_config()
print(config)
Expected Behavior | 期望行为
No response
Environment | 运行环境
- OS:
- DEVICE:
- NVIDIA Driver:
- CUDA:
- NVIDIA GPU Memory:
- PyTorch:BasicTS logs | BasicTS日志
No response
Steps To Reproduce | 复现方法
No response
Anything else? | 备注
No response