fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.06k stars 450 forks source link

您好,请问怎么替换掉CRF层呢 #298

Closed tyistyler closed 4 years ago

tyistyler commented 4 years ago

我在使用fastnlp框架进行Bert+Bi-LSTM+CRF实验时,想要使用交叉熵直接代替CRF,但是每次迭代都会有一个预测的过程(crf对应为target is None-->self.crf.viterbi_decode(logits, mask)),请问这里可以怎么处理呀,十分感谢。

yhcc commented 4 years ago

也许再判断一下是否有crf,大概这样

if target is None:
  if hasattr(self, 'crf'):
    pred = self.crf.viterbi_decode(logits, mask))
  else:
    pred = logits.argmax(dim=-1)
tyistyler commented 4 years ago

非常感谢

tyistyler commented 4 years ago

还有一个小问题想再跟您请教一下,我在修改去掉CRF的时候,直接使用定义好的交叉熵来预测NER标签,但是结果很差,请问在fastnlp中我可以这样训练么? train.py-- loss = CrossEntropyLoss(pred='output',target='target',class_in_dim=2)

model.py---

(model.forward) chars, (hn, cn) = self.lstm(chars) chars = self.fc_dropout(chars) chars = self.out_fc(chars)#model_dim-->len(tag_vocab) logits = F.logsoftmax(chars, dim=-1) if target is None:#prediction ---#paths, = self.crf.viterbi_decode(logits, mask) ---#return {'pred': paths} ---#pred = torch.argmax(logits, axis=2) ---pred = logits.argmax(dim=-1) ---return {'pred': pred} else:#train ---#loss = self.crf(logits, target, mask) --- #return {'loss': loss} ---return {'output': logits}

yhcc commented 4 years ago

效果差的原因是由于直接argmax解码会有一个问题,就是非法越迁。例如出现B-PER,而下一个tag是I-LOC这种问题。可以通过一些规则来解决,或者使用维特比解码来搞定,

from fastNLP.modules import allowed_transitions, viterbi_decode

# 这个tag_vocab就是你的target的vocab,这部分代码你放在初始化
trans = allowed_transitions(tag_vocab, include_start_end=False)
        constrain = torch.full((len(tag_vocab), len(tag_vocab)), fill_value=-10000.0, dtype=torch.float)
for from_tag_id, to_tag_id in trans:
        constrain[from_tag_id, to_tag_id] = 0

# 预测的时候用这个预测, mask是句子的mask,为0的地方为pad
pred = viterbi_decode(pred, constrain.to(pred), mask=mask)[0]

这样就能排除非法越迁,通过这种方式,一般来说可以得到比较接近CRF的性能。但有一个缺点就是,由于用了维特比,解码速度比argmax慢了。