suxuann / ddib

Dual Diffusion Implicit Bridges for Image-to-Image Translation. ICLR 2023.
MIT License
353 stars 30 forks source link

About the DDIM reconstruction error. #15

Open KwanghyunOn opened 1 year ago

KwanghyunOn commented 1 year ago

Hi there,

I've been attempting to use your code to reconstruct the original image with DDIM sampling functions, but I'm having some difficulty. Specifically, the image below is the result of my attempt to reconstruct the FFHQ dataset using a pre-trained diffusion model trained on the CelebA-HQ dataset.

Unfortunately, the reconstruction error is much larger than I expected. I was wondering if there might be something I'm overlooking or not taking into account? Any insights you could provide would be greatly appreciated.

Thank you.

recon_ffhq_clip=TF

def main():
    # ...

    for batch_idx, data in enumerate(dataloader):
        img_batch = data[0].to(device)
        noise = diffusion.ddim_reverse_sample_loop(
            model,
            img_batch,
            clip_denoised=False,
            device=device,
            progress=True,
        )
        recon_sample = diffusion.ddim_sample_loop(
            model,
            (args.batch_size, 3, args.image_size, args.image_size),
            noise=noise,
            clip_denoised=True,
            device=device,
            eta=args.eta,
            progress=True,
        )
        torchvision.utils.save_image(
            (torch.cat([img_batch, recon_sample], dim=0) + 1.) / 2.,
            f"results/recon_ffhq_clip=FT.png",
            nrow=4,
        )

    # ...
838959823 commented 7 months ago

I encounter the same problem using Imagenet-pretrained models. If the reconstruction error is large, how to guarantee the cycle consistency?