deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.17k stars 660 forks source link

Error using batchPredict with QAInput #3383

Closed willmostly closed 3 months ago

willmostly commented 3 months ago

Description

Using the batchPredict method with the BERT model and pytorch engine throws

Caused by: ai.djl.engine.EngineException: stack expects each tensor to be equal size, but got [56] at entry 0 and [55] at entry 1

when more than one QAInput are submitted.

I attempted to extend the Bert QA example to use batchPredict. My background is in Java, not ML, so I'm not sure how to interpret this error. Padding the input strings to the same length did not help. If I submit a List with a single entry, the error does not occur.

Expected Behavior

batchPredict returns without error

Error Message

SLF4J(W): No SLF4J providers were found.
SLF4J(W): Defaulting to no-operation (NOP) logger implementation
SLF4J(W): See https://www.slf4j.org/codes.html#noProviders for further details.
Loading:     100% |████████████████████████████████████████|
Exception in thread "main" ai.djl.translate.TranslateException: java.lang.IllegalArgumentException: You cannot batch data with different input shapes(55) vs (56)
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:196)
    at com.starburst.djl.BatchFailExample.batchPredict(BatchFailExample.java:48)
    at com.starburst.djl.BatchFailExample.main(BatchFailExample.java:65)
Caused by: java.lang.IllegalArgumentException: You cannot batch data with different input shapes(55) vs (56)
    at ai.djl.translate.StackBatchifier.batchify(StackBatchifier.java:83)
    at ai.djl.translate.Translator.batchProcessInput(Translator.java:105)
    at ai.djl.inference.Predictor.batchPredict(Predictor.java:184)
    ... 2 more
Caused by: ai.djl.engine.EngineException: stack expects each tensor to be equal size, but got [56] at entry 0 and [55] at entry 1
    at ai.djl.pytorch.jni.PyTorchLibrary.torchStack(Native Method)
    at ai.djl.pytorch.jni.JniUtils.stack(JniUtils.java:626)
    at ai.djl.pytorch.engine.PtNDArrayEx.stack(PtNDArrayEx.java:663)
    at ai.djl.pytorch.engine.PtNDArrayEx.stack(PtNDArrayEx.java:33)
    at ai.djl.ndarray.NDArrays.stack(NDArrays.java:1825)
    at ai.djl.ndarray.NDArrays.stack(NDArrays.java:1785)
    at ai.djl.translate.StackBatchifier.batchify(StackBatchifier.java:54)
    ... 4 more

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

package com.starburst.djl;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class BatchFailExample
{

    private final Criteria<QAInput, String> criteria;

    private final Predictor<QAInput, String> predictor;
    private final ZooModel<QAInput, String> model;

    public BatchFailExample()
    {
        criteria = Criteria.builder()
                .optApplication(Application.NLP.QUESTION_ANSWER)
                .setTypes(QAInput.class, String.class)
                .optFilter("backbone", "bert")
                .optEngine("PyTorch")
                .optDevice(Device.cpu())
                .optProgress(new ProgressBar())
                .build();
        try {
            model = criteria.loadModel();
            predictor = model.newPredictor();
        }
        catch (ModelNotFoundException | MalformedModelException | IOException e) {
            throw new RuntimeException(e);
        }
    }

    public List<String> batchPredict(List<QAInput> inputs)
            throws TranslateException
    {
        return predictor.batchPredict(inputs);
    }

    public static void main(String[] args)
            throws TranslateException
    {
        String data = """
               Claremont is the only city in Sullivan County, New Hampshire, United States.
               The population was 12,949 at the 2020 census.[4] Claremont is a core city of the 
               Lebanon–Claremont micropolitan area, a bi-state, four-county region in the upper Connecticut River valley.
               """;

        List<QAInput> qas = new ArrayList<>();
        qas.add(new QAInput("How many people live in Claremont?", data));
        qas.add(new QAInput("What river is Claremont near?", data));
        qas.add(new QAInput("How many cities are in Sullivan county", data));
        BatchFailExample batchFailExample = new BatchFailExample();
        List<String> answers = batchFailExample.batchPredict(qas);
        System.out.println(answers.getFirst());
        System.out.println(answers.get(1));
        System.out.println(answers.get(2));
    }
}

