wangyuxinwhy / uniem

unified embedding model
Apache License 2.0
826 stars 64 forks source link

CoSentLoss的一点疑惑, #95

Open NLPJCL opened 1 year ago

NLPJCL commented 1 year ago

🐛 bug 说明

CoSentLoss的输入数据格式是:

ScoredPairRecord 就是带有分数的句对样本,在 PairRecord 的基础上添加了句对的相似分数(程度)。字段的名称是 sentence1 和 sentence2,以及 label。

1.0 代表相似,0.0 代表不相似

scored_pair_record1 = ScoredPairRecord(sentence1='肾结石如何治疗?', sentence2='如何治愈肾结石', label=1.0) scored_pair_record2 = ScoredPairRecord(sentence1='肾结石如何治疗?', sentence2='胆结石有哪些治疗方法?', label=0.0) print(f'scored_pair_record: {scored_pair_record1}') print(f'scored_pair_record: {scored_pair_record2}')

1.CoSentLoss实现的时候,有一个 smaller_mask = true_similarity.unsqueeze(0) <= true_similarity.unsqueeze(1) 是不是写反了,应该是true_similarity.unsqueeze(1) <= true_similarity.unsqueeze(0) batch为4: label:[0,1,1,2] 那么第一个label得到结果就是[true,true,true,true].把label为0的loss全部mask掉。 2.这种情况下,是不是不能打乱句子的顺序了,否则对比的就是不同query和label的score值了,没有啥意义?

Python Version

None

NLPJCL commented 1 year ago

1.明白了,没写反,写的有点乱, 这里突然加了个负号,没注意看。 cosine_similarity_diff = -(predict_similarity.unsqueeze(0) - predict_similarity.unsqueeze(1)) 2.没仔细看,也明白了,其实loss优化的是query1和pos1的余弦相似度大于query2和pos2的余弦相似度。(labe1大于label2的情况下)。至于query1和pos1的余弦相似度拉到多近,由模型自己决定。

GingerNg commented 1 year ago

smaller_mask = true_similarity.unsqueeze(0) <= true_similarity.unsqueeze(1) 不太理解,为什么这里是这里是<=, 而不是<?

wangyuxinwhy commented 1 year ago

= 的目的是将那些相似度相同的 pair,也排除在 loss 之外。