tqch / v-diffusion-torch

PyTorch Implementation of V-objective Diffusion Probabilistic Models with Classifier-free Guidance
MIT License
29 stars 3 forks source link

CIFAR10 Config #1

Closed rahulvigneswaran closed 1 year ago

rahulvigneswaran commented 1 year ago

Can you provide the exact config you used to generate the cifar10 results you showed in readme?

tqch commented 1 year ago

The exact config is in ./configs/cifar10.json. I can also provide an example code snippet to generate the CIFAR-10 results in README.md:

from v_diffusion import *
import torch
from torchvision.utils import save_image

model = UNet(
    in_channels=3,
    hid_channels=256,
    out_channels=6,
    ch_multipliers=[1, 1, 1],
    num_res_blocks=3,
    apply_attn=[False, True, True],
    drop_rate=0.2,
    num_classes=10
)

device = "cuda:0"
chkpt = torch.load("./chkpts/vdpm_cifar10_XXX.pt", map_location=device)["ema"]["shadow"]
model.to(device)
for k in list(chkpt.keys()):
    if k.startswith("module."):
        chkpt[k.split(".", maxsplit=1)[1]] = chkpt.pop(k)
model.load_state_dict(chkpt)
model.eval()
model.requires_grad_(False)
del chkpt

logsnr_fn = get_logsnr_schedule("cosine", -20, 20)
diffusion = GaussianDiffusion(
    logsnr_fn,
    sample_timesteps=256,
    model_out_type="both",
    model_var_type="fixed_large",
    reweight_type="truncated_snr",
    loss_type="mse",
    w_guide=0.1,
    p_uncond=0.1,
    use_ddim=True)

ws = [0, 1, 3]
torch.manual_seed(1234)
noise = torch.randn((100, 3, 32, 32)).to(device)

for w in ws:
    diffusion.w_guide = w

    x_gen = diffusion.p_sample(
        model, (100, 3, 32, 32), device=device,
        noise=noise,
        label=torch.arange(
            1, 11, device=device, dtype=torch.float32).repeat_interleave(10),
        use_ddim=True)

    save_image(
        x_gen.cpu(),
        f"./cifar10_w{w}.png",
        nrow=10,
        normalize=True,
        value_range=(-1, 1)
    )