huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

mask BOS in prompted ids #77

Closed sanchit-gandhi closed 5 months ago

sanchit-gandhi commented 5 months ago

When conditioning on previous context text, we should mask the previous tokens and the BOS token from the loss computation. Training the model to predict the decoder input id doesn't make sense - this would be asking the model to "predict" when the prompt ids finish and the transcription starts. At inference time, the BOS token is provided as a decoder input id, so it should not be predicted by the model, only passed as an input.

sanchit-gandhi commented 5 months ago

cc @peregilk as I think this could be propagated to your Flax code: https://github.com/NbAiLab/nb-whisper/blob/628933e7ab617580c2302ff9391aca141de1184e/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L593-L596