This is the official implementation of the paper Toward Zero-forgetting Continual Learning for Interactive Trajectory Prediction: A Dynamically Expandable Approach.
The conda environment can be created by the following command
conda create -n DEITP python=3.8
conda activate DEITP
conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt
Dataset used in this work comes from publicly available INTERACTION datasets, and the preprocessing method follows the D-GSM. The original dataset used in the experiment can be downloaded from here. Please download the dataset and put it under the data/original
folder. Training scripts will automatically process the data and save processed data in the data/processed
folder.
The repository provides training scripts for several continual learning approaches in continuous traffic scenarios. DEITP, called as Dynamically Expandable Model (DEM) for simplicity in the code, will be trained by running the following command:
bash scripts/train_dem.sh
We also propose a familiarity autoencdoer (FAE) based approach for task detection in the task-free continual learning (TFCL) setting. The FAE-based approach can be trained by running the following command:
bash scripts/train_fae.sh
The trained models can be tested by running the following commands. If you want to test the DEM model to predict on all tasks, run:
python test.py --task_detect 0 --task_free 0 --task_predict 1
If you want to test the approach to predict on all tasks in the TFCL setting, run:
python test.py --task_detect 0 --task_free 1 --task_predict 1
If you want merely to test task detection, run:
python test.py --task_detect 1 --task_free 0 --task_predict 0
We sincerely appreciate the following github repos for their valuable code base we build upon: