crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.26k stars 372 forks source link

How CLIP guided sampling works #20

Open htzheng opened 2 years ago

htzheng commented 2 years ago

Thank you for the nice code. I wonder if you could explain how the CLIP-guided sampling works. Thank you!

UdonDa commented 2 years ago

The sampling is implemented at https://github.com/crowsonkb/k-diffusion/blob/master/sample_clip_guided.py#L32.

To achieve a conditional sampling, the following score is necessary: $\nabla{x} \mathrm{log}p(x|y) = \nabla{x} \mathrm{log}p(x) + \nabla{x} \mathrm{log}p(y|x)$. According to the [Karras+ arXiv22], the score of p(x) is represented by $\nabla{x} \mathrm{log}p(x; \sigma) = (D(x;\sigma) - x) / \sigma^{2}$.

So, the sampling code is interpreted by In L32, $\mathrm{denoised} = D(x;\sigma) + \nabla_{x} \mathrm{log}p(y|x) \sigma^{2}$ is calculated. Then, to obtain score function here (actually, to_d computes $\frac{dx}{dt}$), $\mathrm{to_d}(x, \sigma, \mathrm{denoised})$ $= (\mathrm{denoised} - x) / \sigma^{2}$ $= (D(x;\sigma) + \nabla_{x} \mathrm{log}p(y|x) \sigma^{2} - x) / \sigma^{2}$ $= (D(x;\sigma) - x) / \sigma^{2} + \nabla{x} \mathrm{log}p(y|x)$ $= \nabla{x} \mathrm{log}p(x) + \nabla_{x} \mathrm{log}p(y|x)$

Therefore, we can perform the CLIP-guided sampling!