Open htzheng opened 2 years ago
The sampling is implemented at https://github.com/crowsonkb/k-diffusion/blob/master/sample_clip_guided.py#L32.
To achieve a conditional sampling, the following score is necessary: $\nabla{x} \mathrm{log}p(x|y) = \nabla{x} \mathrm{log}p(x) + \nabla{x} \mathrm{log}p(y|x)$. According to the [Karras+ arXiv22], the score of p(x) is represented by $\nabla{x} \mathrm{log}p(x; \sigma) = (D(x;\sigma) - x) / \sigma^{2}$.
So, the sampling code is interpreted by
In L32, $\mathrm{denoised} = D(x;\sigma) + \nabla_{x} \mathrm{log}p(y|x) \sigma^{2}$ is calculated.
Then, to obtain score function here (actually, to_d
computes $\frac{dx}{dt}$),
$\mathrm{to_d}(x, \sigma, \mathrm{denoised})$
$= (\mathrm{denoised} - x) / \sigma^{2}$
$= (D(x;\sigma) + \nabla_{x} \mathrm{log}p(y|x) \sigma^{2} - x) / \sigma^{2}$
$= (D(x;\sigma) - x) / \sigma^{2} + \nabla{x} \mathrm{log}p(y|x)$
$= \nabla{x} \mathrm{log}p(x) + \nabla_{x} \mathrm{log}p(y|x)$
Therefore, we can perform the CLIP-guided sampling!
Thank you for the nice code. I wonder if you could explain how the CLIP-guided sampling works. Thank you!