hemingkx / CLUENER2020

A PyTorch implementation of a BiLSTM\BERT\Roberta(+CRF) model for Named Entity Recognition.
470 stars 107 forks source link

请问怎么调用训练好的模型进行预测? #16

Open stay-leave opened 1 year ago

sociem commented 6 months ago

同问

At-Leisure commented 5 months ago

参考函数

./model.py/dev(data_loader, vocab, model, device, mode='dev')

参考代码

model.eval()
true_tags = []
pred_tags = []
sent_data = []

for idx, batch_samples in enumerate(dev_loader):
    sentences, labels, masks, lens = batch_samples
    sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0]
                        for (mask, indices) in zip(masks, sentences)])
    sentences = sentences.to(device)
    labels = labels.to(device)
    masks = masks.to(device)
    y_pred = model.forward(sentences)
    labels_pred = model.crf.decode(y_pred, mask=masks)
    targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)]
    true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets])
    pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])

print(f'预测标签:{pred_tags }')
sociem commented 4 months ago

感谢大佬!!!立刻尝试💪在 2024年5月10日,11:10,ZhouYaFei @.***> 写道: 参考函数 ./model.py/dev(data_loader, vocab, model, device, mode='dev') 参考代码 model.eval() true_tags = [] pred_tags = [] sent_data = []

for idx, batch_samples in enumerate(dev_loader): sentences, labels, masks, lens = batch_samples sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0] for (mask, indices) in zip(masks, sentences)]) sentences = sentences.to(device) labels = labels.to(device) masks = masks.to(device) y_pred = model.forward(sentences) labels_pred = model.crf.decode(y_pred, mask=masks) targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)] true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets]) pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])

print(f'预测标签:{pred_tags }')

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: @.***>