# generated_q_hidden = generated_q_hidden[:,None,:,:].view(bz, self.args.sample_num, -1, p_hidden.size(-1))[:,0,:,:] # [bz, seq_len, 768]
if mlm_labels is not None and mlm_labels['decoder_mlm_labels'] is not None:
mlm_loss += self.mlm_loss(generated_q_hidden, mlm_labels['decoder_mlm_labels']) # query生成loss
src/modeling code: if self.training:
只对positive计算generation loss
如何保证只对【正样本】计算mlm_loss?负样本是如何过滤的?