dropreg / R-Drop

867 stars 107 forks source link

About the implementation in transformers, where the reduction in ce_loss uses the mean (by default), while KL uses the reduction is sum ? #24

Closed XiaoqingNLP closed 2 years ago

XiaoqingNLP commented 2 years ago
      for logits in logits_list:
          if labels is not None:
              if self.num_labels == 1:
                  #  We are doing regression
                  loss_fct = MSELoss()
                  if loss:
                      loss += alpha * loss_fct(logits.view(-1), labels.view(-1))
                  else:
                      loss = alpha * loss_fct(logits.view(-1), labels.view(-1))
              else:
                  loss_fct = CrossEntropyLoss()
                  if loss:
                      loss += alpha * loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
                  else:
                      loss = alpha * loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if loss is not None:
            if self.num_labels == 1:
                loss_fct = MSELoss()
                loss += 1.0 * loss_fct(logits_list[0].view(-1), logits_list[-1].view(-1))
            else:
                p = torch.log_softmax(logits_list[0].view(-1, self.num_labels), dim=-1)
                p_tec = torch.softmax(logits_list[0].view(-1, self.num_labels), dim=-1)
                q = torch.log_softmax(logits_list[-1].view(-1, self.num_labels), dim=-1)
                q_tec = torch.softmax(logits_list[-1].view(-1, self.num_labels), dim=-1)

                kl_loss = torch.nn.functional.kl_div(p, q_tec, reduction='none').sum()
                reverse_kl_loss = torch.nn.functional.kl_div(q, p_tec, reduction='none').sum()
                loss += 0.5 * (kl_loss + reverse_kl_loss) / 2.

https://github.com/dropreg/R-Drop/blob/084365d0836b643e0743841528ff39ff88113eef/huggingface_transformer_src/src/transformers/models/bert/modeling_bert.py#L1545

dropreg commented 2 years ago

The reduction function "mean" or "sum" may requires picking different hyper-parameter alpha (we use "sum" as reported in paper). I think both ways are fine.