zhengli97 / CTKD

[AAAI 2023] Official PyTorch Code for "Curriculum Temperature for Knowledge Distillation"
https://zhengli97.github.io/CTKD/
Apache License 2.0
156 stars 12 forks source link

Implementation of Instance_T #13

Closed charmeleonz closed 8 months ago

charmeleonz commented 8 months ago

Hi, may I ask where is function KL_Loss() in KD_loss += KL_Loss(y_s[i], y_t[i], T[i]) defined? Thank you.

zhengli97 commented 8 months ago

Thank you for your reminder. KL_loss() is just a very simple implementation of KL loss for 1d vectors. It is defined as follows:

def KL_Loss(output_batch, teacher_outputs, T):

    output_batch = output_batch.unsqueeze(0)
    teacher_outputs = teacher_outputs.unsqueeze(0)

    output_batch = F.log_softmax(output_batch / T, dim=1)
    teacher_outputs = F.softmax(teacher_outputs / T, dim=1) + 10 ** (-7)

    loss = T * T * torch.sum(torch.mul(teacher_outputs, torch.log(teacher_outputs) - output_batch))
    return loss
wenyi-zhh commented 8 months ago

你好,请问如果要复现instance-T的话,是只需要将KD.py的16-18行修改以及这里的KL_loss修改吗?论文中提到的先concat再mlp的代码在哪里呢?

zhengli97 commented 8 months ago
  1. 是的 需要改对应的loss
  2. mlp代码在README提供的云链接里,有py实现
wenyi-zhh commented 8 months ago

最近正在学习您的工作,非常感谢您的及时回复,我这就试试。

charmeleonz commented 8 months ago

@zhengli97 Thanks a lot for your timely response.