yzd-v / cls_KD

'NKD and USKD' (ICCV 2023) and 'ViTKD' (CVPRW 2024)
Apache License 2.0
217 stars 18 forks source link

Implementation of non-target mask #16

Closed charmeleonz closed 1 year ago

charmeleonz commented 1 year ago

Hi, thanks for your work.

In the implementation of the non-target mask for the NKD loss:

mask = torch.ones_like(logits).scatter(1, label, 1).bool()

Shouldn't it be mask = torch.ones_like(logits).scatter(1, label, 0).bool() instead?

yzd-v commented 1 year ago

Thank you. It's my mistake when I update the codes. I will fix this bug.

charmeleonz commented 1 year ago

Hi, thanks for your reply. Also, is the part of the code that "equalises sum of student & teacher non-target logits?" missing?

yzd-v commented 1 year ago

Here makes them be the same. We drop the target class, and then use softmax.