Closed FENRlR closed 3 months ago
The code
dphi_dt = self.estimator(x, mask, mu, t, cond, training=training) if guidance_scale > 0.0: mu_avg = mu.mean(2, keepdims=True).expand_as(mu) dphi_avg = self.estimator(x, mask, mu_avg, t, cond, training=training) dphi_dt = dphi_dt + guidance_scale * (dphi_dt - dphi_avg)
was separated to def func_dphi_dt for ease of reuse for both methods.
def func_dphi_dt
Thank you!
The code
was separated to
def func_dphi_dt
for ease of reuse for both methods.