Closed rahulvigneswaran closed 1 year ago
The exact config is in ./configs/cifar10.json
. I can also provide an example code snippet to generate the CIFAR-10 results in README.md
:
from v_diffusion import *
import torch
from torchvision.utils import save_image
model = UNet(
in_channels=3,
hid_channels=256,
out_channels=6,
ch_multipliers=[1, 1, 1],
num_res_blocks=3,
apply_attn=[False, True, True],
drop_rate=0.2,
num_classes=10
)
device = "cuda:0"
chkpt = torch.load("./chkpts/vdpm_cifar10_XXX.pt", map_location=device)["ema"]["shadow"]
model.to(device)
for k in list(chkpt.keys()):
if k.startswith("module."):
chkpt[k.split(".", maxsplit=1)[1]] = chkpt.pop(k)
model.load_state_dict(chkpt)
model.eval()
model.requires_grad_(False)
del chkpt
logsnr_fn = get_logsnr_schedule("cosine", -20, 20)
diffusion = GaussianDiffusion(
logsnr_fn,
sample_timesteps=256,
model_out_type="both",
model_var_type="fixed_large",
reweight_type="truncated_snr",
loss_type="mse",
w_guide=0.1,
p_uncond=0.1,
use_ddim=True)
ws = [0, 1, 3]
torch.manual_seed(1234)
noise = torch.randn((100, 3, 32, 32)).to(device)
for w in ws:
diffusion.w_guide = w
x_gen = diffusion.p_sample(
model, (100, 3, 32, 32), device=device,
noise=noise,
label=torch.arange(
1, 11, device=device, dtype=torch.float32).repeat_interleave(10),
use_ddim=True)
save_image(
x_gen.cpu(),
f"./cifar10_w{w}.png",
nrow=10,
normalize=True,
value_range=(-1, 1)
)
Can you provide the exact config you used to generate the cifar10 results you showed in readme?