Skip to content

[official] PyTorch implementation of TimeVQVAE from the paper ["Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023]

License

Notifications You must be signed in to change notification settings

ML4ITS/TimeVQVAE

Repository files navigation

TimeVQVAE

This is an official Github repository for the PyTorch implementation of TimeVQVAE from our paper "Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023.

TimeVQVAE is a robust time series generation model that utilizes vector quantization for data compression into the discrete latent space (stage1) and a bidirectional transformer for the prior learning (stage2).

Notes

The implementation has been modified for better performance and smaller memory consumption. Therefore, the resulting evaluation metrics are probably somewhat different from the repoted scores in the paper. We've done so to benefit the community for their practical use. For details, see the Update Notes section below.

Install / Environment setup

You should first create a virtual environment, and activate the environment. Then you can install the necessary libraries by running the following command.

pip install -r requirements.txt

Dataset and Dataset Download

The UCR archive datasets are automatically downloaded if you run any of the training command below such as python stage1.py. If you just want to download the datasets only without running the training, run

python preprocessing/preprocess_ucr.py

[update note on July 8, 2024] We now use a larger training set by using the following re-arranged dataset: We reorganized the original datasets from the UCR archive by 1) merging the existing training and test sets, 2) resplitting it using StratifiedShuffleSplit (from sklearn) into 80% and 20% for a training set and test set, respectively. We did so becaused the original datasets have two primary issues to be used to train a time series generative model. Firstly, a majority of the datasets have a larger test set compared to a training set. Secondly, there is clear difference in patterns between training and test sets for some of the datasets. The data format remains the same.

Usage

Configuration

  • configs/config.yaml: configuration for dataset, data loading, optimizer, and models (i.e., encoder, decoder, vector-quantizer, and MaskGIT)
  • config/sconfig_cas.yaml: configuration for running CAS, Classification Accuracy Score (= TSTR, Training on Synthetic and Test on Real).

Training: Stage1 and Stage2

python stage1.py --dataset_names Wafer --gpu_device_ind 0
python stage2.py --dataset_names Wafer --gpu_device_ind 0

The trained model is saved in saved_models/. The details of the logged metrics are documented in evaluation/README.md.

Evaluation

FID, IS, visual inspection between $p(X)$ and $p_\theta(\hat{X})$ with the corresponding comparison in an evaluation latent space.

python evaluate.py --dataset_names Wafer --gpu_device_idx 0

Run CAS (Classification Accuracy Score)

python run_CAS.py  --dataset_names Wafer --gpu_device_idx 0

Minimal Code for Sampling

Refer to simple_sampling.ipynb.

Train it on a Custom Dataset

  1. a template class, DatasetImporterCustom, is given in preprocessing/preprocess_ucr.py.
    • no need to modify any other code except DatasetImporterCustom to train TimeVQVAE on your dataset.
  2. write a data loading code for your dataset in __init__ within DatasetImporterCustom.
  3. run the following codes - stage1,2.
python stage1.py --use_custom_dataset True --dataset_names custom --gpu_device_ind 0
python stage2.py --use_custom_dataset True --dataset_names custom --gpu_device_ind 0
python evaluate.py --use_custom_dataset True --dataset_names custom --gpu_device_idx 0

Also, you can sample synthetic time series with custom_dataset_sampling.ipynb.

Google Colab

Google Colab (NB! make sure to change your notebook setting to GPU.)

A Google Colab notebook is available for time series generation with the pretrained VQVAE. The usage is simple:

  1. User Settings: specify dataset_name and n_samples_to_generate.
  2. Sampling: Run the unconditional sampling and class-conditional sampling.

Update Notes

Implementation Modifications

  • [2024.07.26] updated $E$ and $D$ so that they have incremental hidden dimension sizes for depths; cosine annealing w/ linear warmup lr scheduler is used; reconstruction loss on a time domain only while modeling a discrete latent space from a time-frequency domain as before.
  • [2024.07.23] Snake activation [6] function is used instead of (Leaky)ReLU in the encoder and decoder. It's shown to generally improve the VQVAE's reconstruction capability in my experiments, especially beneficial for periodic time series like the ones in the FordA dataset.
  • [2024.07.08] using the re-organized datasets instead of the original datasets, as decrived above in the Data Download section.
  • [2024.07.04] FID score can be computed with ROCKET representations in evaluate.py by setting --feature_extractor_type rocket. We found that the representations from ROCKET [5] result in more robust distributional plot with PCA and FID score. That is because the ROCKET representations are the most unbiased representations, as it is not trained at all, unlike any supervised methods like supervised FCN. This is a default setting now. Also, this enables FID score computation on a custom dataset, which supervisd FCN cannot do.
  • [2024.07.02] use a convolutional-based upsampling layer, (nearest neighbor interpolation - convs), to lengthen the LF token embeddings to match with the length of HF embeddings. Linear used to be used; Strong dropouts are used to the LF and HF embeddings within forward_hf in bidirectional_transformer.py to make the sampling process robust; Smaller HF transformer is used due to an overfitting problem; n_fft of 4 is used instead of 8.
  • [2024.07.01] compute the prior loss only on the masked locations, instead of the entire tokens.

Fidelity Enhancer for Vector Quantized Time Series Generator (FE-VQTSG) [3] (not published yet)

It is a U-Net-based mapping model that transforms a synthetic time series generated by a VQ-based TSG method to be more realistic while retaining the original context.

The model training is availble after finishing the stage1 and stage2 trainings. To train FE-VQTSG, run

python stage_fid_enhancer.py  --dataset_names Wafer --gpu_device_ind 0

During the evaluation, FE-VQTSG can be employed by setting --use_fidelity_enhancer True.

python evaluate.py --dataset_names Wafer --gpu_device_idx 0 --use_fidelity_enhancer True

TimeVQVAE for Anomaly Detection (TimeVQVAE-AD) [4]

TimeVQVAE learns a prior, and we can utilize the learned prior to measure the likelihood of a segment of time series, in which a high likelihood indicates a normal state while a low likelihood indicates an abnormal state (i.e., anomaly). With that principal, we have developed TimeVQVAE-AD. It not only achieves a state-of-the-art anomaly detection accuracy on the UCR Anomaly archive, but also provides a high level of explainability, covering counterfactual sampling (i.e., to answer the following question, "how is the time series supposed look if there was no anomaly?"). If AD is your interest, please check out the paper. Its open-source code is available here.

Citation

[1] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Vector Quantized Time Series Generation with a Bidirectional Prior Model." International Conference on Artificial Intelligence and Statistics. PMLR, 2023.

[3]

[4] Lee, Daesoo, Sara Malacarne, and Erlend Aune. "Explainable time series anomaly detection using masked latent generative modeling." Pattern Recognition (2024): 110826.

[5] Dempster, Angus, François Petitjean, and Geoffrey I. Webb. "ROCKET: exceptionally fast and accurate time series classification using random convolutional kernels." Data Mining and Knowledge Discovery 34.5 (2020): 1454-1495.

[6] Ziyin, Liu, Tilman Hartwig, and Masahito Ueda. "Neural networks fail to learn periodic functions and how to fix it." Advances in Neural Information Processing Systems 33 (2020): 1583-1594.

About

[official] PyTorch implementation of TimeVQVAE from the paper ["Vector Quantized Time Series Generation with a Bidirectional Prior Model", AISTATS 2023]

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published