huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

Why we need KL-divergence #45

Open penolove opened 7 months ago

penolove commented 7 months ago

https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html image according to the formula, it's seems just equivalent to optimize soft-label cross-entropy, since the target y_true is just a constant decided by teacher model

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#crossentropyloss isn't it? or did I misunderstand anything?

sanchit-gandhi commented 7 months ago

Hey @penolove - the CE and KL loss terms are closely related:

  1. For the CE loss, y_true is 1 for the correct class probability label, and 0 for all others. In other words, it is a 'one-hot' encoding of the target distribution.
  2. For the KL loss, y_true is the predicted class probability from the teacher model. Thus, it can take values between [0, 1]. In other words, it is the predicted density function from the teacher model.

In summary, the CE loss trains the student model to maximise the probability of the correct class label, whereas the KL loss trains the student to mimic the distribution of the teacher model over all class labels.

penolove commented 7 months ago

yes, the 1. you mentioned is hard label CE which the target is the one-hot label generated by teacher but for 2. KL loss we can see in the formula: y_true log(y_true) - y_true log(y_pred) (this is a vector) first term is fixed due to y_true (probabilities of each tokens) is already generated, the second term is what I mean the soft-label cross-entropy, which means we can only use the second term instead of KL, and I think it can reduce some calculation of y_true* log(y_true)

sanchit-gandhi commented 7 months ago

Yes you're right - likely we could likely run the teacher with .generate to get the probability mass function for each example, and then compute a simplified loss by merging the CE and KL loss terms.

However, we wanted to run pseudo-labelling ahead of time and only keep the predictions (since this is the most time consuming steps). Therefore, we opted for running generation before and saving the text predictions, and then passing the pseudo-labelled text back through the teacher model to get the teacher mass function during training.

Doing it this way, it seemed simpler to keep the CE and KL terms separate, since they then tie-in with the mathematical formulations.