Open tanmey007 opened 1 year ago
The problem arises in chapter:
The formula used to calculate KL divergence is wrong
Steps to reproduce the behavior:
1. 2. 3.
loss_fct = nn.KLDivLoss(reduction="batchmean") loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1),F.softmax(logits_tea / self.args.temperature, dim=-1))
loss_fct = nn.KLDivLoss(reduction="batchmean") loss_kd = self.args.temperature ** 2 * loss_fct(F.log_softmax(logits_stu / self.args.temperature, dim=-1),F.log_softmax(logits_tea / self.args.temperature, dim=-1))
Reference :https://pytorch.org/docs/1.13/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss
Hi, you might find this answer helpful https://discuss.pytorch.org/t/kl-loss-and-log-softmax/69136/2
Information
The problem arises in chapter:
Describe the bug
The formula used to calculate KL divergence is wrong
To Reproduce
Steps to reproduce the behavior:
1. 2. 3.
Expected behavior
Reference :https://pytorch.org/docs/1.13/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss