Closed mohammadreza-molapanah closed 5 months ago
hi, I don't understand the subtraction order above, why not like this,
prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)
hi, I don't understand the subtraction order above, why not like this, prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)
Because the final distil_loss takes the negative value of the prod_probs.
hi, I don't understand the subtraction order above, why not like this, prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)
The subtraction order that you mentioned is also correct, provided that the negative coefficient is removed from the distill_loss calculation. It should be as follows:
prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0) x = torch.sum(prod_probs, dim=-1).view(-1) distil_loss = torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
Hi, all. Sorry for late reply.
teacher_probs
term, as teacher probs are constant term in the gradient descent w.r.t student parameter, we don't need to consider it.Thanks for your interesting in our work.
Considering the definition of KLD, two functions forward_kl and skewed_forward_kl in the file distillm/losses.py need to be modified. In these two functions, the value of teacher_probs * log(teacher_probs) is not considered. The corrected version of skewed_forward_kl is as follows:
def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):
The same modifications are applicable to forward_kl as well.