joh-fischer / PlantLDM

A latent diffusion model for visual synthesis of plant images.
MIT License
15 stars 1 forks source link


We implemented a latent diffusion model (Rombach et al., 2021) for the visual synthesis of plant images of the PlantNet-300K dataset. The model is can be split into two stages. The first stages consists of a VQ-GAN (Esser et al., 2020) which encodes the image into a latent representation. A Denoising Diffusion Probabilistic Model (Ho et al., 2020) forms the second stage, synthesizing the latent representations which are then decoded by the decoder of the first stage.


Clone the repository and create a virtual environment (optional)

virtualenv venv
source venv/bin/activate

Then install the dependencies

pip install -r requirements.txt



After downloading the data you can choose whether you want to crop and resize the images on the fly or preprocess them in advance, which saves computational resources during training. For the first option, just specify a config

data_dir: '<your_dir_to_data>'
is_preprocessed: False

and start the training with the --data-config <your_cfg_file> argument parser option. Additionally, you need to specify the image size with the --image-size <size> option.

If you want to save resources and preprocess the data, just run the script and create a config with the respective data directory and the is_preprocessed option set to True. Then specify this config file with the --data-config <your_cfg_file> argument parser option.


In order to monitor the losses and visualizations just cd into the repo and run

tensorboard --logdir=logs


Implementation of VQ-VAE (paper).


python3 --name myExp --epochs 2 --config configs/vqvae.yaml

To first debug the code with CIFAR10 just run

python3 --epochs 2 --config configs/vqvae.yaml --debug


Implementation of VQ-GAN (paper).


python3 --name myExp --epochs 2 --config configs/vqgan.yaml

To first debug the code with CIFAR10 just run

python3 --epochs 2 --config configs/vqgan.yaml --debug

Model Settings

You can specify the model settings in the vqgan.yaml config file. The length of the channels list in the config files also determines the down-scaling of the input image. For example, a list with two channels (eg [32, 64]) down-samples the image by a factor of 4.

    in_channels: 3
      - 32
      - 64
    dim_keys: 64
    n_heads: 4
  latent_dim: 32
  n_embeddings: 512


You can specify which losses to use and which weights for which loss in the vqgan.yaml config file.

  rec_loss_type: 'L1'
  perceptual_weight: 0.1
  codebook_weight: 0.9
  commitment_weight: 0.25
  disc_weight: 0.1
  disc_in_channels: 3
  disc_n_layers: 4
  disc_warm_up_iters: 5000
  disc_res_blocks: False