crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.26k stars 372 forks source link

Example how to use the VDenoiser/OpenAIDenoiser/CompVisDenoiser wrappers? #15

Open Quasimondo opened 2 years ago

Quasimondo commented 2 years ago

I would be marvelous if you could add a small example at some point that shows how to use these wrappers. My guess is that I first have to initialize the inner_model using the other respective codebase and then just wrap it?

crowsonkb commented 2 years ago

You do just load it using the other codebase in the normal way and wrap it. There is an example at https://colab.research.google.com/drive/1w0HQqxOKCk37orHATPxV8qb0wb4v-qa0, where I do:

    model_wrap = K.external.OpenAIDenoiser(model, diffusion, device=device)
    sigmas = model_wrap.get_sigmas(n_steps)
    if init is not None:
        sigmas = sigmas[sigmas <= sigma_start]

The get_sigmas() method on the model wrapper retrieves the model's original noise schedule if it had one and optionally respaces it with the given number of timesteps.

Quasimondo commented 2 years ago

Thanks for making that clearer!

Now I tried to monkey-wrench that into sample_clip_guided.py, using the CC12M1Model from v-diffusion and whilst everything seems to load fine I am getting an error during forward()

inner_model = K.models.CC12M1Model().eval().requires_grad_(False).to(device)
inner_model.load_state_dict(torch.load("cc12m_1_cfg.pth", map_location='cpu'))
inner_model = K.external.VDenoiser(inner_model)

....

...k_diffusion/external.py", line 38, in forward
    return self.inner_model(input * c_in, self.sigma_to_t(sigma), **kwargs) * c_out + input * c_skip
TypeError: forward() missing 1 required positional argument: 'clip_embed'

I guess I might be asking for too much and should just use regular sampling for now?

crowsonkb commented 2 years ago

cc12m_1_cfg expects a ViT-B/16 CLIP embedding as input and you have to specify it somehow, either as clip_embed=something to the model or as extra_args={'clip_embed': something} to one of the k-diffusion sampling loops. You can pass zeros in to get unconditional sampling (torch.zeros([batch_size, 512], device=device)).

Quasimondo commented 2 years ago

Thanks again for helping! Now I am getting somewhere.

tinapan-pt commented 2 years ago

Thank you very much for your answer! I would like to ask why the picture quality becomes very poor after the number of steps is reduced to less than 100 steps when I use the new sampler method on the openai guided diffusion model. In theory, wouldn't DPM-solver converge faster?