TensorFlow implementation for stochastic adversarial video prediction. Given a sequence of initial frames, our model is able to predict future frames of various possible futures. For example, in the next two sequences, we show the ground truth sequence on the left and random predictions of our model on the right. Predicted frames are indicated by the yellow bar at the bottom. For more examples, visit the project page.
Stochastic Adversarial Video Prediction,
Alex X. Lee, Richard Zhang, Frederik Ebert, Pieter Abbeel, Chelsea Finn, Sergey Levine.
arXiv preprint arXiv:1804.01523, 2018.
An alternative implementation of SAVP is available in the Tensor2Tensor library.
git clone -b master --single-branch https://github.com/alexlee-gk/video_prediction.git
cd video_prediction
pip install -r requirements.txt
PYTHONPATH
, e.g. export PYTHONPATH=path/to/video_prediction
.download_model.sh
script).bair
):
bash data/download_and_preprocess_dataset.sh bair
ours_savp
) for the action-free version of that dataset (i.e. bair_action_free
):
bash pretrained_models/download_model.sh bair_action_free ours_savp
CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \
--dataset_hparams sequence_length=30 \
--checkpoint pretrained_models/bair_action_free/ours_savp \
--mode test \
--results_dir results_test_samples/bair_action_free
results_test_samples/bair_action_free/ours_savp
.CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair \
--dataset_hparams sequence_length=30 \
--checkpoint pretrained_models/bair_action_free/ours_savp \
--mode test \
--results_dir results_test/bair_action_free
results_test/bair_action_free/ours_savp
.scripts/generate_all.sh
and scripts/evaluate_all.sh
.bair
):
bash data/download_and_preprocess_dataset.sh bair
CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair \
--model savp --model_hparams_dict hparams/bair_action_free/ours_savp/model_hparams.json \
--output_dir logs/bair_action_free/ours_savp
tensorboard --logdir logs/bair_action_free --port 6006
and open http://localhost:6006.
long_sequence_length
differs from sequence_length
).CUDA_VISIBLE_DEVICES
to a comma-separated list of devices, e.g. CUDA_VISIBLE_DEVICES=0,1,2,3
. To use the CPU, set CUDA_VISIBLE_DEVICES=""
.scripts/train_all.sh
.Download the datasets using the following script. These datasets are collected by other researchers. Please cite their papers if you use the data.
bash data/download_and_preprocess_dataset.sh dataset_name
The dataset_name
should be one of the following:
bair
: BAIR robot pushing dataset. [Citation]kth
: KTH human actions dataset. [Citation]To use a different dataset, preprocess it into TFRecords files and define a class for it. See kth_dataset.py
for an example where the original dataset is given as videos.
Note: the bair
dataset is used for both the action-free and action-conditioned experiments. Set the hyperparameter use_state=True
to use the action-conditioned version of the dataset.
bash pretrained_models/download_model.sh dataset_name model_name
The dataset_name
should be one of the following: bair_action_free
, kth
, or bair
.
The model_name
should be one of the available pre-trained models:
ours_savp
: our complete model, trained with variational and adversarial losses. Also referred to as ours_vae_gan
.The following are ablations of our model:
ours_gan
: trained with L1 and adversarial loss, with latent variables sampled from the prior at training time.ours_vae
: trained with L1 and KL loss.ours_deterministic
: trained with L1 loss, with no stochastic latent variables.See pretrained_models/download_model.sh
for a complete list of available pre-trained models.
The implementation is designed such that each video prediction model defines its architecture and training procedure, and include reasonable hyperparameters as defaults.
Still, a few of the hyperparameters should be overriden for each variant of dataset and model.
The hyperparameters used in our experiments are provided in hparams
as JSON files, and they can be passed onto the training script with the --model_hparams_dict
flag.
If you find this useful for your research, please use the following.
@article{lee2018savp,
title={Stochastic Adversarial Video Prediction},
author={Alex X. Lee and Richard Zhang and Frederik Ebert and Pieter Abbeel and Chelsea Finn and Sergey Levine},
journal={arXiv preprint arXiv:1804.01523},
year={2018}
}