TRI-ML / RAP

This is the official code for the paper RAP: Risk-Aware Prediction for Robust Planning: https://arxiv.org/abs/2210.01368
Other
34 stars 8 forks source link
autonomous-driving machine-learning risk-modelling trajectory-prediction

License statement

The code is provided under a Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) license. Under the license, the code is provided royalty free for non-commercial purposes only. The code may be covered by patents and if you want to use the code for commercial purposes, please contact us for a different license.

RAP: Risk-Aware Prediction

This is the official code for RAP: Risk-Aware Prediction for Robust Planning. You can test the results in our huggingface demo and see some additional experiments on the paper website.

A planner reacts to low-probability events if they are dangerous, biasing the predictions to better represent these events helps the planner to be cautious.

We define and train a trajectory forecasting model and bias its prediction towards risk such that it helps a planner to estimate risk by producing the relevant pessimistic trajectory forecasts to consider.

Datasets

This repository uses two datasets:

Forecasting model

A conditional variational auto-encoder (CVAE) model is used as the base pedestrian trajectory predictor. Its latent space is quantized or gaussian depending on the parameter that you set in the config. It uses either multi-head attention or a modified version of context gating to account for interactions. Depending on the parameters, the trajectory encoder and decoder can be set to MLP, LSTM, or maskedLSTM.

Usage

Installation

Setting up the data

Didactic simulation

WOMD

Configuration and training

Didactic simulation

WOMD

Training has two phases: training the unbiased predictor then training the biased encoder with a frozen predictor model. This second step need to draw many samples to estimate the risk. It is possible that your GPU runs out of memory at this stage. If it does consider reducing the batch size and reducing the number of samples "n_mc_samples_biased". If the number of samples "n_mc_samples_risk" is kept high, the risk estimation will be more accurate but training might be very slow.

Evaluation

Many evaluation scripts are available in "scripts/eval_scripts", to compute results, plot graphs, draw the didactic experiment scene etc...

You can also run the interactive interface locally with python scripts/scripts_utils/plotly_interface.py --load_from=<full path to the .ckpt checkpoint file> --cfg_path=<full path to the learning_config.py file from the checkpoint> Sadly the WOMD license does not allow us to provide the pre-trained weights of our model so you will need to train it yourself.