vespa-engine / vespa

AI + Data, online. https://vespa.ai
https://vespa.ai
Apache License 2.0
5.49k stars 586 forks source link

Issue importing ONNX model using a Tensorflow StringLookup layer #21774

Open dj-shin-okcupid opened 2 years ago

dj-shin-okcupid commented 2 years ago

Describe the bug Vespa is unable to build an application package containing an ONNX model that uses a Tensorflow StringLookup layer. My use case involves using string values directly as inputs to a deep learning model, which is where the StringLookup layer comes in handy.

To Reproduce Steps to reproduce the behavior:

  1. Environment: python 3.7.5 with tensorflow 2.4.4 and tf2onnx 1.10.0
  2. Used sample application 'model-evaluation' from Github'' and started a vespa docker container image
  3. Created a dummy model using the following code in Python:
    
    import tensorflow as tf
    import tensorflow.keras.layers.experimental.preprocessing as preprocessing
    import pandas as pd

sample_data = pd.DataFrame.from_dict({'x': ['a', 'a', 'a', 'b', 'a', 'c']}) vocab = list(sample_data['x'].unique()) x = tf.keras.layers.Input(name='x',shape=(1,),dtype=sample_data['x'].dtype) y = preprocessing.StringLookup( num_oov_indices=1, vocabulary=vocab, mask_token="MASK", oov_token="UNK", trainable=False)(x) model = tf.keras.Model(inputs=x, outputs=y) model.save('string_lookup_test')

4. Converted model to onnx via tf2onnx: `python -m tf2onnx.convert --saved-model=string_lookup_test --output=string_lookup_test.onnx`
5. Added serialized model to `<my_app_root>/src/main/application/models/`
6. From app root directory, build application package: `mvn clean package -U`
7. During testing step, we observe the following error : `java.lang.IllegalArgumentException: A ONNX tensor with data type STRING cannot be converted to a Vespa tensor type`

**Expected behavior**
Application package builds to completion, ready for deployment

**Screenshots**

