bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.36k stars 927 forks source link

GPlink代码在训练过程中引发nan的原因及可能的解决方案 #436

Open bucm-tcm-tool opened 2 years ago

bucm-tcm-tool commented 2 years ago

GPlink的这种设计思路非常棒,很有启发。 问题 但是GPlink的代码在训练过程中,有可能出现loss突然变为nan(inf)的情况。特别是小样本训练中(如2000个句子,每个句子平均包含3个三元组的数据集),loss经常出现nan。 原因 一个直观的原因是:在sparse_multilabel_categorical_crossentropy函数中,某些样本被反复学习多个epoches后,基于这些样本计算的K.exp(aux_loss - all_loss)=1,从而使K.log(1 - K.exp(aux_loss - all_loss))-->inf,致使loss为nan。 解决 对于该问题,或许可以在引发nan的代码中,加一个epslon,以避免出现inf,例如:K.log(1 - K.exp(aux_loss - all_loss)+epslon)。

以下是原始的sparse_multilabel_categorical_crossentropy函数,以及该函数引发nan的关键代码(加粗的部分):

def sparse_multilabel_categorical_crossentropy(y_true, y_pred, mask_zero=False): """稀疏版多标签分类的交叉熵 说明:

  1. y_true.shape=[..., num_positive], y_pred.shape=[..., num_classes];
  2. 请保证y_pred的值域是全体实数,换言之一般情况下 y_pred不用加激活函数,尤其是不能加sigmoid或者 softmax;
  3. 预测阶段则输出y_pred大于0的类;
  4. 详情请看:https://kexue.fm/archives/7359 。 """ zeros = K.zeros_like(y_pred[..., :1]) y_pred = K.concatenate([y_pred, zeros], axis=-1) if mask_zero: infs = zeros + K.infinity() y_pred = K.concatenate([infs, y_pred[..., 1:]], axis=-1) y_pos_2 = batch_gather(y_pred, y_true) y_pos_1 = K.concatenate([y_pos_2, zeros], axis=-1) if mask_zero: y_pred = K.concatenate([-infs, y_pred[..., 1:]], axis=-1) y_pos_2 = batch_gather(y_pred, y_true) pos_loss = K.logsumexp(-y_pos_1, axis=-1) aux_loss = K.logsumexp(y_pos_2, axis=-1) all_loss = K.logsumexp(y_pred, axis=-1) neg_loss = all_loss + K.log(1 - K.exp(aux_loss - all_loss))
bojone commented 2 years ago

谢谢提议,已加 https://github.com/bojone/bert4keras/commit/2e1bc3495789cdaa682c014b31e8425a0efeaac3#diff-fbe4e655d7da485d9313ca966df4bf963c294e0716fb6693e2a3af51a3afb7d7