YixiaoWang7 / OptTrajDiff

25 stars 0 forks source link


This repository is the official github repository for OptTrajDiff.

OptTrajDiff: Optimizing Diffusion Models for Joint Trajectory Prediction and Controllable Generation
Yixiao Wang, Chen Tang, Lingfeng Sun, Simone Rossi, Yichen Xie, Chensheng Peng, Thomas Hannagan, Stefano Sabatini, Nicola Poerio, Masayoshi Tomizuka, Wei Zhan,

Diffusion models are promising for joint trajectory prediction and controllable generation in autonomous driving, but they face challenges of inefficient inference time and high computational demands. To tackle these challenges, we introduce Optimal Gaussian Diffusion (OGD) and Estimated Clean Manifold (ECM) Guidance. OGD optimizes the prior distribution for a small diffusion time $T$ and starts the reverse diffusion process from it. ECM directly injects guidance gradients to the estimated clean manifold, eliminating extensive gradient backpropagation throughout the network. Our methodology streamlines the generative process, enabling practical applications with reduced computational overhead. Experimental validation on the large-scale Argoverse 2 dataset demonstrates our approach's superior performance, offering a viable solution for computationally efficient, high-quality joint trajectory generation and controllable generation for autonomous driving.



Initial Setup

Step 1: Download the code by cloning the repository:

git clone https://github.com/YixiaoWang7/OptTrajDiff.git && cd OptTrajDiff

Step 2: Set up a new conda environment and install required packages:

conda env create -f environment.yml
conda activate OptTrajDiff

Step 3: Implement the Argoverse 2 API and access the Argoverse 2 Motion Forecasting Dataset. Please see the Argoverse 2 User Guide.

Joint Trajectory Prediction with Optimal Gaussian Diffusion

Training Command

python train_diffnet_tb.py --root <Path to dataset> --train_batch_size 16 --val_batch_size 4 --test_batch_size 4 --dataset argoverse_v2 --num_historical_steps 50 --num_future_steps 60 --num_recurrent_steps 3 --pl2pl_radius 150 --time_span 10 --pl2a_radius 50 --a2a_radius 50 --num_t2m_steps 30 --pl2m_radius 150 --a2m_radius 150 --devices "4,5,6" --qcnet_ckpt_path <Path to QCNet checkpoint> --num_workers 4 --num_denoiser_layers 3 --num_diffusion_steps 100 --T_max 30 --max_epochs 30 --lr 0.005 --beta_1 0.0001 --beta_T 0.05 --diff_type opd --sampling ddim --sampling_stride 10 --num_eval_samples 6 --choose_best_mode FDE --std_reg 0.3 --check_val_every_n_epoch 3 --path_pca_s_mean 'pca/imp_org/s_mean_10.npy' --path_pca_VT_k 'pca/imp_org/VT_k_10.npy' --path_pca_V_k 'pca/imp_org/V_k_10.npy' --path_pca_latent_mean 'pca/imp_org/latent_mean_10.npy' --path_pca_latent_std 'pca/imp_org/latent_std_10.npy'

Below are the significant arguments related to our work:

Validation Command

python val_diffnet.py --root <Path to dataset> --ckpt_path <Path to diffusion network checkpoint> --devices '5,' --batch_size 8 --sampling ddim --sampling_stride 10 --num_eval_samples 128 --std_reg 0.3 --path_pca_V_k 'pca/imp_org/V_k_10.npy' --network_mode 'val'

Controllable Generation with ECM and ECMR

python val_diffnet.py --root <Path to dataset> --ckpt_path <Path to diffusion network checkpoint> --devices '2,' --batch_size 16 --sampling ddim --sampling_stride 10 --num_eval_samples 128 --std_reg 0.3 --path_pca_V_k 'pca/imp_org/V_k_10.npy' --network_mode 'val' --guid_sampling 'guid' --guid_task 'rand_goal_5s' --guid_method <Guided sampling method> --guid_plot plot --cost_param_costl 10.0 --cost_param_threl 1.0

Below are the key arguments relevant to our work:


If you find our work helpful, please star 🌟 this repo and cite 📑 our paper. BibTex will be updated soon. Thanks for your support!

  title={Optimizing diffusion models for joint trajectory prediction and controllable generation},
  author={Wang, Yixiao and Tang, Chen and Sun, Lingfeng and Rossi, Simone and Xie, Yichen and Peng, Chensheng and Hannagan, Thomas and Sabatini, Stefano and Poerio, Nicola and Tomizuka, Masayoshi and others},
  booktitle={European Conference on Computer Vision},

This repository is developed based on Query-Centric Trajectory Prediction. Please also consider citing:

  title={Query-Centric Trajectory Prediction},
  author={Zhou, Zikang and Wang, Jianping and Li, Yung-Hui and Huang, Yu-Kai},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},