nlp-with-transformers / notebooks

Jupyter notebooks for the Natural Language Processing with Transformers book
https://transformersbook.com/
Apache License 2.0
3.7k stars 1.13k forks source link

Error in chapter 8 KL divergence formula #94

Open tanmey007 opened 1 year ago

tanmey007 commented 1 year ago

Information

The problem arises in chapter:

Describe the bug

The formula used to calculate KL divergence is wrong

To Reproduce

Steps to reproduce the behavior:

1. 2. 3.


loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1),F.softmax(logits_tea / self.args.temperature, dim=-1))

Expected behavior


loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1),F.log_softmax(logits_tea / self.args.temperature, dim=-1))

Reference :https://pytorch.org/docs/1.13/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss

triet1102 commented 5 months ago

Hi, you might find this answer helpful https://discuss.pytorch.org/t/kl-loss-and-log-softmax/69136/2