lucidrains / spear-tts-pytorch

Implementation of Spear-TTS - multi-speaker text-to-speech attention network, in Pytorch
MIT License
253 stars 18 forks source link

EOS token not predicted while training from scratch #15

Open Kodhandarama opened 8 months ago

Kodhandarama commented 8 months ago

I am currently training S1 from scratch as described in the paper as an ablation study. The paper states that the authors use a decoder only architecture and a 12-layer transformer as described in the t5 paper. I made the above changes to the current framework and trained the text to semantic transformer model. The initial few tokens are predicted right by the model, however the predictions become worse and the model fails to predict the EOS token. Is there a fix you have in mind for this?

Side note : One of the errors I kept running into was a CUDA error in the function rotate_queries_with_cached_keys. To avoid this, I set use_xpos=True in the RotaryEmbedding class.

My aim here is to be able to predict semantic tokens from text. I am open to trying out other models for this as well if you have suggestions :)

lucasnewman commented 8 months ago

I'm not sure about the rotary embedding (Phil is the expert there), but in terms of the EOS token, some of it depends on how you're handling EOS in the pretraining stage.

One thing to watch out for is that if you're pretraining on random crops of semantic data (this is imho necessary for reasonable training efficiency), you'll want to omit the automatic EOS handling for that stage, but also make sure to allocate space in the embedding for the EOS token later in S1 training.

e.g. pretraining could look like:

text_to_semantic_model = TextToSemantic(
    num_semantic_token_ids = num_semantic_tokens + 1, # add an extra token for eos later
...
    autoset_semantic_eos_id = False, # don't autoset the eos for pretraining
)

trainer = SpeechSpeechPretrainer(
    model = text_to_semantic_model,
...
)
trainer.train()

and then S1 would look like:

text_to_semantic_model = TextToSemantic(
    num_semantic_token_ids = num_semantic_tokens, # don't include eos token explicitly
...
    autoset_semantic_eos_id = True, # autoset the eos for text -> semantic generation
)

trainer = TextToSemanticTrainer(
    model = text_to_semantic_model,
    ...
)
trainer.load("pretrained_ckpt.pt", restore_optimizer = False)
trainer.train()
Kodhandarama commented 8 months ago

Thank you for your quick reply! I followed your S1 training strategy exactly. In my case, I am looking to train the text to semantic prediction model with a large, parallel speech - text corpus (LJspeech) without any pre-training and backtranslation. (This is mentioned in the ablation study in the SPEARTTS paper). The encoder and decoder are both unfrozen and only the speech embedding is frozen. (I did try training the model with a trainable speech embedding but that did not change EOS prediction problem). image I've observed that during the inference process, the model persists in predicting semantic tokens until reaching the maximum limit of 2048 tokens. The issue persists as the EOS token is not predicted at any stage, leading to ongoing hallucination.

lucasnewman commented 8 months ago

Ok, it's tricky to debug without seeing the code in that case, unfortunately. If you can't share it, I would check to make sure your padding and/or mask tokens aren't colliding with the EOS token, and look at the inputs to the network to make sure the EOS token is being correctly applied and attended to. In the grand scheme of things, it's just a transformer network under the hood, so if you're skipping all the early stages and not freezing any layers, it should be straightforward to debug.

FWIW, in the context of Spear-TTS, LJSpeech is considered small and you'll struggle to generalize to a broader array of speech. You may consider using something like LibriTTS-R, which has higher overall audio quality and has a number of preselected training subsets for experimentation or you can pool them for up to 960 hours of audio. It's still going to be in the range of "audiobook vocabulary" vs some of the state-of-the-art examples available now when it comes to semantic generation, though!