lucidrains / denoising-diffusion-pytorch

Implementation of Denoising Diffusion Probabilistic Model in Pytorch
MIT License
8.09k stars 1.01k forks source link

Sampled images are often mirrored across y-axis despite training images being oriented in same direction #35

Open sgbaird opened 2 years ago

sgbaird commented 2 years ago

Is there something in the repository that treats mirrors about the y-axis as equivalent? Wondering if this is just an artifact of the checkpoint sample-##.png stacking of the images or if those are the real outputs.

Reproducer

I'm using ~16k grayscale training images that are $64 \times 64$ pixels.

conda create -n xtal2png-ddpm python==3.9.*
conda activate xtal2png-ddpm
pip install denoising_diffusion_pytorch xtal2png
from os import path

import torch
from denoising_diffusion_pytorch import GaussianDiffusion, Trainer, Unet
from mp_time_split.core import MPTimeSplit

from xtal2png.core import XtalConverter

mpt = MPTimeSplit()
mpt.load()

fold = 0
train_inputs, val_inputs, train_outputs, val_outputs = mpt.get_train_and_val_data(fold)

data_path = path.join("data", "preprocessed", "mp-time-split")
xc = XtalConverter(save_dir=data_path)
xc.xtal2png(train_inputs.tolist())

model = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=1).cuda()

diffusion = GaussianDiffusion(
    model, channels=1, image_size=64, timesteps=1000, loss_type="l1"
).cuda()

trainer = Trainer(
    diffusion,
    data_path,
    image_size=64,
    train_batch_size=32,
    train_lr=2e-5,
    train_num_steps=700000,  # total training steps
    gradient_accumulate_every=2,  # gradient accumulation steps
    ema_decay=0.995,  # exponential moving average decay
    amp=True,  # turn on mixed precision
)

trainer.train()

sampled_images = diffusion.sample(batch_size=100)

Training examples

image

Notice how they're all oriented with the small square of zeros at the top-left (this is the case for all training data).

sample-25.png

loss: 0.0535:   4%|  | 25507/700000 [4:26:37<94:32:36,  1.98it/s]9it/s]

Notice how many of them are mirrored about the y-axis, which is not desired. If it's just an artifact of how the images are stacked, that's one thing - but if it's the actual sampled images it's a bit worrisome. Note also that it never seems to do an x-axis mirror with the small square of zeros at the bottom-right or bottom-left.

sample-25

cc my labmate @hasan-sayeed who has also been working on this

lucidrains commented 2 years ago

try turning off the horizontal flip augmentation by setting https://github.com/lucidrains/denoising-diffusion-pytorch/commit/79c5f045e10461c3f56f2ecf3767f318ec3719e9#diff-1ef0846017d5c27bd277b66e5afcab8798667afd39675e6f9382fc7e87a1fe28R584 to False