Closed willmostly closed 3 months ago
@willmostly
The root cause of the exception is because PtBertQATranslator
has limitation when using batch. It requires padding
if run in batch mode. We didn't enable padding
by default because it hurt performance for single prediction case.
It should work if you use the following code:
Criteria<QAInput, String> criteria =
Criteria.builder()
.optApplication(Application.NLP.QUESTION_ANSWER)
.setTypes(QAInput.class, String.class)
.optFilter("backbone", "bert")
.optEngine("PyTorch")
.optDevice(Device.cpu())
.optArgument("padding", "true")
.optProgress(new ProgressBar())
.build();
A few comments regarding your project settings:
Tysm for the guidance @frankfliu! I confirm that adding .optArgument("padding", "true")
resolves the error message, I will close this issue.
It appears that batchPredict
is less accurate than running this model in single prediction mode. In single prediction mode it produces correct answers to each test question, while this is not the case with batchPredict
. I'm just noting this in case you want to follow up on it - i'll transition over to the Huggingface models going forward.
Description
Using the
batchPredict
method with the BERT model and pytorch engine throwswhen more than one QAInput are submitted.
I attempted to extend the Bert QA example to use
batchPredict
. My background is in Java, not ML, so I'm not sure how to interpret this error. Padding the input strings to the same length did not help. If I submit aList
with a single entry, the error does not occur.Expected Behavior
batchPredict returns without error
Error Message
How to Reproduce?
(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
I'm runing this in a standalone project with the dependencies
What have you tried to solve it?
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: