BrikerMan / Kashgari

Kashgari is a production-level NLP Transfer learning framework built on top of tf.keras for text-labeling and text-classification, includes Word2Vec, BERT, and GPT2 Language Embedding.
http://kashgari.readthedocs.io/
Apache License 2.0
2.4k stars 441 forks source link

ner任务 tf_serving 调用问题 #480

Closed kehlaaa closed 3 years ago

kehlaaa commented 3 years ago

You must follow the issue template and provide as much information as possible. otherwise, this issue will be closed. 请按照 issue 模板要求填写信息。如果没有按照 issue 模板填写,将会忽略并关闭这个 issue

Check List

Thanks for considering to open an issue. Before you submit your issue, please confirm these boxes are checked.

You can post pictures, but if specific text or code is required to reproduce the issue, please provide the text in a plain text format for easy copy/paste.

Environment

[Paste requirements.txt file here]

Question

你好,我按照v2.0.1文档中提到的方法,将 Bert + CRF 实现的一个NER模型使用tf_serving docker部署好了,

import requests
import numpy as np
from kashgari.processors import load_processors_from_model

text_processor, label_processor = load_processors_from_model('/Users/brikerman/Desktop/tf-serving/1603683152')

samples = [
    ['hello', 'world'],
    ['你', '好', '世', '界']
]
tensor = text_processor.transform(samples)

instances = [{
   "Input-Token": i.tolist(),
   "Input-Segment": np.zeros(i.shape).tolist()
} for i in tensor]

# predict
r = requests.post("http://localhost:8501/v1/models/bgru:predict", json={"instances": instances})
predictions = r.json()['predictions']
# Convert result back to labels
labels = label_processor.inverse_transform(np.array(predictions).argmax(-1))
print(labels)

在最后输出类型时,报了如下错误:

Traceback (most recent call last):
  File "ner_pred.py", line 168, in <module>
    labels = label_processor.inverse_transform(np.array(predictions).argmax(-1))
  File "/usr/local/miniconda3/lib/python3.7/site-packages/kashgari/processors/sequence_processor.py", line 154, in inverse_transform
    for index, seq in enumerate(labels):
TypeError: 'numpy.int64' object is not iterable

因为接触深度学习时间较短,问题比较基础,麻烦解答一下,谢谢