explainingai-code / DiT-PyTorch

This repo implements Diffusion Transformers(DiT) in PyTorch and provides training and inference code on CelebHQ dataset
8 stars 2 forks source link
diffusion-transformer

Diffusion Transformers(DiT) Implementation in PyTorch

DiT Tutorial Video

<img alt="DiT Tutorial" src="https://github.com/user-attachments/assets/2b0deffa-0181-4676-b79f-fec9b12d8326" width="400">

Sample Output for DiT on CelebHQ

Trained for 200 epochs


This repository implements DiT in PyTorch for diffusion models. It provides code for the following:

This is very similar to official DiT implementation except the following changes.

Setup

Data Preparation

CelebHQ

For setting up on CelebHQ, simply download the images from the official repo of CelebMASK HQ here. and add it to data directory. Ensure directory structure is the following

DiT-PyTorch
    -> data
        -> CelebAMask-HQ
            -> CelebA-HQ-img
                -> *.jpg

Configuration

Allows you to play with different components of DiT and autoencoder

Important configuration parameters


Training

The repo provides training and inference for CelebHQ (Unconditional DiT)

For working on your own dataset:

Once the config and dataset is setup:

Training AutoEncoder for DiT

Training Unconditional DiT

Train the autoencoder first and setup dataset accordingly.

For training unconditional DiT ensure the right dataset is used in train_vae_dit.py

Output

Outputs will be saved according to the configuration present in yaml files.

For every run a folder of task_name key in config will be created

During training of autoencoder the following output will be saved

During inference of autoencoder the following output will be saved

During training and inference of unconditional DiT following output will be saved:

Citations

@misc{peebles2023scalablediffusionmodelstransformers,
      title={Scalable Diffusion Models with Transformers}, 
      author={William Peebles and Saining Xie},
      year={2023},
      eprint={2212.09748},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2212.09748}, 
}