Closed charmeleonz closed 1 year ago
Hi, thanks for your work.
In the implementation of the non-target mask for the NKD loss:
mask = torch.ones_like(logits).scatter(1, label, 1).bool()
Shouldn't it be mask = torch.ones_like(logits).scatter(1, label, 0).bool() instead?
Thank you. It's my mistake when I update the codes. I will fix this bug.
Hi, thanks for your reply. Also, is the part of the code that "equalises sum of student & teacher non-target logits?" missing?
Here makes them be the same. We drop the target class, and then use softmax.
Hi, thanks for your work.
In the implementation of the non-target mask for the NKD loss:
mask = torch.ones_like(logits).scatter(1, label, 1).bool()
Shouldn't it be mask = torch.ones_like(logits).scatter(1, label, 0).bool() instead?