fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.06k stars 450 forks source link

Getting the prediction from sequence labeling #313

Closed tim6220 closed 3 years ago

tim6220 commented 4 years ago

Any API to get the prediction of NER in the tutorial (https://fastnlp.readthedocs.io/zh/latest/tutorials/%E5%BA%8F%E5%88%97%E6%A0%87%E6%B3%A8.html) instead of the accuracy result? Thank you.

yhcc commented 4 years ago

you need to keep the Vocabulary used in pre-process, which is recorded in the data_bundle, an example (code should be added after the training code, otherwise, you need to save vocabulary and model) like

vocab = data_bundle.get_vocab('chars')
target_vocab = data_bundle.get_vocab('target')

chars = ["这是一个测试"]
indexed_chars = [vocab.to_index(c) for c in chars]

indexed_chars = torch.LongTensor([[indexed_chars]])
seq_len = torch.LongTensor([len(chars)])  # usually you need transfer indexed_chars and seq_len to the device where model locates
pred = model.predict(indexed_chars, seq_len)['pred']  # 1 x len(chars)的tensor

pred = [target_vocab.to_word(w) for w in pred[0].tolist()]  # you will get something like ['O', 'O', ...]