p0p4k / pflowtts_pytorch

Unofficial implementation of NVIDIA P-Flow TTS paper
https://neurips.cc/virtual/2023/poster/69899
MIT License
214 stars 30 forks source link

Added 2nd-order Heun's method and midpoint method for ODE sampling #45

Closed FENRlR closed 3 months ago

FENRlR commented 3 months ago

The code

dphi_dt = self.estimator(x, mask, mu, t, cond, training=training)

        if guidance_scale > 0.0:
            mu_avg = mu.mean(2, keepdims=True).expand_as(mu)
            dphi_avg = self.estimator(x, mask, mu_avg, t, cond, training=training)
            dphi_dt = dphi_dt + guidance_scale * (dphi_dt - dphi_avg)

was separated to def func_dphi_dt for ease of reuse for both methods.

p0p4k commented 3 months ago

Thank you!