Closed charmeleonz closed 8 months ago
Thank you for your reminder. KL_loss() is just a very simple implementation of KL loss for 1d vectors. It is defined as follows:
def KL_Loss(output_batch, teacher_outputs, T):
output_batch = output_batch.unsqueeze(0)
teacher_outputs = teacher_outputs.unsqueeze(0)
output_batch = F.log_softmax(output_batch / T, dim=1)
teacher_outputs = F.softmax(teacher_outputs / T, dim=1) + 10 ** (-7)
loss = T * T * torch.sum(torch.mul(teacher_outputs, torch.log(teacher_outputs) - output_batch))
return loss
你好,请问如果要复现instance-T的话,是只需要将KD.py的16-18行修改以及这里的KL_loss修改吗?论文中提到的先concat再mlp的代码在哪里呢?
最近正在学习您的工作,非常感谢您的及时回复,我这就试试。
@zhengli97 Thanks a lot for your timely response.
Hi, may I ask where is function KL_Loss() in KD_loss += KL_Loss(y_s[i], y_t[i], T[i]) defined? Thank you.