lucidrains / spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
MIT License
249 stars 18 forks source link

Add trainers for the pretraining and backtranslation tasks #4

Closed lucasnewman closed 1 year ago

lucasnewman commented 1 year ago

I was able to get both the pretraining and backtranslation tasks working using graphemes, as described in Appendix G in the paper. It required some model tweaks to allow freezing the encoder & specifying a target mask when the modalities are different. I added beam search as optional in the generation loop since it seems to outperform temperature sampling for the backtranslation task.

I also fixed an issue with loading checkpoints and to allow restoring checkpoints without the optimizer state. This makes both resuming training runs and fine-tuning from the pretrained model possible.

This might be too big of a change to take, but I thought it might be useful for others! Let me know if you'd prefer me to break it up into smaller changes if that's helpful.

Here's an example of a sample generation using the (severely undertrained 🙈, in my case) models:

semantic_token_ids: tensor([[ 17, 296, 114, 258, 271,  31,  39,  54, 142, 397, 345, 141, 281, 269,
           9, 142, 221, 196, 309, 479, 331, 307, 405, 206, 167, 385, 233,  82,
         227, 419, 483, 225, 226,  82, 209,  83, 145, 253, 368, 453, 168, 177,
         457, 196, 217, 473, 476, 171, 252, 422, 186, 162, 232, 172, 115, 273,
         444, 360, 434, 339, 203, 381, 117, 404, 229,  82, 247, 126, 326, 101,
         149, 228,  82, 289, 320,   7, 217, 473, 286, 468, 134, 175, 359,  81,
         166, 324,   3, 440, 188,  44,  38, 225, 164, 205, 261,  25, 485, 286,
         468, 406, 337,  41, 246,  19, 454, 229, 414,  82,  80,  82,  80,  82,
         140, 108, 119, 351, 278, 330, 388,  33, 195, 471, 368, 310, 107, 447,
           6, 272, 161, 397, 345, 333, 220, 402, 478,  66, 482, 232, 172, 115,
         273, 106, 499, 306, 396, 245, 143, 458, 259, 192, 445, 351, 486, 460,
         368, 342,  54, 224, 168, 494, 275, 203, 381,  48, 417, 421, 128, 491,
         193,  17]], device='mps:0')

Reference: This was not, as it may seem, merely a theory tinged with sarcasm.
 52%|█████▏    | 67/128 [00:01<00:01, 47.14it/s]
Backtranslated (sampling): This was not as it may seem, merely a theory, tinged, with sarcasm.
 52%|█████▏    | 67/128 [00:02<00:02, 25.31it/s]
Backtranslated (beam size 4): This was not, as it may seem, merely a theory, tinged with sarcasm.
lucidrains commented 1 year ago

dude! you MVP - i was about to do all this this weekend

caught up with Voicebox work today and tomorrow, but will circle back to this on Sunday and get it merged! 🙏 (looks great on a cursory scan)

lucidrains commented 1 year ago

:100: :rocket:

even humored me and followed inline with the strange styling :laughing: