FlagOpen / FlagEmbedding

Retrieval and Retrieval-augmented LLMs
MIT License
7.39k stars 534 forks source link

关于微调的问题 #145

Open Alemax067 opened 1 year ago

Alemax067 commented 1 year ago

看源码微调的数据处理是在‘pos’中choice一个,然后把‘neg’接在后面构成passages:

    passages = []
    pos = random.choice(self.dataset[item]['pos'])
    passages.append(pos)

    if len(self.dataset[item]['neg']) < self.args.train_group_size - 1:
        num = math.ceil((self.args.train_group_size - 1) / len(self.dataset[item]['neg']))
        negs = random.sample(self.dataset[item]['neg'] * num, self.args.train_group_size - 1)
    else:
        negs = random.sample(self.dataset[item]['neg'], self.args.train_group_size - 1)
    passages.extend(negs)

计算loss的过程中passege也是一个整体:

    q_reps = self.encode(query)
    p_reps = self.encode(passage)

    if self.training:
        if self.negatives_cross_device:
            q_reps = self._dist_gather_tensor(q_reps)
            p_reps = self._dist_gather_tensor(p_reps)

        scores = self.compute_similarity(q_reps, p_reps)
        scores = scores / self.temperature
        scores = scores.view(q_reps.size(0), -1)

        target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
        target = target * (p_reps.size(0) // q_reps.size(0))
        loss = self.compute_loss(scores, target)

请问模型是怎么区分pos和neg的,或者说pos和neg怎么在微调中发挥作用的?

staoxiao commented 1 year ago

举个例子,假如每个样本取了7个neg,target就为[0, 8, 16, ...], 即第一个query的pos的索引为0, 第二个query的索引为8(前8个passage中是第一个query的pos和7个neg)

Alemax067 commented 1 year ago

举个例子,假如每个样本取了7个neg,target就为[0, 8, 16, ...], 即第一个query的pos的索引为0, 第二个query的索引为8(前8个passage中是第一个query的pos和7个neg)

懂了,谢谢!

sevenandseven commented 5 months ago

举个例子,假如每个样本取了7个neg,target就为[0, 8, 16, ...], 即第一个query的pos的索引为0, 第二个query的索引为8(前8个passage中是第一个query的pos和7个neg)

懂了,谢谢!

你好,还不是特别明白;假如说是target里边有8个值,那么第一个应该就是query对应的pos的索引,第二个是query对应的neg的索引,第三个也是query对应的neg的索引,依次向下,然后pos的索引一定是0吗?

请问是这样理解的吗?

Alemax067 commented 5 months ago

举个例子,假如每个样本取了7个neg,target就为[0, 8, 16, ...], 即第一个query的pos的索引为0, 第二个query的索引为8(前8个passage中是第一个query的pos和7个neg)

懂了,谢谢!

你好,还不是特别明白;假如说是target里边有8个值,那么第一个应该就是query对应的pos的索引,第二个是query对应的neg的索引,第三个也是query对应的neg的索引,依次向下,然后pos的索引一定是0吗?

请问是这样理解的吗?

不是哦,target中的每个索引都是pos在passage中的索引,比如说batchsize为8,passage_num也为8,意思就是每条数据1个pos加7个neg,八条数据的passage拼在一起所有pos的索引就是[0, 8, 16, 24, 32, 40, 48, 56]这样子,这好像跟loss函数有关,我记得他源码里的loss函数就要求target是这种构造方式

sevenandseven commented 5 months ago

举个例子,假如每个样本取了7个neg,target就为[0, 8, 16, ...], 即第一个query的pos的索引为0, 第二个query的索引为8(前8个passage中是第一个query的pos和7个neg)

懂了,谢谢!

你好,还不是特别明白;假如说是target里边有8个值,那么第一个应该就是query对应的pos的索引,第二个是query对应的neg的索引,第三个也是query对应的neg的索引,依次向下,然后pos的索引一定是0吗? 请问是这样理解的吗?

不是哦,target中的每个索引都是pos在passage中的索引,比如说batchsize为8,passage_num也为8,意思就是每条数据1个pos加7个neg,八条数据的passage拼在一起所有pos的索引就是[0, 8, 16, 24, 32, 40, 48, 56]这样子,这好像跟loss函数有关,我记得他源码里的loss函数就要求target是这种构造方式

明白了,感谢回复。