Open yuanshengjun opened 3 years ago
framework.py文件中,存在以下bug:
if los.shape != mask.shape: mask= mask.unsqueeze(dim=-1)
mask= mask.repeat((1, 1, loss.size(2)))
framework.py文件中,存在以下bug:
if los.shape != mask.shape:
mask= mask.unsqueeze(dim=-1)
应该加入该行代码,否则在后续求loss均值时,torch.sum(input_masks), 计算的总数个数不齐全