ChuhuaW / SGNet.pytorch

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)
116 stars 16 forks source link
deep-learning pytorch trajectory-prediction

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)

Installation

Cloning

We use part of the dataloader in Trajectron++, so we include Trajectron++ as a submodule.

git clone --recurse-submodules git@github.com:ChuhuaW/SGNet.pytorch.git

Environment

conda env create --file SGNet_env.yml

Data

ln -s path/to/dataset/ ./data/
ln -s path/to/dataset/ ./data/

Training

Stochastic prediction

Deterministic prediction

Evaluation

Stochastic prediction

Deterministic prediction

cd SGDNet.Pytorch
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ETH --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset HOTEL --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset UNIV --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA1 --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA2 --model SGNet --checkpoint path/to/checkpoint

JAAD/PIE checkpoints

Citation

@ARTICLE{9691856,
  author={Wang, Chuhua and Wang, Yuchen and Xu, Mingze and Crandall, David J.},
  journal={IEEE Robotics and Automation Letters}, 
  title={Stepwise Goal-Driven Networks for Trajectory Prediction}, 
  year={2022}}
- Rank 3rd on nuScences prediction task at 6th AI Driving Olympics, ICRA 2021

The source code and pretrained models will be made availble. Stay tuned. PWC PWC PWC