deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 658 forks source link

djl-convert does not produce working model from Huggingface #3518

Open AlEscher opened 8 hours ago

AlEscher commented 8 hours ago

Description

I am trying to convert a Huggingface model to make it compatible with DJL. My goal is to use djl-convert to convert the model and be able to load it locally. Then I want to generate code-embeddings for Java code, using e.g. Codebert. I ran djl-convert -m microsoft/codebert-base -o models/codebert and then used this code to import the model:

Criteria<String, float[]> criteria = Criteria.builder().setTypes(String.class, float[].class)
  .optApplication(Application.NLP.TEXT_EMBEDDING).optModelPath(Paths.get("models/codebert"))
  .optModelName("codebert-base.pt").optTranslator(translator).optProgress(new ProgressBar()).build();
ZooModel<String, float[]> model = ModelZoo.loadModel(criteria);
Predictor<String, float[]> predictor = model.newPredictor();
float[] embeddings = predictor.predict(input);

The translator is implemented like this:

protected ModelTranslator(String tokenizerPath, boolean useTokenTypes) throws IOException {
  this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerPath, "tokenizer.json"));
  this.useTokenTypes = useTokenTypes;
}

@Override
public NDList processInput(TranslatorContext ctx, String input) {
  return tokenizer.encode(input).toNDList(ctx.getNDManager(), useTokenTypes);
}

@Override
public float[] processOutput(TranslatorContext ctx, NDList list) {
  // Retrieve the embeddings from the output
  NDArray embeddings = list.singletonOrThrow();
  return embeddings.toFloatArray();
}

When generating the embeddings, the model fails with:

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 3 argument(s) for operator 'forward', but received 4 argument(s). Declaration: forward(__torch__.transformers.models.roberta.modeling_roberta.RobertaModel self, Tensor input_ids, Tensor attention_mask) -> Dict(str, Tensor)

What am I doing wrong? Is there a better approach to load a model from huggingface? codebert-base does not seem to be available in the Model Zoo.

Expected Behavior

The convert tool produces a model that can be loaded locally and has a working forward method

Error Message

ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 3 argument(s) for operator 'forward', but received 4 argument(s). Declaration: forward(__torch__.transformers.models.roberta.modeling_roberta.RobertaModel self, Tensor input_ids, Tensor attention_mask) -> Dict(str, Tensor)

How to Reproduce?

See provided code above

Steps to reproduce

(Paste the commands you ran that produced the error.)

  1. Run the djl-convert tool as described above
  2. Attempt to generate embeddings as described above

What have you tried to solve it?

I tried many different ways of getting a model from huggingface to work locally, this approach seems to be the intended way according to https://djl.ai/extensions/tokenizers/#convert-huggingface-model-to-torchscript

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

Found C:\Dev\Research\djl\\gradle\wrapper\gradle-wrapper.jar
Starting a Gradle Daemon (subsequent builds will be faster)

> Task :engines:ml:xgboost:processResources
Downloading https://publish.djl.ai/xgboost/2.0.3/jnilib/linux/aarch64/libxgboost4j.so

> Task :engines:pytorch:pytorch-jni:processResources
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cpu/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cpu-precxx11/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-aarch64/cpu-precxx11/libdjl_torch.so
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/osx-aarch64/cpu/libdjl_torch.dylib                                                                                                                                                                                                                   
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/win-x86_64/cpu/djl_torch.dll                                                                                                                                                                                                                         
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cu124/libdjl_torch.so                                                                                                                                                                                                                   
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/linux-x86_64/cu124-precxx11/libdjl_torch.so                                                                                                                                                                                                          
Downloading https://publish.djl.ai/pytorch/2.4.0/jnilib/0.31.0/win-x86_64/cu124/djl_torch.dll                                                                                                                                                                                                                       

