bojone / SimCSE

SimCSE在中文任务上的简单实验
591 stars 83 forks source link

请问损失函数该如何理解呢? #5

Closed FanWan closed 3 years ago

FanWan commented 3 years ago

def simcse_loss(y_true, y_pred): """用于SimCSE训练的loss """

构造标签

idxs = K.arange(0, K.shape(y_pred)[0])
idxs_1 = idxs[None, :]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())
# 计算相似度
y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20
loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
return K.mean(loss)