STGAT: Modeling Spatial-Temporal Interactions for Human Trajectory Prediction
Our statement about Average Displacement Error (ADE) in the paper is wrong, and it should be RMSE or L2 distance (as in SocialAttention and SocialGan).
- Python 3
- PyTorch (1.2)
- Matplotlib
All the data comes from the SGAN model without any further processing.
- First
cd STGAT
- To train the model run
python train.py
(see the code to understand all the arguments that can be given to the command) - To evalutae the model run
python evaluate_model.py
- Using the default parameters in the code, you can get most of the numerical results presented in the paper. But a reasonable attention visualization may require trained for a longer time and tuned some parameters. For example, for the zara1 dataset and
pred_len
is 8 time-steps,, you can setnum_epochs
to600
(line 36 intrain.py
), and thelearning rate
in step3 to1e-4
(line 180 intrain.py
). - The attachment folder contains the code that produces the attention figures presented in the paper
- Check out the issue of this repo to find out how to get better results on the ETH dataset.
All data and part of the code comes from the SGAN model. If you find this code useful in your research then please also cite their paper.
If you have any questions, please contact [email protected], and if you find this repository useful for your research, please cite the following paper:
@InProceedings{Huang_2019_ICCV,
author = {Huang, Yingfan and Bi, Huikun and Li, Zhaoxin and Mao, Tianlu and Wang, Zhaoqi},
title = {STGAT: Modeling Spatial-Temporal Interactions for Human Trajectory Prediction},
booktitle = {The IEEE International Conference on Computer Vision (ICCV)},
month = {October},
year = {2019}
}