This repository is the official github repository for OptTrajDiff.
OptTrajDiff: Optimizing Diffusion Models for Joint Trajectory Prediction and Controllable Generation
Yixiao Wang, Chen Tang, Lingfeng Sun, Simone Rossi, Yichen Xie, Chensheng Peng, Thomas Hannagan, Stefano Sabatini, Nicola Poerio, Masayoshi Tomizuka, Wei Zhan,
Diffusion models are promising for joint trajectory prediction and controllable generation in autonomous driving, but they face challenges of inefficient inference time and high computational demands. To tackle these challenges, we introduce Optimal Gaussian Diffusion (OGD) and Estimated Clean Manifold (ECM) Guidance. OGD optimizes the prior distribution for a small diffusion time $T$ and starts the reverse diffusion process from it. ECM directly injects guidance gradients to the estimated clean manifold, eliminating extensive gradient backpropagation throughout the network. Our methodology streamlines the generative process, enabling practical applications with reduced computational overhead. Experimental validation on the large-scale Argoverse 2 dataset demonstrates our approach's superior performance, offering a viable solution for computationally efficient, high-quality joint trajectory generation and controllable generation for autonomous driving.
Step 1: Download the code by cloning the repository:
git clone https://github.com/YixiaoWang7/OptTrajDiff.git && cd OptTrajDiff
Step 2: Set up a new conda environment and install required packages:
conda env create -f environment.yml
conda activate OptTrajDiff
Step 3: Implement the Argoverse 2 API and access the Argoverse 2 Motion Forecasting Dataset. Please see the Argoverse 2 User Guide.
python train_diffnet_tb.py --root <Path to dataset> --train_batch_size 16 --val_batch_size 4 --test_batch_size 4 --dataset argoverse_v2 --num_historical_steps 50 --num_future_steps 60 --num_recurrent_steps 3 --pl2pl_radius 150 --time_span 10 --pl2a_radius 50 --a2a_radius 50 --num_t2m_steps 30 --pl2m_radius 150 --a2m_radius 150 --devices "4,5,6" --qcnet_ckpt_path <Path to QCNet checkpoint> --num_workers 4 --num_denoiser_layers 3 --num_diffusion_steps 100 --T_max 30 --max_epochs 30 --lr 0.005 --beta_1 0.0001 --beta_T 0.05 --diff_type opd --sampling ddim --sampling_stride 10 --num_eval_samples 6 --choose_best_mode FDE --std_reg 0.3 --check_val_every_n_epoch 3 --path_pca_s_mean 'pca/imp_org/s_mean_10.npy' --path_pca_VT_k 'pca/imp_org/VT_k_10.npy' --path_pca_V_k 'pca/imp_org/V_k_10.npy' --path_pca_latent_mean 'pca/imp_org/latent_mean_10.npy' --path_pca_latent_std 'pca/imp_org/latent_std_10.npy'
Below are the significant arguments related to our work:
--devices
: Specifies the GPUs you want to use.--qcnet_ckpt_path
: Provides the path to the QCNet checkpoints.--num_denoiser_layers
: Defines the number of layers in the diffusion network.--num_diffusion_steps
: Sets the number of diffusion steps.--max_epochs
: Determines the total number of training epochs.--lr
: Sets the learning rate.--beta_1
: Specifies the $\beta_1$, the diffusion schedule parameter.--beta_T
: Specifies the $\beta_T$, the diffusion schedule parameter.--sampling_stride
: Defines the sampling stride for DDIM.--num_eval_samples
: Indicates the number of evaluation samples.python val_diffnet.py --root <Path to dataset> --ckpt_path <Path to diffusion network checkpoint> --devices '5,' --batch_size 8 --sampling ddim --sampling_stride 10 --num_eval_samples 128 --std_reg 0.3 --path_pca_V_k 'pca/imp_org/V_k_10.npy' --network_mode 'val'
python val_diffnet.py --root <Path to dataset> --ckpt_path <Path to diffusion network checkpoint> --devices '2,' --batch_size 16 --sampling ddim --sampling_stride 10 --num_eval_samples 128 --std_reg 0.3 --path_pca_V_k 'pca/imp_org/V_k_10.npy' --network_mode 'val' --guid_sampling 'guid' --guid_task 'rand_goal_5s' --guid_method <Guided sampling method> --guid_plot plot --cost_param_costl 10.0 --cost_param_threl 1.0
Below are the key arguments relevant to our work:
--guid_method
: Specifies the guided sampling method you wish to use. Options are ['none', 'ECM', 'ECMR']
.--guid_task
: Defines the controllable tasks you plan to test. Options include ['none', 'goal', 'goal_5s', 'goal_at5s', 'rand_goal', 'rand_goal_5s', 'rand_goal_at5s']
.If you find our work helpful, please star 🌟 this repo and cite 📑 our paper. BibTex will be updated soon. Thanks for your support!
@inproceedings{wang2025optimizing,
title={Optimizing diffusion models for joint trajectory prediction and controllable generation},
author={Wang, Yixiao and Tang, Chen and Sun, Lingfeng and Rossi, Simone and Xie, Yichen and Peng, Chensheng and Hannagan, Thomas and Sabatini, Stefano and Poerio, Nicola and Tomizuka, Masayoshi and others},
booktitle={European Conference on Computer Vision},
pages={324--341},
year={2025},
organization={Springer}
}
This repository is developed based on Query-Centric Trajectory Prediction. Please also consider citing:
@inproceedings{zhou2023query,
title={Query-Centric Trajectory Prediction},
author={Zhou, Zikang and Wang, Jianping and Li, Yung-Hui and Huang, Yu-Kai},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2023}
}