Closed gordicaleksa closed 2 years ago
Hi!
{x1, x2, ..., xn}
of (f(x1) + f(x2) + ... + f(xn))
is equal to {grad_x1 f(x1), grad_x2 f(x2), ..., grad_xn f(xn)}
. Since batch elements do not interact (i.e. we don't use BatchNorm), this means that we can run the whole batch through the model and compute the gradient in one go.Thanks @unixpickle!
Hi @unixpickle @prafullasd @erinbeesley, I think I found 2 bugs:
1) Shouldn't we pass out["mean"] (x{t}) instead of x (x{t+1}) here (similarly t-1 instead of t): https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py#L435
2) Shouldn't we separate grad calculation here? https://github.com/openai/guided-diffusion/blob/main/scripts/classifier_sample.py#L61 We need grads of i-th image in the batch w.r.t. corresponding log prob and not grads of i-th image w.r.t. the sum of log probs? It makes no sense to optimize the sum, as we want each image to be directed by its own class.
I might have misunderstood something please let me know if so! :)