onnx / models

A collection of pre-trained, state-of-the-art models in the ONNX format
http://onnx.ai/models/
Apache License 2.0
7.93k stars 1.4k forks source link

run_onnx_squad.py fails with "Model requires 4 inputs. Input Feed contains 3" #216

Open WilliamTambellini opened 5 years ago

WilliamTambellini commented 5 years ago

run_onnx_squad.py from https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad fails with exception: "Model requires 4 inputs. Input Feed contains 3"

Steps to repro: clone that repo

git lfs fetch --include=text/machine_comprehension/bert-squad/model/bertsquad10.onnx.tar.gz

wait for download

extract the real onnx file from the local git cache (could nt find a better way): $ file ~/repos/models/.git/lfs/objects/1c/ec/1cec14b36fac3e09b2ea54b8de297e12bafabd0fb9d123ad10b6d45459a835a6 /home/wtambellini/repos/models/.git/lfs/objects/1c/ec/1cec14b36fac3e09b2ea54b8de297e12bafabd0fb9d123ad10b6d45459a835a6: gzip compressed data, was "bert.onnx.onnx", to text/machine_comprehension/bert-squad/model/bertsquad10.onnx

create the inputs.json as explained there: https://github.com/onnx/models/tree/master/text/machine_comprehension/bert-squad

download the vocab file from the zip bert model from the bert repo : https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip

install a recent onnxrt : $ sudo pip3.5 install --upgrade onnxruntime Collecting onnxruntime Downloading https://files.pythonhosted.org/packages/2a/26/52b66fcea1a79b1c873df22bc9844895e6b1ef356c5bb7ee4da260af2ad2/onnxruntime-0.5.0-cp35-cp35m-manylinux2010_x86_64.whl (3.2MB) 100% |████████████████████████████████| 3.2MB 219kB/s Installing collected packages: onnxruntime Found existing installation: onnxruntime 0.2.1 Uninstalling onnxruntime-0.2.1: Successfully uninstalled onnxruntime-0.2.1 Successfully installed onnxruntime-0.5.0

try to run a simple inference : python3.5 dependencies/run_onnx_squad.py --model model/bertsquad10.onnx --vocab_file ~/Downloads/bert/uncased_L-12_H-768_A-12/vocab.txt --predict_file inputs.json --output /tmp

See that the onnx expects 4 inputs but the py script only gives 3 :

onnxrt expected inputs: NodeArg(name='unique_ids_raw_output_9:0', type='tensor(int64)', shape=['unk__485']) NodeArg(name='segment_ids:0', type='tensor(int64)', shape=['unk486', 256]) NodeArg(name='input_mask:0', type='tensor(int64)', shape=['unk__487', 256]) NodeArg(name='input_ids:0', type='tensor(int64)', shape=['unk__488', 256])

input data is created line 556 with :
data = {"input_ids:0": input_ids[idx:idx + bs], "input_mask:0": input_mask[idx:idx + bs], "segment_ids:0": segment_ids[idx:idx + bs] } so indeed, 'unique_ids_raw_output___9:0' is missing.

According to the doc, the missing input is : "label_ids: one-hot encoded labels for the text "

Kind

vinitra-zz commented 4 years ago

In https://github.com/onnx/models/pull/232#issuecomment-549672816, @KeDengMS mentioned a possible fix that he had implemented as follows:

import onnx
mp = onnx.load('/path/to/bertsquad10.onnx')
input = [i for i in mp.graph.input if i.name == 'unique_ids_raw_output___9:0']
mp.graph.input.remove(input[0])
output = [o for o in mp.graph.output if o.name == 'unique_ids:0']
mp.graph.output.remove(output[0])
node = [n for n in mp.graph.node if n.name == 'unique_ids_graph_outputs_Identity__10']
mp.graph.node.remove(node[0])
onnx.save(mp, '/path/to/bertsquad10.onnx')
lefromage commented 4 years ago

now the error is

NodeArg(name='segment_ids:0', type='tensor(int64)', shape=['unk493', 256]) NodeArg(name='input_mask:0', type='tensor(int64)', shape=['unk__494', 256]) NodeArg(name='input_ids:0', type='tensor(int64)', shape=['unk495', 256]) at 100 0.015084990309551358sec per item at 200 0.007544085942208767sec per item at 300 0.009093221391861637sec per item at 400 0.0068210667744278905sec per item at 500 0.0054574373587965965sec per item at 600 0.0062560025447358685sec per item at 700 0.005363190172772323sec per item total time: 3.7544796019792557sec, 0.004888645315077156sec per item Traceback (most recent call last): File "models/text/machine_comprehension/bert-squad/dependencies/run_onnx_squad.py", line 576, in sys.exit(main()) File "models/text/machine_comprehension/bert-squad/dependencies/run_onnx_squad.py", line 568, in main True, output_prediction_file, output_nbest_file) File "models/text/machine_comprehension/bert-squad/dependencies/run_onnx_squad.py", line 284, in write_predictions start_indexes = _get_best_indexes(result.start_logits, n_best_size) File "models/text/machine_comprehension/bert-squad/dependencies/run_onnx_squad.py", line 476, in _get_best_indexes index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) TypeError: 'numpy.float32' object is not iterable

kunmingho commented 4 years ago

Fixed several issues in run_onnx_squad.py. You can give it a try. run_onnx_squad.py.gz