Deriq-Qian-Dong / III-Retriever

Code for I3 Retriever, accepted by CIKM'23.
41 stars 2 forks source link

about generation loss #3

Open bjtuxck opened 5 months ago

bjtuxck commented 5 months ago

src/modeling code: if self.training:

只对positive计算generation loss

 # generated_q_hidden = generated_q_hidden[:,None,:,:].view(bz, self.args.sample_num, -1, p_hidden.size(-1))[:,0,:,:]  # [bz, seq_len, 768]
 if mlm_labels is not None and mlm_labels['decoder_mlm_labels'] is not None:
       mlm_loss += self.mlm_loss(generated_q_hidden, mlm_labels['decoder_mlm_labels'])  # query生成loss

如何保证只对【正样本】计算mlm_loss?负样本是如何过滤的?

Deriq-Qian-Dong commented 3 months ago

p_hidden的第一个位置都是正样本的hidden state,只用第一个就行了 参考代码中的p_hidden.size(-1))[:,0,:,:]