I'm runing this in a standalone project with the dependencies

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.29.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.mxnet</groupId>
            <artifactId>mxnet-model-zoo</artifactId>
            <version>0.29.0</version>
        </dependency>
        <!-- Pytorch -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.29.0</version>
        </dependency>
        <!-- TensorFlow -->
        <dependency>
            <groupId>ai.djl.tensorflow</groupId>
            <artifactId>tensorflow-model-zoo</artifactId>
            <version>0.29.0</version>
        </dependency>
        <!-- ONNXRuntime -->
        <dependency>
            <groupId>ai.djl.onnxruntime</groupId>
            <artifactId>onnxruntime-engine</artifactId>
            <version>0.29.0</version>
        </dependency>

What have you tried to solve it?

  1. Padding input to fixed number of characters

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

----------- System Properties -----------
java.specification.version: 21
sun.jnu.encoding: UTF-8
java.class.path: /Users/will.morrison/workspace/repos/djl/integration/build/classes/java/main:/Users/will.morrison/workspace/repos/djl/integration/build/resources/main:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.8.0/41a4bff12057eecb6daaf9c7f36c237815be3da1/commons-cli-1.8.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j2-impl/2.23.1/c3ffee33404c3a178f026fd8c7ef0e058b01b01c/log4j-slf4j2-impl-2.23.1.jar:/Users/will.morrison/workspace/repos/djl/basicdataset/build/libs/basicdataset-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/model-zoo/build/libs/model-zoo-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/testing/build/libs/testing-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/pytorch/pytorch-model-zoo/build/libs/pytorch-model-zoo-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/pytorch/pytorch-jni/build/libs/pytorch-jni-2.4.0-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/tensorflow/tensorflow-model-zoo/build/libs/tensorflow-model-zoo-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/ml/xgboost/build/libs/xgboost-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/ml/lightgbm/build/libs/lightgbm-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/onnxruntime/onnxruntime-engine/build/libs/onnxruntime-engine-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/extensions/tokenizers/build/libs/tokenizers-0.30.0-SNAPSHOT.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.23.1/905802940e2c78042d75b837c136ac477d2b4e4d/log4j-core-2.23.1.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.23.1/9c15c29c526d9c6783049c0a77722693c66706e1/log4j-api-2.23.1.jar:/Users/will.morrison/workspace/repos/djl/engines/pytorch/pytorch-engine/build/libs/pytorch-engine-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/engines/tensorflow/tensorflow-engine/build/libs/tensorflow-engine-0.30.0-SNAPSHOT.jar:/Users/will.morrison/workspace/repos/djl/api/build/libs/api-0.30.0-SNAPSHOT.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.10.2/30742acada21960d4333a4204039fbdc6a92083a/testng-7.10.2.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/2.0.13/80229737f704b121a318bba5d5deacbcf395bc77/slf4j-api-2.0.13.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.11.0/8f2dc805097da534612128b7cdf491a5a76752bf/commons-csv-1.11.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/ml.dmlc/xgboost4j_2.12/2.0.3/db511d04d1ca1364cde79a6c8238a2694e31c592/xgboost4j_2.12-2.0.3.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/commons-logging/commons-logging/1.3.3/580ad1a4f34876c4f964c083361de31b3d60be68/commons-logging-1.3.3.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.microsoft.ml.lightgbm/lightgbmlib/3.2.110/f6c85e5d7cc44d49c4544240ea5c96004680007b/lightgbmlib-3.2.110.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.microsoft.onnxruntime/onnxruntime/1.18.0/bacf73dc2e1d92941744c6a3f8c01fc674189d36/onnxruntime-1.18.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.11.0/527175ca6d81050b53bdd4c457a6d6e017626b0e/gson-2.11.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.14.0/67bf3eaea4f0718cb376a181a629e5f88fa1c9dd/jna-5.14.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.26.2/eb1f823447af685208e684fce84783b43517960c/commons-compress-1.26.2.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/commons-io/commons-io/2.16.1/377d592e740dc77124e0901291dbfaa6810a200e/commons-io-2.16.1.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/commons-codec/commons-codec/1.17.0/dbe8eef6e14460e73da07f7b11bf994d6626355/commons-codec-1.17.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.82/a7c5fef184d238065de38f81bbc6ee50cca2e21/jcommander-1.82.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.webjars/jquery/3.7.1/42088e652462c40a369b64d87e18e825644acfab/jquery-3.7.1.jar:/Users/will.morrison/workspace/repos/djl/engines/tensorflow/tensorflow-api/build/libs/tensorflow-api-0.30.0-SNAPSHOT.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.scala-lang.modules/scala-collection-compat_2.12/2.10.0/bf81785e892f4185f470bddd205b011237aab553/scala-collection-compat_2.12-2.10.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.google.errorprone/error_prone_annotations/2.27.0/91b2c29d8a6148b5e2e4930f070d4840e2e48e34/error_prone_annotations-2.27.0.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-api/1.0.0-rc.1/ea1878fb8e289742237e5a0ba6f15398f3e9b7ef/tensorflow-core-api-1.0.0-rc.1.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-native/1.0.0-rc.1/62b5fa3283865cc696dfbebf073ca2116b18f327/tensorflow-core-native-1.0.0-rc.1.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.bytedeco/javacpp/1.5.10/afb6ae145e7563c66b677cb4896dd0197d49fce6/javacpp-1.5.10.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/com.google.protobuf/protobuf-java/3.25.3/d3200261955f3298e0d85c9892201e70492ce8eb/protobuf-java-3.25.3.jar:/Users/will.morrison/.gradle/caches/modules-2/files-2.1/org.tensorflow/ndarray/1.0.0-rc.1/4a96a398ad87bec32be9177b1441b9880c04d822/ndarray-1.0.0-rc.1.jar
java.vm.vendor: Eclipse Adoptium
sun.arch.data.model: 64
user.variant: 
java.vendor.url: https://adoptium.net/
user.timezone: America/New_York
java.vm.specification.version: 21
os.name: Mac OS X
user.country: US
sun.java.launcher: SUN_STANDARD
sun.boot.library.path: /Library/Java/JavaVirtualMachines/temurin-21.jdk/Contents/Home/lib
sun.java.command: ai.djl.integration.util.DebugEnvironment
http.nonProxyHosts: local|*.local|169.254/16|*.169.254/16
jdk.debug: release
sun.cpu.endian: little
user.home: /Users/will.morrison
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2024-04-16
java.home: /Library/Java/JavaVirtualMachines/temurin-21.jdk/Contents/Home
file.separator: /
java.vm.compressedOopsMode: Zero based
line.separator: 