[INFO] Scanning for projects...
Downloading from central: https://repo.maven.apache.org/maven2/com/yahoo/vespa/cloud-tenant-base/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/com/yahoo/vespa/cloud-tenant-base/maven-metadata.xml (8.5 kB at 31 kB/s)
[INFO]
[INFO] ------------------< ai.vespa.example:model-inference >------------------
[INFO] Building model-inference 1.0.1
[INFO] --------------------------[ container-plugin ]--------------------------
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-resources-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-resources-plugin/maven-metadata.xml (874 B at 30 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-install-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-install-plugin/maven-metadata.xml (663 B at 9.6 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-deploy-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-deploy-plugin/maven-metadata.xml (783 B at 31 kB/s)
[INFO]
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ model-inference ---
[INFO] Deleting /home/djshin/dev/myapp/target
[INFO]
[INFO] --- maven-enforcer-plugin:3.0.0-M2:enforce (enforce-java) @ model-inference ---
[INFO]
[INFO] --- maven-enforcer-plugin:3.0.0-M2:enforce (enforce-no-log4j) @ model-inference ---
[INFO]
[INFO] --- bundle-plugin:7.561.60:generateSources (default-generateSources) @ model-inference ---
[INFO]
[INFO] --- maven-resources-plugin:3.2.0:resources (default-resources) @ model-inference ---
[INFO] Using 'UTF-8' encoding to copy filtered resources.
[INFO] Using 'UTF-8' encoding to copy filtered properties files.
[INFO] skip non existing resourceDirectory /home/djshin/dev/myapp/src/main/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.8.1:compile (default-compile) @ model-inference ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 5 source files to /home/djshin/dev/myapp/target/classes
[INFO]
[INFO] --- maven-resources-plugin:3.2.0:testResources (default-testResources) @ model-inference ---
[INFO] Using 'UTF-8' encoding to copy filtered resources.
[INFO] Using 'UTF-8' encoding to copy filtered properties files.
[INFO] skip non existing resourceDirectory /home/djshin/dev/myapp/src/test/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.8.1:testCompile (default-testCompile) @ model-inference ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 4 source files to /home/djshin/dev/myapp/target/test-classes
[INFO]
[INFO] --- maven-surefire-plugin:2.22.0:test (default-test) @ model-inference ---
[INFO]
[INFO] -------------------------------------------------------
[INFO] T E S T S
[INFO] -------------------------------------------------------
[INFO] Running ai.vespa.example.MySearcherTest
[ERROR] Tests run: 1, Failures: 0, Errors: 1, Skipped: 0, Time elapsed: 0.321 s <<< FAILURE! - in ai.vespa.example.MySearcherTest
[ERROR] testMySearcher Time elapsed: 0.318 s <<< ERROR!
java.lang.IllegalArgumentException: A ONNX tensor with data type STRING cannot be converted to a Vespa tensor type
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxValueTypeToString(OnnxModelInfo.java:276)
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxTypeToJson(OnnxModelInfo.java:231)
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxModelToJson(OnnxModelInfo.java:190)
at com.yahoo.vespa.model.ml.OnnxModelInfo.loadFromFile(OnnxModelInfo.java:151)
at com.yahoo.vespa.model.ml.OnnxModelInfo.load(OnnxModelInfo.java:129)
at com.yahoo.vespa.model.VespaModel.loadOnnxModelInfo(VespaModel.java:362)
at com.yahoo.vespa.model.VespaModel.onnxModelInfoFromSource(VespaModel.java:343)
at com.yahoo.vespa.model.VespaModel.createGlobalRankProfiles(VespaModel.java:298)
at com.yahoo.vespa.model.VespaModel.(VespaModel.java:183)
at com.yahoo.vespa.model.VespaModel.(VespaModel.java:167)
at com.yahoo.vespa.model.VespaModel.(VespaModel.java:145)
at com.yahoo.vespa.model.container.ml.ModelsEvaluatorTester.createRankProfileList(ModelsEvaluatorTester.java:113)
at com.yahoo.vespa.model.container.ml.ModelsEvaluatorTester.create(ModelsEvaluatorTester.java:76)
at ai.vespa.example.MySearcherTest.testMySearcher(MySearcherTest.java:24)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:725)
at org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
at org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:149)
at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestableMethod(TimeoutExtension.java:140)
at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestMethod(TimeoutExtension.java:84)
at org.junit.jupiter.engine.execution.ExecutableInvoker$ReflectiveInterceptorCall.lambda$ofVoidMethod$0(ExecutableInvoker.java:115)
at org.junit.jupiter.engine.execution.ExecutableInvoker.lambda$invoke$0(ExecutableInvoker.java:105)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$InterceptedInvocation.proceed(InvocationInterceptorChain.java:106)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.proceed(InvocationInterceptorChain.java:64)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.chainAndInvoke(InvocationInterceptorChain.java:45)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.invoke(InvocationInterceptorChain.java:37)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:104)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:98)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$7(TestMethodTestDescriptor.java:214)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:210)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:135)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:66)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:151)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1541)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137) at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139) at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73) at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138) at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95) at java.base/java.util.ArrayList.forEach(ArrayList.java:1541) at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41) at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155) ...



**Environment (please complete the following information):**
 - OS: Ubuntu
 - Versions 20.04

**Vespa version**
8.555.50

**Additional context**
Add any other context about the problem here.
lesters commented 2 years ago

So, Vespa's interface to model evaluation is through tensors, and Vespa's tensor do not currently support strings. That is the error here. DL models with strings as input is frankly something we haven't considered yet; as usually one has had a tokenizer before the actual DL model (I see this Keras layer is experimental).

This is something that we will have to discuss on how to address, there is no quick fix on our side for this unfortunately.

However, as preprocessing as a step in the model evaluation is gaining maturity, this is something we would have to address soon.

dj-shin-okcupid commented 2 years ago

@lesters Thank you for the information. For the time being, it looks like a Document Processor that converts strings to numeric tensors as described in this sample app might be a possible workaround (https://github.com/vespa-engine/sample-apps/tree/master/dense-passage-retrieval-with-ann). Are there any references that might help towards using something like an ONNX model to create a custom document processor?

Re: the Keras layer, my dev environment is currently on an older version of TF. As of TF v2.8, this layer is a part of tf.keras.layers and is no longer experimental.

lesters commented 2 years ago

There is the model-evaluation (https://github.com/vespa-engine/sample-apps/tree/master/model-evaluation) sample app which is a bit more concrete. For an example of a document processor using a ONNX model to create embeddings:

https://github.com/vespa-engine/sample-apps/blob/master/model-evaluation/src/main/java/ai/vespa/example/MyDocumentProcessor.java

and a unit test for that one:

https://github.com/vespa-engine/sample-apps/blob/master/model-evaluation/src/test/java/ai/vespa/example/MyDocumentProcessorTest.java

jobergum commented 2 years ago

Were you able to progress on this @dj-shin-okcupid ?

dj-shin-okcupid commented 2 years ago

This has been left in the backlog as is for now. LightGBM seems easier to work with since it readily supports string-valued attributes as features.