> Task :integration:debugEnv
----------- System Properties -----------
java.specification.version: 21
sun.cpu.isalist: amd64
sun.jnu.encoding: Cp1252
java.class.path: C:\Dev\Research\djl\integration\build\classes\java\main;C:\Dev\Research\djl\integration\build\resources\main;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-cli\commons-cli\1.9.0\e1cdfa8bf40ccbb7440b2d1232f9f45bb20a1844\commons-cli-1.9.0.jar;C:\Users\Alessandro\.gradle\caches
\modules-2\files-2.1\org.apache.logging.log4j\log4j-slf4j2-impl\2.24.0\3d550671b19e83591d5e66cc8c77272e7aaac34c\log4j-slf4j2-impl-2.24.0.jar;C:\Dev\Research\djl\basicdataset\build\libs\basicdataset-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\model-zoo\build\libs\model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl
\testing\build\libs\testing-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\mxnet\mxnet-model-zoo\build\libs\mxnet-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\pytorch\pytorch-model-zoo\build\libs\pytorch-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\pytorch\pytorch-jni\build\libs\p
ytorch-jni-2.4.0-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-model-zoo\build\libs\tensorflow-model-zoo-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\ml\xgboost\build\libs\xgboost-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\ml\lightgbm\build\libs\lightgbm-0.31.0-SNAPSHOT.jar;C
:\Dev\Research\djl\engines\onnxruntime\onnxruntime-engine\build\libs\onnxruntime-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\extensions\tokenizers\build\libs\tokenizers-0.31.0-SNAPSHOT.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.logging.log4j\log4j-core\2.24.0\537543d3b84d78b4d7
ad055c98f8af13e5e7f3a8\log4j-core-2.24.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.logging.log4j\log4j-api\2.24.0\c6d9bd0c95c9bb6c530f4800da9507b98f018654\log4j-api-2.24.0.jar;C:\Dev\Research\djl\engines\mxnet\mxnet-engine\build\libs\mxnet-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Resear
ch\djl\engines\pytorch\pytorch-engine\build\libs\pytorch-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-engine\build\libs\tensorflow-engine-0.31.0-SNAPSHOT.jar;C:\Dev\Research\djl\api\build\libs\api-0.31.0-SNAPSHOT.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.
testng\testng\7.10.2\30742acada21960d4333a4204039fbdc6a92083a\testng-7.10.2.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.slf4j\slf4j-api\2.0.16\172931663a09a1fa515567af5fbef00897d3c04\slf4j-api-2.0.16.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.commons\commons-csv
\1.11.0\8f2dc805097da534612128b7cdf491a5a76752bf\commons-csv-1.11.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\ml.dmlc\xgboost4j_2.12\2.0.3\db511d04d1ca1364cde79a6c8238a2694e31c592\xgboost4j_2.12-2.0.3.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-logging\commons-logging
\1.3.4\b9fc14968d63a8b8a8a2c1885fe3e90564239708\commons-logging-1.3.4.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.microsoft.ml.lightgbm\lightgbmlib\3.2.110\f6c85e5d7cc44d49c4544240ea5c96004680007b\lightgbmlib-3.2.110.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.microsoft
.onnxruntime\onnxruntime\1.19.0\52985f239457f0b1f635b9a0e9e5b0b03c76b22b\onnxruntime-1.19.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.code.gson\gson\2.11.0\527175ca6d81050b53bdd4c457a6d6e017626b0e\gson-2.11.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\net.java.dev
.jna\jna\5.14.0\67bf3eaea4f0718cb376a181a629e5f88fa1c9dd\jna-5.14.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.apache.commons\commons-compress\1.27.1\a19151084758e2fbb6b41eddaa88e7b8ff4e6599\commons-compress-1.27.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-io\com
mons-io\2.16.1\377d592e740dc77124e0901291dbfaa6810a200e\commons-io-2.16.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\commons-codec\commons-codec\1.17.1\973638b7149d333563584137ebf13a691bb60579\commons-codec-1.17.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.beust\jcommande
r\1.82\a7c5fef184d238065de38f81bbc6ee50cca2e21\jcommander-1.82.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.webjars\jquery\3.7.1\42088e652462c40a369b64d87e18e825644acfab\jquery-3.7.1.jar;C:\Dev\Research\djl\engines\tensorflow\tensorflow-api\build\libs\tensorflow-api-0.31.0-SNAPSHOT.jar;C:\
Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.scala-lang.modules\scala-collection-compat_2.12\2.10.0\bf81785e892f4185f470bddd205b011237aab553\scala-collection-compat_2.12-2.10.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.errorprone\error_prone_annotations\2.27.0\91b2c29d
8a6148b5e2e4930f070d4840e2e48e34\error_prone_annotations-2.27.0.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tensorflow\tensorflow-core-api\1.0.0-rc.1\ea1878fb8e289742237e5a0ba6f15398f3e9b7ef\tensorflow-core-api-1.0.0-rc.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tens
orflow\tensorflow-core-native\1.0.0-rc.1\62b5fa3283865cc696dfbebf073ca2116b18f327\tensorflow-core-native-1.0.0-rc.1.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.bytedeco\javacpp\1.5.10\afb6ae145e7563c66b677cb4896dd0197d49fce6\javacpp-1.5.10.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\com.google.protobuf\protobuf-java\3.25.5\5ae5c9ec39930ae9b5a61b32b93288818ec05ec1\protobuf-java-3.25.5.jar;C:\Users\Alessandro\.gradle\caches\modules-2\files-2.1\org.tensorflow\ndarray\1.0.0-rc.1\4a96a398ad87bec32be9177b1441b9880c04d822\ndarray-1.0.0-rc.1.jar
java.vm.vendor: Oracle Corporation
sun.arch.data.model: 64
user.variant:
java.vendor.url: https://java.oracle.com/
user.timezone: Europe/Berlin
java.vm.specification.version: 21
os.name: Windows 11
user.country: GB
sun.java.launcher: SUN_STANDARD
sun.boot.library.path: C:\Program Files\Java\jdk-21\bin
sun.java.command: ai.djl.integration.util.DebugEnvironment
jdk.debug: release
sun.cpu.endian: little
user.home: C:\Users\Alessandro
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2024-10-15
java.home: C:\Program Files\Java\jdk-21
file.separator: \
java.vm.compressedOopsMode: Zero based
line.separator:

java.vm.specification.vendor: Oracle Corporation
java.specification.name: Java Platform API Specification
user.script:
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.runtime.version: 21.0.5+9-LTS-239
user.name: Alessandro
stdout.encoding: Cp1252
path.separator: ;
os.version: 10.0
java.runtime.name: Java(TM) SE Runtime Environment
file.encoding: UTF-8
java.vm.name: Java HotSpot(TM) 64-Bit Server VM
java.vendor.url.bug: https://bugreport.java.com/bugreport/
java.io.tmpdir: C:\Users\ALESSA~1\AppData\Local\Temp\
java.version: 21.0.5
user.dir: C:\Dev\Research\djl\integration
os.arch: amd64
java.vm.specification.name: Java Virtual Machine Specification
sun.os.patch.level:
native.encoding: Cp1252
java.library.path: C:\Program Files\Java\jdk-21\bin;C:\WINDOWS\Sun\Java\bin;C:\WINDOWS\system32;C:\WINDOWS;C:\Program Files\BullseyeCoverage\bin;C:\Testwell\CTC;C:\Program Files\Common Files\Oracle\Java\javapath;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1
.0\;C:\WINDOWS\System32\OpenSSH\;C:\ProgramData\chocolatey\bin;C:\Program Files\osquery;C:\LLVM\bin;C:\cygwin64\bin;C:\Strawberry\c\bin;C:\Strawberry\perl\site\bin;C:\Strawberry\perl\bin;C:\Program Files\microsoft.codecoverage.17.1.0\build\netstandard1.0\CodeCoverage;C:\Program Files\Git LFS;C:\Program File
s\teamscale-upload-windows;C:\Program Files\dotnet\;C:\Program Files\OpenCppCoverage;C:\Program Files\BullseyeCoverage\lib;C:\Program Files\Go\bin;C:\Program Files\apache-maven-3.9.0\bin;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\bin;C:\Program Files (x86)\GtkSharp\2.12\bin;C:\Program Files\no
dejs\;C:\Program Files\Docker\Docker\resources\bin;C:\Users\Alessandro\AppData\Local\Android\Sdk\platform-tools;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\cmake\win\x64\bin;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\mingw\bin;C:\Users\Alessandro\AppData\Roaming\Python\Python312\Scripts;C
:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Scripts\;C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\;C:\Users\Alessandro\AppData\Local\Programs\Python\Launcher\;C:\Users\Alessandro\AppData\Local\pnpm;C:\Users\Alessandro\.poetry\bin;C:\Users\Alessandro\AppData\Local\Microsoft\WindowsApps;C:\Users\Alessandro\AppData\Local\Programs\Microsoft VS Code\bin;C:\Users\Alessandro\go\bin;C:\Users\Alessandro\AppData\Roaming\npm;C:\Users\Alessandro\AppData\Local\JetBrains\Toolbox\scripts;C:\Users\Alessandro\.dotnet\tools;C:\Users\Alessandro\AppData\Local\Programs\Git\cmd;.
java.vm.info: mixed mode, sharing
stderr.encoding: Cp1252
java.vendor: Oracle Corporation
java.vm.version: 21.0.5+9-LTS-239
sun.io.unicode.encoding: UnicodeLittle
java.class.version: 65.0

--------- Environment Variables ---------
USERDOMAIN_ROAMINGPROFILE: DESKTOP-022SVF2
PROCESSOR_LEVEL: 6
LCOV_HOME: C:\ProgramData\chocolatey\lib\lcov\tools
SESSIONNAME: Console
ALLUSERSPROFILE: C:\ProgramData
COVFILE: C:\Dev\bullseye-testwise-coverage\bullseye.cov
PROCESSOR_ARCHITECTURE: AMD64
PSModulePath: C:\Users\Alessandro\OneDrive\Documents\WindowsPowerShell\Modules;C:\Program Files\WindowsPowerShell\Modules;C:\WINDOWS\system32\WindowsPowerShell\v1.0\Modules;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\platform\PowerShell
SystemDrive: C:
PNPM_HOME: C:\Users\Alessandro\AppData\Local\pnpm
DIRNAME: C:\Dev\Research\djl\
USERNAME: Alessandro
CMD_LINE_ARGS: debugEnv
ProgramFiles(x86): C:\Program Files (x86)
APP_HOME: C:\Dev\Research\djl\
PATHEXT: .COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW;.CPL
DriverData: C:\Windows\System32\Drivers\DriverData
OneDriveConsumer: C:\Users\Alessandro\OneDrive
GOPATH: C:\Users\Alessandro\go
ProgramData: C:\ProgramData
GIT_LFS_PATH: C:\Program Files\Git LFS
ProgramW6432: C:\Program Files
HOMEPATH: \Users\Alessandro
PROCESSOR_IDENTIFIER: Intel64 Family 6 Model 158 Stepping 10, GenuineIntel
ProgramFiles: C:\Program Files
PUBLIC: C:\Users\Public
windir: C:\WINDOWS
CTCHOME: C:\Testwell\CTC
=::: ::\
ZES_ENABLE_SYSMAN: 1
_SKIP: 2
LOCALAPPDATA: C:\Users\Alessandro\AppData\Local
USERDOMAIN: DESKTOP-022SVF2
LOGONSERVER: \\DESKTOP-022SVF2
JAVA_HOME: C:\Program Files\Java\jdk-21
PROMPT: $P$G
JETBRAINS_INTELLIJ_COMMAND_END_MARKER: SUyi67dqQ9BKmVqVo3br2NnDywq1xvC4ulCIYXe9Obl4Owe0u0wC2bPj9Yi6YYBr
EFC_10204: 1
OneDrive: C:\Users\Alessandro\OneDrive
=C:: C:\Dev\Research\djl
APPDATA: C:\Users\Alessandro\AppData\Roaming
DOWNLOAD_URL: "https://raw.githubusercontent.com/gradle/gradle/master/gradle/wrapper/gradle-wrapper.jar"
GTK_BASEPATH: C:\Program Files (x86)\GtkSharp\2.12\
JAVA_EXE: C:\Program Files\Java\jdk-21/bin/java.exe
ChocolateyInstall: C:\ProgramData\chocolatey
CommonProgramFiles: C:\Program Files\Common Files
Path: C:\Program Files\BullseyeCoverage\bin;C:\Testwell\CTC;C:\Program Files\Common Files\Oracle\Java\javapath;C:\WINDOWS\system32;C:\WINDOWS;C:\WINDOWS\System32\Wbem;C:\WINDOWS\System32\WindowsPowerShell\v1.0\;C:\WINDOWS\System32\OpenSSH\;C:\ProgramData\chocolatey\bin;C:\Program Files\osquery;C:\LLVM\bin;C
:\cygwin64\bin;C:\Strawberry\c\bin;C:\Strawberry\perl\site\bin;C:\Strawberry\perl\bin;C:\Program Files\microsoft.codecoverage.17.1.0\build\netstandard1.0\CodeCoverage;C:\Program Files\Git LFS;C:\Program Files\teamscale-upload-windows;C:\Program Files\dotnet\;C:\Program Files\OpenCppCoverage;C:\Program Files
\BullseyeCoverage\lib;C:\Program Files\Go\bin;C:\Program Files\apache-maven-3.9.0\bin;C:\Program Files (x86)\Google\Cloud SDK\google-cloud-sdk\bin;C:\Program Files (x86)\GtkSharp\2.12\bin;C:\Program Files\nodejs\;C:\Program Files\Docker\Docker\resources\bin;C:\Users\Alessandro\AppData\Local\Android\Sdk\plat
form-tools;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\cmake\win\x64\bin;C:\Users\Alessandro\AppData\Local\Programs\CLion\bin\mingw\bin;C:\Users\Alessandro\AppData\Roaming\Python\Python312\Scripts;C:\Users\Alessandro\AppData\Local\Programs\Python\Python312\Scripts\;C:\Users\Alessandro\AppData\Local
\Programs\Python\Python312\;C:\Users\Alessandro\AppData\Local\Programs\Python\Launcher\;C:\Users\Alessandro\AppData\Local\pnpm;C:\Users\Alessandro\.poetry\bin;C:\Users\Alessandro\AppData\Local\Microsoft\WindowsApps;C:\Users\Alessandro\AppData\Local\Programs\Microsoft VS Code\bin;C:\Users\Alessandro\go\bin;C:\Users\Alessandro\AppData\Roaming\npm;C:\Users\Alessandro\AppData\Local\JetBrains\Toolbox\scripts;C:\Users\Alessandro\.dotnet\tools;C:\Users\Alessandro\AppData\Local\Programs\Git\cmd
OS: Windows_NT
COMPUTERNAME: DESKTOP-022SVF2
PROCESSOR_REVISION: 9e0a
CLASSPATH: C:\Dev\Research\djl\\gradle\wrapper\gradle-wrapper.jar
CommonProgramW6432: C:\Program Files\Common Files
ComSpec: C:\WINDOWS\system32\cmd.exe
APP_BASE_NAME: gradlew
TERMINAL_EMULATOR: JetBrains-JediTerm
PSExecutionPolicyPreference: Bypass
SystemRoot: C:\WINDOWS
TEMP: C:\Users\ALESSA~1\AppData\Local\Temp
HOMEDRIVE: C:
USERPROFILE: C:\Users\Alessandro
TMP: C:\Users\ALESSA~1\AppData\Local\Temp
CommonProgramFiles(x86): C:\Program Files (x86)\Common Files
NUMBER_OF_PROCESSORS: 12