java.vm.specification.vendor: Oracle Corporation
java.specification.name: Java Platform API Specification
apple.awt.application.name: DebugEnvironment
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
ftp.nonProxyHosts: local|*.local|169.254/16|*.169.254/16
java.runtime.version: 21.0.3+9-LTS
user.name: will.morrison
stdout.encoding: UTF-8
path.separator: :
os.version: 14.5
java.runtime.name: OpenJDK Runtime Environment
file.encoding: UTF-8
java.vm.name: OpenJDK 64-Bit Server VM
java.vendor.version: Temurin-21.0.3+9
java.vendor.url.bug: https://github.com/adoptium/adoptium-support/issues
java.io.tmpdir: /var/folders/2p/cp6y8k2951v45b0z_0r9nzb80000gp/T/
java.version: 21.0.3
user.dir: /Users/will.morrison/workspace/repos/djl/integration
os.arch: aarch64
java.vm.specification.name: Java Virtual Machine Specification
native.encoding: UTF-8
java.library.path: /Users/will.morrison/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
java.vm.info: mixed mode
stderr.encoding: UTF-8
java.vendor: Eclipse Adoptium
java.vm.version: 21.0.3+9-LTS
sun.io.unicode.encoding: UnicodeBig
socksNonProxyHosts: local|*.local|169.254/16|*.169.254/16
java.class.version: 65.0

