DRSY / EMO

[ICLR 2024]EMO: Earth Mover Distance Optimization for Auto-Regressive Language Modeling(https://arxiv.org/abs/2310.04691)
111 stars 13 forks source link

分布式多机训练, loss 训练 300 step 后会变成负数 #4

Open jiaruipeng1994 opened 10 months ago

jiaruipeng1994 commented 10 months ago

在我自己的代码中引入 emo, 使用的 bf16 训练, 训练过程中 loss 会变负, 可能是什么原因呢

loss: -2.334918e-02, loss_cur_dp: -2.334918e-02

DRSY commented 10 months ago

你好,方便具体讲一下是用的仓库里的哪部分代码吗

jiaruipeng1994 commented 10 months ago
# ======================================================================== #
#                   Compute the EMO loss
# ======================================================================== #
labels_tmp = labels.clone()
labels_tmp[labels_tmp==(-100)] = 0
one_hot = torch.nn.functional.one_hot(labels_tmp, num_classes=self.config.vocab_size).to(logits.dtype)
stable_onehot = (one_hot+1e-15) / torch.linalg.vector_norm((one_hot+1e-15), ord=1, dim=-1, keepdim=True) # (bsz*seq_len, vocab_size)
embedding_matrix = self.lm_head.weight.data.detach() # (vocab_size, hidden_size)
embedding_matrix = embedding_matrix / torch.linalg.vector_norm(embedding_matrix, ord=2, dim=1, keepdim=True)
p_contextual_repr = stable_onehot @ embedding_matrix # (bsz*seq_len, hidden_size)
q_grad = torch.log_softmax(logits, dim=-1).exp() # (bsz*seq_len, vocab_size)
q_contextual_repr = q_grad @ embedding_matrix # (bsz*seq_len, hidden_size)
threshold = (1 - torch.sum(p_contextual_repr*p_contextual_repr, dim=-1))
emo_loss = (1 - torch.sum(p_contextual_repr*q_contextual_repr, dim=-1)) # (bsz*seq_len,)
emo_loss = (torch.abs(emo_loss-threshold)+threshold) * mask

# ======================================================================== #
#                   Compose the final loss
# ======================================================================== #
loss = (emo_loss / (mle_loss+1e-10)).detach() * mle_loss + emo_loss
loss = (loss * mask).sum() / (1e-15 + mask.sum())
DRSY commented 10 months ago

相关代码已更新,您再试一下呢?