lucidrains / spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
MIT License
254 stars 19 forks source link

Add support for specifying a mask id during the speech-to-speech pretraining task #6

Closed lucasnewman closed 1 year ago

lucasnewman commented 1 year ago

This allows for masking variable-lengthed sequences with the correct distribution (e.g. p = 0.6) with respect to the sequence mask.

I had to branch the masking behavior when the mask id is provided, because the existing batch -> mask -> rearrange back to batch doesn't work for ragged sequences, since each sequence in the batch can have a variable number of tokens. @lucidrains let me know if you are ok with this approach or if you have a better idea 🙏

I also added an option to return logits along with the loss to allow computing downstream metrics (the metrics are not included in this PR), and put in some minor fixes for the decoder depth & cleaned up some issues with the logging from the original trainer work.

lucidrains commented 1 year ago

@lucasnewman yes this is exactly how i would have approached it! thank you Lucas! 💯 🚀

lucidrains commented 1 year ago

code is impeccable, but now i know why haha