Hi @robert-graf,
Thank you for this wonderful reference implementation. I'm trying to infer and save the sampled image for afhq dataset. Below is my code adapted from your mnist inference code. Somehow the saved image is just blank/empty:
import torch, torchvision
from loader.arguments import get_latest_Checkpoint
from diffusion import Diffusion
from loader import load_dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
name = "afhq" # exp_name
version = "*" # * is the newest or use the number for an specific version
batch_size = 3
checkpoint = get_latest_Checkpoint(name, log_dir_name="logs_diffusion", best=False, version="*")
assert checkpoint is not None, "did not found checkpoint"
model = Diffusion.load_from_checkpoint(checkpoint, strict=False)
#model.cpu()#cuda()
model.cuda()
label = torch.tensor([0,1,2]).cuda()
# DDPM
image_ddpm = model.forward(batch_size, 1000, label=label)
# DDIM
image_ddim, inter = model.forward_ddim(batch_size, [i for i in range(0, 1000, 20)], label=label, eta=0.0) # type: ignore
assert isinstance(image_ddpm, torch.Tensor)
grid = torchvision.utils.make_grid(torch.cat([image_ddpm, image_ddim], dim=0), nrow=5).cpu()
a = grid.permute(1, 2, 0)
a = torch.clamp(a, 0, 1)
plt.figure(figsize=(40, 40))
plt.savefig('sample.jpg')
Could you let me know if I'm missing something here.
Hi @robert-graf, Thank you for this wonderful reference implementation. I'm trying to infer and save the sampled image for afhq dataset. Below is my code adapted from your mnist inference code. Somehow the saved image is just blank/empty:
Could you let me know if I'm missing something here.