webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
Other
1.61k stars 59 forks source link

feat: add support for BERTSQuAD question-answer inputs #111

Closed pixelspark closed 2 years ago

pixelspark commented 2 years ago

Add support for the tokenization used in BERT question-answer models (BERTSQuAD and others)

pixelspark commented 2 years ago

So, after many hours of work... this finally is ready to go! With this PR, WONNX supports the BERTSQuAD model (and possibly, with perhaps some minor modifications, other BERT models). BERT is typically used in text comprehension tasks. The model I tested in particular is the BERTSQuAD model for Q&A. Basically the model accepts a 'context' and 'question' text, and will identify the pieces of the 'context' text that are most likely the answer to the stated question.

If you want to play with this, you can use this modified version of the BERTSQuAD model. This is based on the ONNX version of BERTSquAD. I ran the model through onnx-simplifier and also had to pull some other tricks to remove dynamic dimensions from the model. You will also need this vocabulary file. When you have both files, you can use the model as follows:

cargo run -- infer bertsquad-12-inferred-fixed-2.onnx \
    --vocab ./bertsquad-vocab.txt \
    --qa-answer \
    --context "I finally finished the pull request and uploaded it to GitHub" \
    --question "what did i finish?"

This should output the pull request. Isn't that just extremely satisfying? :-)

Some other improvements included in this PR:

@haixuanTao would you be so kind to review? (tests succeed on my computer, haven't checked CI results yet)

pixelspark commented 2 years ago

@haixuanTao also something weird seems to be happing wit the Python tests, can you check? Seems unrelated to this PR. The Python onnx_backend test passes just fine on my computer...

haixuanTao commented 2 years ago

It is really impressive!!! Do you mind pinning the version of the python requirements.txt dependencies to your computer? You can check it with pip freeze I don't really have time to fix the issue now. The PR looks fine for me :)