Closed sanchit-gandhi closed 5 months ago
cc @peregilk as I think this could be propagated to your Flax code: https://github.com/NbAiLab/nb-whisper/blob/628933e7ab617580c2302ff9391aca141de1184e/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L593-L596
When conditioning on previous context text, we should mask the previous tokens and the BOS token from the loss computation. Training the model to predict the decoder input id doesn't make sense - this would be asking the model to "predict" when the prompt ids finish and the transcription starts. At inference time, the BOS token is provided as a decoder input id, so it should not be predicted by the model, only passed as an input.