longlongman / CasRel-pytorch-reimplement

Pytorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. The original code is written in keras.
197 stars 52 forks source link

loss计算中的bug #15

Open yuanshengjun opened 3 years ago

yuanshengjun commented 3 years ago

framework.py文件中,存在以下bug:

if los.shape != mask.shape:
mask= mask.unsqueeze(dim=-1)

应该加入该行代码,否则在后续求loss均值时,torch.sum(input_masks), 计算的总数个数不齐全

mask= mask.repeat((1, 1, loss.size(2)))