THUDM / WebGLM

WebGLM: An Efficient Web-enhanced Question Answering System (KDD 2023)
Apache License 2.0
1.57k stars 135 forks source link

关于train_retriever.py中的loss #49

Open llllooong opened 1 year ago

llllooong commented 1 year ago

麻烦问一下train_retriever.py文件中第44行求loss的函数中,cross_entropy的训练target为什么是是torch.arange(0, len(l_pos)呀? image

Longin-Yu commented 1 year ago
  1. 每一条训练数据包含一条强关联(作为 positive sample)与弱关联(作为 hard negative sample)。
  2. 训练过程中,若 batchsize 为 $n$,则同一个 batch 内将包含 $n$ 条 positive sample 和 $n$ 条 hard negative sample,对于每一条数据而言,只有它的 positive sample 是正例,其余 $2n - 1$ 条全都是负例。
  3. 将这 2n 条数据按 $(pos_1, \cdots, pos_n, neg_1, \cdots, neg_n)$ 的方式拼接起来后,第 $i$ 条数据的正样本 index 即为 $i$。
llllooong commented 1 year ago

那是不是得确保一个batch内,尽量少有相似问题?