lonePatient / BERT-NER-Pytorch

Chinese NER(Named Entity Recognition) using BERT(Softmax, CRF, Span)
MIT License
2.05k stars 425 forks source link

总觉 在bert span中start位置提供的信息结合倒end的方式有问题。 #68

Open orangetwo opened 2 years ago

orangetwo commented 2 years ago

如果 start_positions (假设batch size=1, 忽略batch这个纬度) 为 tensor([0, 0, 0, 2, 0, 8, 0, 0, 0, 0]) end_positions 为 tensor([0, 0, 0, 0, 2, 0, 0, 8, 0, 0]) 那么 label_logits = tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 1.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.]])

代码里self.end_fc(sequence_output, label_logits) 会对end 和 label_logits进行拼接 这里我觉得对 end端输入 start部分的信息 没什么用,因为 对于 end中的第5个token id 即2 它要拼接的对象是 label_logits中第5行 在我的理解中 因为它是end 他应该拼接的是 start中 token_id =2的信息 即应该拼接为label_logits中的第4行 如果 模型是序列的 还会有可能把start的信息 传递给 end,但是bert后的各个token的fc是独立计算的 所以 感觉这里并没有把start信息很好的传递给end的

WinnieRerverse commented 2 years ago

我的理解,self.end_fc(sequence_output, label_logits)这个是BERT的隐层输出和 label_logits拼接,就是想计算end_logits 除了sequence_output信息外能融合start_positions(label_logits)的信息,而计算start_logits只有sequence_output信息,但是我觉得这样直接拼接效果不好,一是隐层的维度比label的维度大的多,二是label_logits都是0或1,比较稀疏,和隐层embeding的数值不太一样