tedhuang96 / gst

[RA-L + ICRA22] Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction
MIT License
38 stars 3 forks source link
deep-learning human-robot-interaction human-trajectory-prediction lstm pytorch trajectory-prediction transformer

GST

This is the implementation for the paper

Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction

Zhe Huang, Ruohua Li, Kazuki Shin, Katherine Driggs-Campbell

published in RA-L.

[Paper] [arXiv] [Project]

GST is the abbreviation of our model Gumbel Social Transformer. All code was developed and tested on Ubuntu 18.04 with CUDA 10.2, Python 3.6.9, and PyTorch 1.7.1.

Citation

If you find this repo useful, please cite

@article{huang2022learning,
  title={Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction},
  author={Huang, Zhe and Li, Ruohua and Shin, Kazuki and Driggs-Campbell, Katherine},
  journal={IEEE Robotics and Automation Letters},
  year={2022},
  volume={7},
  number={2},
  pages={1198-1205},
  doi={10.1109/LRA.2021.3138547}
}

Setup

1. Create a Virtual Environment. (Optional)
virtualenv -p /usr/bin/python3 myenv
source myenv/bin/activate
2. Install Packages

You can run either

pip install -r requirements.txt

or

pip install numpy
pip install scipy
pip install matplotlib
pip install tensorboardX
pip install torch==1.7.1

If you want to use tensorboard --logdir results to check training curves, install tensorflow by running

pip install tensorflow
3. Create Folders and Dataset Files.
sh run/make_dirs.sh
sh run/create_datasets.sh

Training and Evaluation on Various Configurations

To train and evaluate a model with n=1, i.e., the target pedestrian pays attention to at most one partially observed pedestrian, run

sh run/train_sparse.sh
sh run/eval_sparse.sh

To train and evaluate a model with n=1 and temporal component as a temporal convolution network, run

sh run/train_sparse_tcn.sh
sh run/eval_sparse_tcn.sh

To train and evaluate a model with full connection, i.e., the target pedestrian pays attention to all partially observed pedestrians in the scene, run

sh run/train_full_connection.sh
sh run/eval_full_connection.sh

To train and evaluate a model in which the target pedestrian pays attention to all fully observed pedestrians in the scene, run

sh run/train_full_connection_fully_observed.sh
sh run/eval_full_connection_fully_observed.sh

Important Arguments for Building Customized Configurations

Credits

Part of the code is based on the following works and repos:

[1] Mohamed, Abduallah, et al. "Social-stgcnn: A social spatio-temporal graph convolutional neural network for human trajectory prediction." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020. [GitHub]

[2] Pytorch implementation of Multi-head Attention. [Modules] [Functional]

Contact

Please feel free to open an issue or send an email to zheh4@illinois.edu.