Open stay-leave opened 1 year ago
./model.py/dev(data_loader, vocab, model, device, mode='dev')
model.eval()
true_tags = []
pred_tags = []
sent_data = []
for idx, batch_samples in enumerate(dev_loader):
sentences, labels, masks, lens = batch_samples
sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0]
for (mask, indices) in zip(masks, sentences)])
sentences = sentences.to(device)
labels = labels.to(device)
masks = masks.to(device)
y_pred = model.forward(sentences)
labels_pred = model.crf.decode(y_pred, mask=masks)
targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)]
true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets])
pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])
print(f'预测标签:{pred_tags }')
感谢大佬!!!立刻尝试💪在 2024年5月10日,11:10,ZhouYaFei @.***> 写道: 参考函数 ./model.py/dev(data_loader, vocab, model, device, mode='dev') 参考代码 model.eval() true_tags = [] pred_tags = [] sent_data = []
for idx, batch_samples in enumerate(dev_loader): sentences, labels, masks, lens = batch_samples sent_data.extend([[vocab.id2word.get(idx.item()) for i, idx in enumerate(indices) if mask[i] > 0] for (mask, indices) in zip(masks, sentences)]) sentences = sentences.to(device) labels = labels.to(device) masks = masks.to(device) y_pred = model.forward(sentences) labels_pred = model.crf.decode(y_pred, mask=masks) targets = [itag[:ilen] for itag, ilen in zip(labels.cpu().numpy(), lens)] true_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in targets]) pred_tags.extend([[vocab.id2label.get(idx) for idx in indices] for indices in labels_pred])
print(f'预测标签:{pred_tags }')
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you commented.Message ID: @.***>
同问