robert-graf / Readable-Conditional-Denoising-Diffusion

Readable Conditional Denoising Diffusion
MIT License
28 stars 1 forks source link

How to sample and save an image for afhq #2

Open cutoken opened 6 months ago

cutoken commented 6 months ago

Hi @robert-graf, Thank you for this wonderful reference implementation. I'm trying to infer and save the sampled image for afhq dataset. Below is my code adapted from your mnist inference code. Somehow the saved image is just blank/empty:

import torch, torchvision
from loader.arguments import get_latest_Checkpoint
from diffusion import Diffusion
from loader import load_dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

name = "afhq"  # exp_name
version = "*"  # * is the newest or use the number for an specific version
batch_size = 3
checkpoint = get_latest_Checkpoint(name, log_dir_name="logs_diffusion", best=False, version="*")
assert checkpoint is not None, "did not found checkpoint"
model = Diffusion.load_from_checkpoint(checkpoint, strict=False)
#model.cpu()#cuda()
model.cuda()

label = torch.tensor([0,1,2]).cuda()

# DDPM
image_ddpm = model.forward(batch_size, 1000, label=label)
# DDIM
image_ddim, inter = model.forward_ddim(batch_size, [i for i in range(0, 1000, 20)], label=label, eta=0.0)  # type: ignore
assert isinstance(image_ddpm, torch.Tensor)
grid = torchvision.utils.make_grid(torch.cat([image_ddpm, image_ddim], dim=0), nrow=5).cpu()
a = grid.permute(1, 2, 0)
a = torch.clamp(a, 0, 1)
plt.figure(figsize=(40, 40))
plt.savefig('sample.jpg')

Could you let me know if I'm missing something here.

cutoken commented 6 months ago

Got it working with the help of gpt4. Here are some derpy little friends for reference: sample(1)