sanchit-gandhi / seq2seq-speech

Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.
34 stars 6 forks source link

Negative Losses in CTC Training #28

Closed sanchit-gandhi closed 2 years ago

sanchit-gandhi commented 2 years ago

Training the baseline CTC model on the Common Voice 9 (CV9) dataset, we observe that the training loss drops below zero after ~1.5k train steps: https://wandb.ai/sanchit-gandhi/commonvoice_9_0/runs/y593pwm4?workspace=user-sanchit-gandhi. The CTC loss should be strictly nonnegative.

sanchit-gandhi commented 2 years ago

Found it! When we define the model: https://github.com/sanchit-gandhi/seq2seq-speech/blob/b1bf2c2148910d59fd8ba3f0086244e0879a65b7/run_flax_speech_recognition_ctc.py#L849-L856 We need to set the config attribute vocab_size to the number of elements in the tokenizer's vocabulary. Otherwise, it will default to the vocab_size for the Wav2Vec2-large-lv60 checkpoint, which is defined as the vocab size of the default Wav2Vec2 tokenizer built on Librispeech ASR. If the actual tokenizer's vocab size is greater than that of the default Wav2Vec2 tokenizer, we'll have logits that span over a partial sub-space of the full tokenizer vocabulary. These ill-defined logits then (likely) give rise to an ill-defined CTC loss function.

patrickvonplaten commented 2 years ago

Great catch! Due to this the tokenizer converts too many letters to tokens which surely messes up the CTC loss. You're exatly right we should add a vocab_size=len(tokenizer) here