Skip to content

[🐞] 复现STID在PEMS数据上出现问题 #311

@klayc-gzl

Description

@klayc-gzl

Is there an existing issue / discussion for this? | 是否已有关于该错误的issue或讨论?

  • I have searched the existing issues / discussions | 我已经搜索过已有的issues和讨论

Is there an existing answer for this in tutorial? | 该问题是否在教程中有解答?

  • I have searched tutorial | 我已经搜索过tutorial

Current Behavior | 当前行为

使用这个config实现的STID在PEMS08数据上效果与论文差别较大是config配置的问题吗:
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(file), 'src'))

from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from basicts.configs import BasicTSForecastingConfig
from basicts.models import STID
from basicts.data import BasicTSForecastingDataset

def get_config():
# Create model config for STID
model_config = STID.STIDConfig(
input_len=12,
output_len=12,
num_features=170, # Number of nodes in PEMS08
input_hidden_size=32,
intermediate_size=None,
hidden_act="relu",
num_layers=1,
if_spatial=True,
spatial_hidden_size=32,
if_time_in_day=False,
if_day_in_week=False,
num_time_in_day=24,
num_day_in_week=7,
tid_hidden_size=32,
diw_hidden_size=32
)

# Create and return the config using the proper constructor
config = BasicTSForecastingConfig(
    model=STID.STID,
    model_config=model_config,
    dataset_name="PEMS08",
    dataset_type=BasicTSForecastingDataset,
    dataset_params={
        "dataset_name": "PEMS08",
        "input_len": 12,
        "output_len": 12,
        "use_timestamps": True,
        "memmap": False,
    },
    num_epochs=100,
    optimizer=Adam,
    optimizer_params={
        "lr": 0.001,
        "weight_decay": 0.0001
    },
    lr_scheduler=MultiStepLR,
    lr_scheduler_params={
        "milestones": [50, 80],
        "gamma": 0.1
    },
    batch_size=64,
    train_batch_size=64,
    val_batch_size=64,
    test_batch_size=64,
    metrics=["MAE", "MAPE", "RMSE"],
    target_metric="MAE",
    best_metric="min",
    # Global config parameters
    gpus="0",  # Specify GPU to use
    norm_each_channel=True,  # Normalize each channel separately
    null_val=0.0,  # Value to fill for missing values
    rescale=True,  # Whether to rescale data (Updated for spatial-temporal forecasting)
    scaler=None  # Use default scaler
)

return config

if name == "main":
config = get_config()
print(config)

Image Image Image

Expected Behavior | 期望行为

No response

Environment | 运行环境

- OS:
- DEVICE:
- NVIDIA Driver:
- CUDA:
- NVIDIA GPU Memory:
- PyTorch:

BasicTS logs | BasicTS日志

No response

Steps To Reproduce | 复现方法

No response

Anything else? | 备注

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds-triagedfor issues raised to be triaged

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions