jongwooko / distillm

Official PyTorch implementation of DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)
https://arxiv.org/abs/2402.03898
145 stars 21 forks source link

incorrect loss functions based on KLD definition #7

Closed mohammadreza-molapanah closed 5 months ago

mohammadreza-molapanah commented 6 months ago

Considering the definition of KLD, two functions forward_kl and skewed_forward_kl in the file distillm/losses.py need to be modified. In these two functions, the value of teacher_probs * log(teacher_probs) is not considered. The corrected version of skewed_forward_kl is as follows:

def skewed_forward_kl(logits, teacher_logits, no_model_batch, lam=0.1):

teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
student_probs = F.softmax(logits, dim=-1, dtype=torch.float32)
mixed_probs = lam * teacher_probs + (1-lam) * student_probs

teacher_logprobs = F.log_softmax(teacher_logits, dim=-1, dtype=torch.float32)
mixed_logprobs = torch.log(mixed_probs)

mask = (no_model_batch["label"] != -100).int()
inf_mask = torch.isinf(logits) | torch.isinf(teacher_logits)

prod_probs = torch.masked_fill(teacher_probs * mixed_logprobs, inf_mask, 0)
prod_probs -= torch.masked_fill(teacher_probs * teacher_logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_loss

The same modifications are applicable to forward_kl as well.

wangfan120 commented 6 months ago

hi, I don't understand the subtraction order above, why not like this,
prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)

songmzhang commented 6 months ago

hi, I don't understand the subtraction order above, why not like this, prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)

Because the final distil_loss takes the negative value of the prod_probs.

mohammadreza-molapanah commented 6 months ago

hi, I don't understand the subtraction order above, why not like this, prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0)

The subtraction order that you mentioned is also correct, provided that the negative coefficient is removed from the distill_loss calculation. It should be as follows:

prod_probs = torch.masked_fill(teacher_probs teacher_logprobs, inf_mask, 0) prod_probs -= torch.masked_fill(teacher_probs mixed_logprobs, inf_mask, 0) x = torch.sum(prod_probs, dim=-1).view(-1) distil_loss = torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

jongwooko commented 5 months ago

Hi, all. Sorry for late reply.

  1. Regarding teacher_probs term, as teacher probs are constant term in the gradient descent w.r.t student parameter, we don't need to consider it.
  2. Regarding the subtraction order, as @songmzhang mentioned, it is because final distil_loss takes the negative value of the prod_probs.

Thanks for your interesting in our work.