lucidrains / voicebox-pytorch

Implementation of Voicebox, new SOTA Text-to-speech network from MetaAI, in Pytorch
MIT License
589 stars 49 forks source link

Disable gradients for null conditioning when CFG is enabled #37

Closed lucasnewman closed 9 months ago

lucasnewman commented 9 months ago

I had a training run loss blow up to NaN after a while with conditional drop enabled, and it looks like the null conditioning had gradients enabled and would eventually overflow from debugging the model weights. This just disables gradients for the null conditioning parameter. I verified the loss still converges as expected.

I also included a couple of drive-by fixes (let me know if you want them in another PR):

lucidrains commented 9 months ago

@lucasnewman that's really interesting learned null conditioning led to instability! i'll have to think about that one

rest lgtm!

lucidrains commented 9 months ago
Screen Shot 2023-11-25 at 4 51 34 PM

may be seeing synergy between gateloop and attention layers (combined green run actually has less parameters than either the gateloop or attention run alone)

recommend giving that a try!