Closed FanWan closed 3 years ago
def simcse_loss(y_true, y_pred): """用于SimCSE训练的loss """
idxs = K.arange(0, K.shape(y_pred)[0]) idxs_1 = idxs[None, :] idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None] y_true = K.equal(idxs_1, idxs_2) y_true = K.cast(y_true, K.floatx()) # 计算相似度 y_pred = K.l2_normalize(y_pred, axis=1) similarities = K.dot(y_pred, K.transpose(y_pred)) similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12 similarities = similarities * 20 loss = K.categorical_crossentropy(y_true, similarities, from_logits=True) return K.mean(loss)
def simcse_loss(y_true, y_pred): """用于SimCSE训练的loss """
构造标签