huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

I want to confirm how the knowledge organization is implemented? #73

Open hxypqr opened 5 months ago

hxypqr commented 5 months ago

I don't quite understand how knowledge distillation is implemented here.

Whisper is trained on 680,000 hours of untagged data for autoregression. According to the content of the fourth section of the paper, our model is trained on 21,170 hours of data with pseudo-labels generated by Whisper, with the first and 32nd layer parameters frozen based on Whisper. This means that our model only needs to go through 21,170 hours of data with pseudo-labels and a model structure similar to Whisper, freezing the first and 32nd layers, using weighted KL divergence and label cross-entropy to achieve good results?

If this is the case, it is indeed a significant discovery, indicating that we can always reduce the model's parameters and inference time after pre-training the model using similar methods, without significant loss of accuracy.

Thank you in advance

sanchit-gandhi commented 5 months ago

That's almost right! We freeze the entire encoder (32 layers) and take the first and last layers of the decoder (2 layers). We then train the model on the knowledge distillation objective on 22k hours of data. You can read more about how we initialise and train the model here: https://github.com/huggingface/distil-whisper#3-approach-%EF%B8%8F

Given the model is pre-trained on so much data, the encoder representation of the audio data is extremely good. We then just need to train the first and last decoder layers to behave as the full original 32 decoder layers, which requires less data than full pre-training.

Section 9.2 of the paper gives a nice analysis of the effect of dataset size for distillation: https://arxiv.org/pdf/2311.00430.pdf