Closed XiaoqingNLP closed 2 years ago
for logits in logits_list: if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() if loss: loss += alpha * loss_fct(logits.view(-1), labels.view(-1)) else: loss = alpha * loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() if loss: loss += alpha * loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) else: loss = alpha * loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if loss is not None: if self.num_labels == 1: loss_fct = MSELoss() loss += 1.0 * loss_fct(logits_list[0].view(-1), logits_list[-1].view(-1)) else: p = torch.log_softmax(logits_list[0].view(-1, self.num_labels), dim=-1) p_tec = torch.softmax(logits_list[0].view(-1, self.num_labels), dim=-1) q = torch.log_softmax(logits_list[-1].view(-1, self.num_labels), dim=-1) q_tec = torch.softmax(logits_list[-1].view(-1, self.num_labels), dim=-1) kl_loss = torch.nn.functional.kl_div(p, q_tec, reduction='none').sum() reverse_kl_loss = torch.nn.functional.kl_div(q, p_tec, reduction='none').sum() loss += 0.5 * (kl_loss + reverse_kl_loss) / 2.
https://github.com/dropreg/R-Drop/blob/084365d0836b643e0743841528ff39ff88113eef/huggingface_transformer_src/src/transformers/models/bert/modeling_bert.py#L1545
The reduction function "mean" or "sum" may requires picking different hyper-parameter alpha (we use "sum" as reported in paper). I think both ways are fine.
https://github.com/dropreg/R-Drop/blob/084365d0836b643e0743841528ff39ff88113eef/huggingface_transformer_src/src/transformers/models/bert/modeling_bert.py#L1545