microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.59k stars 2.92k forks source link

<OnnxValue>.getValue() returns non-parseable java object #19440

Open jazzblue opened 8 months ago

jazzblue commented 8 months ago

Describe the issue

I found that it is probably the same issue as https://github.com/microsoft/onnxruntime/issues/16781.

I am using ONNX to serve a scikit-learn trained model inside Java code. The output is returned as OnnxValue object and I apply getValue() to retrieve the output value. As per [API documentation](https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OnnxValue.html#getValue()) it is supposed to return the value as a Java object and I understand I should be able to extract the primitive value, such as float or array. At least for OnnxTensor the [API doc](https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OnnxTensor.html#getValue()) says Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of primitives if it has multiple dimensions. Logging the type, by applying getType() method, shows the correct type OnnxTensor(info=TensorInfo(javaType=INT64,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,shape=[1])). However, casting it into long, or int, throws exception and I do not see any other method or way to get the primitive or array. How would I extract the value from the java object?

To reproduce

  1. Java version:
    openjdk version "11.0.21" 2023-10-17
    OpenJDK Runtime Environment (build 11.0.21+9-post-Ubuntu-0ubuntu122.04)
    OpenJDK 64-Bit Server VM (build 11.0.21+9-post-Ubuntu-0ubuntu122.04, mixed mode)
  2. Directory tree:
    |-- pom.xml
    |-- src
    |   `-- main
    |       |-- java
    |       |   `-- onnx
    |       |       `-- example
    |       |           `-- OnnxRf.java
    |       `-- resources
    |           `-- rf_iris.onnx
  3. Training script (in python) pip install skl2onnx
    
    import numpy as np
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier

iris = load_iris() X, y = iris.data, iris.target X = X.astype(np.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) clr = RandomForestClassifier() clr.fit(X_train, y_train)

Convert into ONNX format.

from skl2onnx import to_onnx

onx = to_onnx(clr, X) with open("rf_iris.onnx", "wb") as f: f.write(onx.SerializeToString())

save the onnx packaged model `rf_iris.onnx` under `src/main/resources`
4. File `pom.xml`
```xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>onnx.example</groupId>
    <artifactId>onnx-example</artifactId>
    <version>1.0</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <main.class>onnx.example.OnnxRf</main.class>
        <onnxruntime.version>1.16.3</onnxruntime.version>
    </properties>
    <dependencies>
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>${onnxruntime.version}</version>
        </dependency>
    </dependencies>
    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>3.1.1</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                        <configuration>
                            <artifactSet>
                                <excludes>
                                    <exclude>com.google.code.findbugs:jsr305</exclude>
                                </excludes>
                            </artifactSet>
                            <filters>
                                <filter>
                                    <!-- Do not copy the signatures in the META-INF folder.
                                    Otherwise, this might cause SecurityExceptions when using the JAR. -->
                                    <artifact>*:*</artifact>
                                    <excludes>
                                        <exclude>META-INF/*.SF</exclude>
                                        <exclude>META-INF/*.DSA</exclude>
                                        <exclude>META-INF/*.RSA</exclude>
                                    </excludes>
                                </filter>
                            </filters>
                            <transformers>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
                                    <!-- Replace this with the main class of your job -->
                                    <mainClass>onnx.example.OnnxRf</mainClass>
                                </transformer>
                                <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
                            </transformers>
                        </configuration>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>
  1. File OnnxRf.java
    
    package onnx.example;

import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import java.util.*;

import ai.onnxruntime.OnnxValue;

public class OnnxRf {

public static void main(String[] args) throws Exception {

    OrtSession session = null;
    OrtEnvironment env;
    Set<String> outputNames;

    Map<String, OnnxTensor> inputs = new HashMap();
    Iterator<String> namesIterator;
    String inFieldName;
    float[][] inputShaped;
    inputShaped = new float[1][1];

    env = OrtEnvironment.getEnvironment();
    session = env.createSession("src/main/resources/rf_iris.onnx", new OrtSession.SessionOptions());
    outputNames = session.getOutputNames();

    float[] inVal = new float[] {1.1f, 2.3f, 3.4f, 5.6f};
    inputShaped[0] = inVal;

    inputs.put("X", OnnxTensor.createTensor(env, inputShaped));
    System.out.println ("outputNames: " + outputNames);

    try (var results = session.run(inputs, outputNames)) {
        System.out.println ("----- output types -----");
        for (String fieldName : outputNames) {
            System.out.println(fieldName + ": " + results.get(fieldName).get() + " : " + results.get(fieldName).get().getType());
        }
        System.out.println ("----------");
        System.out.println("output_label, class: " + results.get("output_label").get().getValue().getClass());
        System.out.println("output_label, str: " + results.get("output_label").get().getValue().toString());
        System.out.println("output_label, getType: " + results.get("output_label").get().getType());
        System.out.println("output_label, getInfo: " + results.get("output_label").get().getInfo());

        // Trying to cast to long here since output_label output type is shown as INT64, but int did not work either
        System.out.println("output_label, long: " + (long) results.get("output_label").get().getValue());

    } catch (OrtException e) {
        // e.printStackTrace();
        System.out.println (">>>>>" + e.getCode() + ": " + e.getMessage());
    }
}

}

6. Build Java package

mvn package

7. Run the Java inference code using ONNX model

java -jar target/onnx-example-1.0.jar



### Urgency

_No response_

### Platform

Windows

### OS Version

Ubuntu 22.04.3 LTS

### ONNX Runtime Installation

Built from Source

### ONNX Runtime Version or Commit ID

1.16.3

### ONNX Runtime API

Python

### Architecture

X64

### Execution Provider

Default CPU

### Execution Provider Library Version

_No response_
Craigacp commented 8 months ago

Cast it to long[] not long. You can always reflectively inspect the type of the object returned by getValue, e.g. results.get("output_label").get().getValue().getClass() will return long[].class. Scalars have shape [], whereas this model produces something of shape [batch_size] so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.

jazzblue commented 8 months ago

Cast it to long[] not long. You can always reflectively inspect the type of the object returned by getValue, e.g. results.get("output_label").get().getValue().getClass() will return long[].class. Scalars have shape [], whereas this model produces something of shape [batch_size] so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.

@Craigacp casting to long[] worked, thanks!

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.