Open dj-shin-okcupid opened 2 years ago
So, Vespa's interface to model evaluation is through tensors, and Vespa's tensor do not currently support strings. That is the error here. DL models with strings as input is frankly something we haven't considered yet; as usually one has had a tokenizer before the actual DL model (I see this Keras layer is experimental
).
This is something that we will have to discuss on how to address, there is no quick fix on our side for this unfortunately.
However, as preprocessing as a step in the model evaluation is gaining maturity, this is something we would have to address soon.
@lesters Thank you for the information. For the time being, it looks like a Document Processor that converts strings to numeric tensors as described in this sample app might be a possible workaround (https://github.com/vespa-engine/sample-apps/tree/master/dense-passage-retrieval-with-ann). Are there any references that might help towards using something like an ONNX model to create a custom document processor?
Re: the Keras layer, my dev environment is currently on an older version of TF. As of TF v2.8, this layer is a part of tf.keras.layers
and is no longer experimental.
There is the model-evaluation
(https://github.com/vespa-engine/sample-apps/tree/master/model-evaluation) sample app which is a bit more concrete. For an example of a document processor using a ONNX model to create embeddings:
and a unit test for that one:
Were you able to progress on this @dj-shin-okcupid ?
This has been left in the backlog as is for now. LightGBM seems easier to work with since it readily supports string-valued attributes as features.
Describe the bug Vespa is unable to build an application package containing an ONNX model that uses a Tensorflow StringLookup layer. My use case involves using string values directly as inputs to a deep learning model, which is where the StringLookup layer comes in handy.
To Reproduce Steps to reproduce the behavior:
sample_data = pd.DataFrame.from_dict({'x': ['a', 'a', 'a', 'b', 'a', 'c']}) vocab = list(sample_data['x'].unique()) x = tf.keras.layers.Input(name='x',shape=(1,),dtype=sample_data['x'].dtype) y = preprocessing.StringLookup( num_oov_indices=1, vocabulary=vocab, mask_token="MASK", oov_token="UNK", trainable=False)(x) model = tf.keras.Model(inputs=x, outputs=y) model.save('string_lookup_test')
[INFO] Scanning for projects...(VespaModel.java:183)(VespaModel.java:167)(VespaModel.java:145)
Downloading from central: https://repo.maven.apache.org/maven2/com/yahoo/vespa/cloud-tenant-base/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/com/yahoo/vespa/cloud-tenant-base/maven-metadata.xml (8.5 kB at 31 kB/s)
[INFO]
[INFO] ------------------< ai.vespa.example:model-inference >------------------
[INFO] Building model-inference 1.0.1
[INFO] --------------------------[ container-plugin ]--------------------------
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-resources-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-resources-plugin/maven-metadata.xml (874 B at 30 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-install-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-install-plugin/maven-metadata.xml (663 B at 9.6 kB/s)
Downloading from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-deploy-plugin/maven-metadata.xml
Downloaded from central: https://repo.maven.apache.org/maven2/org/apache/maven/plugins/maven-deploy-plugin/maven-metadata.xml (783 B at 31 kB/s)
[INFO]
[INFO] --- maven-clean-plugin:2.5:clean (default-clean) @ model-inference ---
[INFO] Deleting /home/djshin/dev/myapp/target
[INFO]
[INFO] --- maven-enforcer-plugin:3.0.0-M2:enforce (enforce-java) @ model-inference ---
[INFO]
[INFO] --- maven-enforcer-plugin:3.0.0-M2:enforce (enforce-no-log4j) @ model-inference ---
[INFO]
[INFO] --- bundle-plugin:7.561.60:generateSources (default-generateSources) @ model-inference ---
[INFO]
[INFO] --- maven-resources-plugin:3.2.0:resources (default-resources) @ model-inference ---
[INFO] Using 'UTF-8' encoding to copy filtered resources.
[INFO] Using 'UTF-8' encoding to copy filtered properties files.
[INFO] skip non existing resourceDirectory /home/djshin/dev/myapp/src/main/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.8.1:compile (default-compile) @ model-inference ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 5 source files to /home/djshin/dev/myapp/target/classes
[INFO]
[INFO] --- maven-resources-plugin:3.2.0:testResources (default-testResources) @ model-inference ---
[INFO] Using 'UTF-8' encoding to copy filtered resources.
[INFO] Using 'UTF-8' encoding to copy filtered properties files.
[INFO] skip non existing resourceDirectory /home/djshin/dev/myapp/src/test/resources
[INFO]
[INFO] --- maven-compiler-plugin:3.8.1:testCompile (default-testCompile) @ model-inference ---
[INFO] Changes detected - recompiling the module!
[INFO] Compiling 4 source files to /home/djshin/dev/myapp/target/test-classes
[INFO]
[INFO] --- maven-surefire-plugin:2.22.0:test (default-test) @ model-inference ---
[INFO]
[INFO] -------------------------------------------------------
[INFO] T E S T S
[INFO] -------------------------------------------------------
[INFO] Running ai.vespa.example.MySearcherTest
[ERROR] Tests run: 1, Failures: 0, Errors: 1, Skipped: 0, Time elapsed: 0.321 s <<< FAILURE! - in ai.vespa.example.MySearcherTest
[ERROR] testMySearcher Time elapsed: 0.318 s <<< ERROR!
java.lang.IllegalArgumentException: A ONNX tensor with data type STRING cannot be converted to a Vespa tensor type
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxValueTypeToString(OnnxModelInfo.java:276)
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxTypeToJson(OnnxModelInfo.java:231)
at com.yahoo.vespa.model.ml.OnnxModelInfo.onnxModelToJson(OnnxModelInfo.java:190)
at com.yahoo.vespa.model.ml.OnnxModelInfo.loadFromFile(OnnxModelInfo.java:151)
at com.yahoo.vespa.model.ml.OnnxModelInfo.load(OnnxModelInfo.java:129)
at com.yahoo.vespa.model.VespaModel.loadOnnxModelInfo(VespaModel.java:362)
at com.yahoo.vespa.model.VespaModel.onnxModelInfoFromSource(VespaModel.java:343)
at com.yahoo.vespa.model.VespaModel.createGlobalRankProfiles(VespaModel.java:298)
at com.yahoo.vespa.model.VespaModel.
at com.yahoo.vespa.model.VespaModel.
at com.yahoo.vespa.model.VespaModel.
at com.yahoo.vespa.model.container.ml.ModelsEvaluatorTester.createRankProfileList(ModelsEvaluatorTester.java:113)
at com.yahoo.vespa.model.container.ml.ModelsEvaluatorTester.create(ModelsEvaluatorTester.java:76)
at ai.vespa.example.MySearcherTest.testMySearcher(MySearcherTest.java:24)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:725)
at org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
at org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:149)
at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestableMethod(TimeoutExtension.java:140)
at org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestMethod(TimeoutExtension.java:84)
at org.junit.jupiter.engine.execution.ExecutableInvoker$ReflectiveInterceptorCall.lambda$ofVoidMethod$0(ExecutableInvoker.java:115)
at org.junit.jupiter.engine.execution.ExecutableInvoker.lambda$invoke$0(ExecutableInvoker.java:105)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$InterceptedInvocation.proceed(InvocationInterceptorChain.java:106)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.proceed(InvocationInterceptorChain.java:64)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.chainAndInvoke(InvocationInterceptorChain.java:45)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain.invoke(InvocationInterceptorChain.java:37)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:104)
at org.junit.jupiter.engine.execution.ExecutableInvoker.invoke(ExecutableInvoker.java:98)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$7(TestMethodTestDescriptor.java:214)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:210)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:135)
at org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:66)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:151)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1541)
at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137) at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139) at org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73) at org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138) at org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95) at java.base/java.util.ArrayList.forEach(ArrayList.java:1541) at org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41) at org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155) ...