--------- Environment Variables ---------
CONDA_PROMPT_MODIFIER: (base) 
HOMEBREW_PREFIX: /opt/homebrew
SHLVL: 1
INFOPATH: /opt/homebrew/share/info:
SHELL: /bin/zsh
TMPDIR: /var/folders/2p/cp6y8k2951v45b0z_0r9nzb80000gp/T/
__CFBundleIdentifier: com.apple.Terminal
APP_ICON_5400: /Users/will.morrison/workspace/repos/djl/media/gradle.icns
HOME: /Users/will.morrison
LaunchInstanceID: E71FEB6E-B353-40D4-8B0A-FED17BCDEEE4
HOMEBREW_REPOSITORY: /opt/homebrew
CONDA_PREFIX: /opt/homebrew/Caskroom/miniconda/base
PATH: /opt/homebrew/Caskroom/miniconda/base/bin:/opt/homebrew/Caskroom/miniconda/base/condabin:/Users/will.morrison/go/bin:/opt/homebrew/anaconda3/bin:/opt/homebrew/opt/libpq/bin:/opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/System/Cryptexes/App/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/local/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/bin:/var/run/com.apple.security.cryptexd/codex.system/bootstrap/usr/appleinternal/bin:/Users/will.morrison/.orbstack/bin
LOGNAME: will.morrison
TERM: xterm-256color
__CF_USER_TEXT_ENCODING: 0x1F6:0x0:0x0
XPC_FLAGS: 0x0
_CE_CONDA: 
LANG: en_US.UTF-8
_CE_M: 
TERM_PROGRAM_VERSION: 453
TERM_PROGRAM: Apple_Terminal
CONDA_SHLVL: 1
CONDA_EXE: /opt/homebrew/Caskroom/miniconda/base/bin/conda
SSH_AUTH_SOCK: /private/tmp/com.apple.launchd.gankGUFb8K/Listeners
OLDPWD: /Users/will.morrison/workspace/repos/djl
XPC_SERVICE_NAME: 0
CONDA_DEFAULT_ENV: base
PROMPT: %! %4~ $ 
USER: will.morrison
SECURITYSESSIONID: 186b6
HOMEBREW_CELLAR: /opt/homebrew/Cellar
PWD: /Users/will.morrison/workspace/repos/djl
CONDA_PYTHON_EXE: /opt/homebrew/Caskroom/miniconda/base/bin/python

-------------- Directories --------------
temp directory: /var/folders/2p/cp6y8k2951v45b0z_0r9nzb80000gp/T
DJL cache directory: /Users/will.morrison/.djl.ai
Engine cache directory: /Users/will.morrison/.djl.ai

------------------ CUDA -----------------
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.30.0-SNAPSHOT
[INFO ] - PyTorch graph executor optimizer is enabled, this may impact your inference latency and throughput. See: https://docs.djl.ai/docs/development/inference_performance_optimization.html#graph-executor-optimization
[INFO ] - Number of inter-op threads is 10
[INFO ] - Number of intra-op threads is 8
Default Engine: PyTorch:2.4.0, capabilities: [
        OPENMP,
]
PyTorch Library: /Users/will.morrison/.djl.ai/pytorch/2.4.0-cpu-osx-aarch64
Default Device: cpu()
Rust: 4
PyTorch: 2
XGBoost: 10
LightGBM: 10
OnnxRuntime: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 10
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 51015736
Maximum memory (bytes): 17179869184
Total memory available to JVM (bytes): 58720256
Heap committed: 58720256
Heap nonCommitted: 33619968
GCC: 
Apple clang version 15.0.0 (clang-1500.3.9.4)
Target: arm64-apple-darwin23.5.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin

BUILD SUCCESSFUL in 5s
frankfliu commented 3 months ago

@willmostly

The root cause of the exception is because PtBertQATranslator has limitation when using batch. It requires padding if run in batch mode. We didn't enable padding by default because it hurt performance for single prediction case.

It should work if you use the following code:

        Criteria<QAInput, String> criteria =
                Criteria.builder()
                        .optApplication(Application.NLP.QUESTION_ANSWER)
                        .setTypes(QAInput.class, String.class)
                        .optFilter("backbone", "bert")
                        .optEngine("PyTorch")
                        .optDevice(Device.cpu())
                        .optArgument("padding", "true")
                        .optProgress(new ProgressBar())
                        .build();

A few comments regarding your project settings:

  1. MXNet is deprecated, please avoid using mxnet model zoo in future
  2. PyTorch model zoo is mainly for demo purpose, we strongly recommend you to use Huggingface model zoo (HfModelZoo), see example: https://github.com/deepjavalibrary/djl-demo/tree/master/huggingface/nlp/src/main/java/com/examples
    1. The Huggingface Translators not only support batch forward, but also support batch tokenizer and batched post processing
    2. It support all advance tokenizers implemented in rust with rich features and superior performance
    3. You can access most of models from Huggingface model hub
    4. You can import models from huggingface or your local folder into DJL with djl-converter
    5. We have optimization (using OnnxRuntime and Rust) for text embedding models.
willmostly commented 3 months ago

Tysm for the guidance @frankfliu! I confirm that adding .optArgument("padding", "true") resolves the error message, I will close this issue.

It appears that batchPredict is less accurate than running this model in single prediction mode. In single prediction mode it produces correct answers to each test question, while this is not the case with batchPredict. I'm just noting this in case you want to follow up on it - i'll transition over to the Huggingface models going forward.