-------------- Directories --------------
temp directory: C:\Users\ALESSA~1\AppData\Local\Temp
DJL cache directory: C:\Users\Alessandro\.djl.ai
Engine cache directory: C:\Users\Alessandro\.djl.ai

------------------ CUDA -----------------
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.31.0-SNAPSHOT
[INFO ] - Downloading libgcc_s_seh-1.dll ...
[INFO ] - Downloading libgfortran-3.dll ...
[INFO ] - Downloading libopenblas.dll ...
[INFO ] - Downloading libquadmath-0.dll ...
[INFO ] - Downloading mxnet.dll ...
Default Engine: MXNet:1.9.0, capabilities: [
        SIGNAL_HANDLER,
        LAPACK,
        BLAS_OPEN,
        OPENMP,
        OPENCV,
        MKLDNN,
]
MXNet Library: C:\Users\Alessandro\.djl.ai\mxnet\1.9.1-mkl-win-x86_64\mxnet.dll
Default Device: cpu()
Rust: 4
PyTorch: 2
MXNet: 0
XGBoost: 10
LightGBM: 10
OnnxRuntime: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 12
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 513832784
Maximum memory (bytes): 8527020032
Total memory available to JVM (bytes): 536870912
Heap committed: 536870912
Heap nonCommitted: 29818880

BUILD SUCCESSFUL in 49s
64 actionable tasks: 15 executed, 49 up-to-date
frankfliu commented 8 hours ago

@AlEscher Can you use our built-in TextEmbeddingTranslator?

The following code works for me for this model:

        Criteria<String, float[]> criteria =
                Criteria.builder()
                        .setTypes(String.class, float[].class)
                        .optModelPath(path)
                        .optEngine("PyTorch")
                        .optTranslatorFactory(new TextEmbeddingTranslatorFactory())
                        .optProgress(new ProgressBar())
                        .build();
frankfliu commented 8 hours ago

@AlEscher

Your error may caused by you set useTokenTypes = true