xupei0610 / SocialVAE

[ECCV2022] SocialVAE: Human Trajectory Prediction using Timewise Latents
MIT License
64 stars 11 forks source link

SocialVAE: Human Trajectory Prediction using Timewise Latents

This is the official implementation for SocialVAE: Human Trajectory Prediction using Timewise Latents. [arXiv] [YouTube]

Abstract -- Predicting pedestrian movement is critical for human behavior analysis and also for safe and efficient human-agent interactions. However, despite significant advancements, it is still challenging for existing approaches to capture the uncertainty and multimodality of human navigation decision making. In this paper, we propose SocialVAE, a novel approach for human trajectory prediction. The core of SocialVAE is a timewise variational autoencoder architecture that exploits stochastic recurrent neural networks to perform prediction, combined with a social attention mechanism and backward posterior approximation to allow for better extraction of pedestrian navigation strategies. We show that SocialVAE improves current state-of-the-art performance on several pedestrian trajectory prediction benchmarks, including the ETH/UCY benchmark, the Stanford Drone Dataset and SportVU NBA movement dataset.

@inproceedings{socialvae2022,
   author={Xu, Pei and Hayet, Jean-Bernard and Karamouzas, Ioannis},
   title={SocialVAE: Human Trajectory Prediction using Timewise Latents},
   booktitle={European Conference on Computer Vision},
   pages={511-528},
   year={2022},
   organization={Springer},
   doi={10.1007/978-3-031-19772-7_30}
}
Our approach shows low errors in trajectory prediction on challenging scenarios with complex and intensive human-human interctions. Below we show the prediction of our model for basketball players. We also include our NBA datasets (data/nba) in this repository. Caution: the NBA datasets were recorded in the unit of feet. Please refer to our paper for more details. Predictions Heatmap Attention

Dependencies

We recommend to install all the requirements through Conda by

$ conda create --name <env> --file requirements.txt -c pytorch -c conda-forge

Code Usage

Command to train a model from scratch:

$ python main.py --train <train_data_dir> --test <test_data_dir> --ckpt <checkpoint_dir> --config <config_file>

We provide the training and testing data in data folder and the configuration files that we used in config folder. To reproduce the reported results, please run

# ETH/UCY benchmarks
$ python main.py --train data/eth/train --test data/eth/test --ckpt log_eth --config config/eth.py
$ python main.py --train data/hotel/train --test data/hotel/test --ckpt log_hotel --config config/hotel.py
$ python main.py --train data/univ/train --test data/univ/test --ckpt log_univ --config config/univ.py
$ python main.py --train data/zara01/train --test data/zara01/test --ckpt log_zara01 --config config/zara01.py
$ python main.py --train data/zara02/train --test data/zara02/test --ckpt log_zara02 --config config/zara02.py

# SDD benchmark
$ python main.py --train data/sdd/train --test data/sdd/test --ckpt log_sdd --config config/sdd.py

# NBA benchmark
$ python main.py --train data/nba/rebound/train --test data/nba/rebound/test --ckpt log_rebound --config config/nba_rebound.py
$ python main.py --train data/nba/score/train --test data/nba/score/test --ckpt log_score --config config/nba_score.py

Evaluation and Pre-trained Models

Command to evaluate a pre-trained model:

$ python main.py --test <test_data_dir> --ckpt <checkpoint_dir> --config <config_file>

We provide our pretained models in models folder. To evaluate our pre-trained models, please run

# ETH/UCY benchmarks
$ python main.py --test data/eth/test --ckpt models/eth --config config/eth.py
$ python main.py --test data/hotel/test --ckpt models/hotel --config config/hotel.py
$ python main.py --test data/univ/test --ckpt models/univ --config config/univ.py
$ python main.py --test data/zara01/test --ckpt models/zara01 --config config/zara01.py
$ python main.py --test data/zara02/test --ckpt models/zara02 --config config/zara02.py

# SDD benchmark
$ python main.py --test data/sdd/test --ckpt models/sdd --config config/sdd_pixel.py

# NBA benchmark
$ python main.py --test data/nba/rebound/test --ckpt models/nba/rebound --config config/nba_rebound.py
$ python main.py --test data/nba/score/test --ckpt models/nba/score --config config/nba_score.py

Due to the large size of NBA scoring dataset (592,640 trajectories), it may take about 2 hours to perform a full test with FPC. To run the model without FPC for fast testing, please use --no-fpc option.

All our training and testing were done on machines equipped with V100 GPU. The test results may vary a little when the model runs on machines with different hardware or a different version of pytorch/cuda.

Training New Models

Prepare your own dataset

Our code supports loading trajectories from multiple files, each of which represent a scene. Just split your data into training and testing sets and put each scene as a txt file into the corresponding folder.

Each line in the data files is in the format of

frame_ID:int  agent_ID:int  pos_x:float  pos_y:float  group:str

where frame_ID and agent_ID are integers and pos_x and pos_y are float numbers. The group field is optional to identify the agent type/group such that the model can be trained to perform prediction for specific groups/types of agents. See config/nba_rebound.py for an example where the model is trained to predict the movement of players only and the basketball appears only as a neighbor of other agents.

Setup your config file

We provide our config files in config folder, which can be used as reference.

A key hyperparameter that needs to pay attention is NEIGHBOR_RADIUS. In a common scenario with causal human walking, it can be values from 2 to 5. For intensive human movement, it could be 5-10 and even larger.

Training

$ python main.py --train <folder_of_training_data> --test <folder_of_testing_data> --ckpt <checkpoint_folder> --config <config_file>

Evaluation

$ python main.py --test <folder_of_testing_data> --ckpt <checkpoint_folder> --config <config_file>

The script will automatically run FPC hyperparameter finetuning after the training is done. To manually perform the finetuning given an existing model, please run the evaluation command with --fpc_finetune option.