XiYe20 / STDiffProject

[AAAI'24] "STDiff: Spatio-temporal Diffusion for Continuous Stochastic Video Prediction". Xi Ye, Guillaume-Alexandre Bilodeau
12 stars 3 forks source link

STDiff: Spatio-temporal diffusion for continuous stochastic video prediction

arXiv | code

STDiff_BAIR_15

Overview

STDiff Architecture

Installation

  1. Install the custom diffusers library
    git clone https://github.com/XiYe20/CustomDiffusers.git
    cd CustomDiffusers
    pip install -e .
  2. Install the requirements of STDiff
    pip install -r requirements.txt

Datasets

Processed KTH dataset: https://drive.google.com/file/d/1RbJyGrYdIp4ROy8r0M-lLAbAMxTRQ-sd/view?usp=sharing \ SM-MNIST: https://drive.google.com/file/d/1eSpXRojBjvE4WoIgeplUznFyRyI3X64w/view?usp=drive_link

For other datasets, please download them from the official website. Here we show the dataset folder structure.

BAIR

Please download the original BAIR dataset and utilize the "/utils/read_BAIR_tfrecords.py" script to convert it into frames as follows:

/BAIR \      test/ \          example_0/ \             0000.png \             0001.png \             ... \          example1/ \             0000.png \             0001.png \             ... \          example... \      train/ \          example0/ \             0000.png \             0001.png \             ... \          example...

Cityscapes

Please download "leftImg8bit_sequence_trainvaltest.zip" from the official website. Center crop and resize all the frames to the size of 128X128. Save all the frames as follows:

/Cityscapes \      test/ \          berlin/ \             berlin_000000_000000_leftImg8bit.png \             berlin_000000_000001_leftImg8bit.png \             ... \          bielefeld/ \             bielefeld_000000_000302_leftImg8bit.png \             bielefeld_000000_000302_leftImg8bit.png \             ... \          ... \      train/\          aachen/ \             .... \          bochum/ \             .... \          ... \      val/\             ....

KITTI

Please download the raw data (synced+rectified) from KITTI official website. Center crop and resize all the frames to the resolution of 128X128. Save all the frames as follows:

/KITTI \      2011_09_26_drive_0001_sync/ \             0000000000.png \             0000000001.png \             ... \      2011_09_26_drive_0002_sync/ \             ... \       ...

Training and Evaluation

The STDiff project uses accelerate for training. The training configuration files and test configuration files for different datasets are placed inside stdiff/configs.

Training

  1. Check train_script.sh, modify the visible gpus, num_process, select the correct train_cofig file
  2. Training
    . ./train_script.sh

Test

  1. Check test_script.sh, select the correct test_cofig file
  2. Test
    . ./test_script.sh

Citation

@inproceedings{ye2024stdiff,
  title={STDiff: Spatio-Temporal Diffusion for Continuous Stochastic Video Prediction},
  author={Ye, Xi and Bilodeau, Guillaume-Alexandre},
  booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
  volume={38},
  number={7},
  pages={6666--6674},
  year={2024}
}

Uncurated prediction examples of STDiff for multiple datasets.

The temporal coordinates are shown at the top left corner of the frame. Frames with Red temporal coordinates denote future frames predicted by our model.

BAIR

STDiff_BAIR_0

STDiff_BAIR_15

SMMNIST

STDiff_SMMNIST_7

STDiff_SMMNIST_10

KITTI

STDiff_KITTI_0

STDiff_KITTI_22

Cityscapes

STDiff_City_110

STDiff_City_120