NVlabs / RED-diff

Other
60 stars 8 forks source link

Scaling of PiGDM guidance #3

Open man-sean opened 9 months ago

man-sean commented 9 months ago

In the PiGDM paper (Sec A.1, Algorithm 1) it says that we need to scale the guidance term by $\sqrt{\alpha_t}$. In the code we scale by $\sqrt{\alphat} \cdot \sqrt{\alpha{t-1}}$:

coeff = alpha_s.sqrt() 
if not self.awd:
    coeff = coeff - c2 * alpha_t.sqrt() / (1 - alpha_t).sqrt()
coeff = coeff * alpha_t.sqrt() * self.grad_term_weight

If we only scale by $\sqrt{\alphat}$ we get NaN during inference due to large guidance. From were this additional scaling by $\sqrt{\alpha{t-1}}$ comes from?