Open jazzblue opened 8 months ago
Cast it to long[]
not long
. You can always reflectively inspect the type of the object returned by getValue
, e.g. results.get("output_label").get().getValue().getClass()
will return long[].class
. Scalars have shape []
, whereas this model produces something of shape [batch_size]
so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.
Cast it to
long[]
notlong
. You can always reflectively inspect the type of the object returned bygetValue
, e.g.results.get("output_label").get().getValue().getClass()
will returnlong[].class
. Scalars have shape[]
, whereas this model produces something of shape[batch_size]
so while there is only a single element in this case, it's a 1d vector which we return as a 1d array.
@Craigacp casting to long[]
worked, thanks!
This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details.
Describe the issue
I found that it is probably the same issue as https://github.com/microsoft/onnxruntime/issues/16781.
I am using ONNX to serve a scikit-learn trained model inside Java code. The output is returned as
OnnxValue
object and I apply getValue() to retrieve the output value. As per [API documentation](https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OnnxValue.html#getValue()) it is supposed to return the value as a Java object and I understand I should be able to extract the primitive value, such as float or array. At least forOnnxTensor
the [API doc](https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OnnxTensor.html#getValue()) saysEither returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of primitives if it has multiple dimensions.
Logging the type, by applying getType() method, shows the correct typeOnnxTensor(info=TensorInfo(javaType=INT64,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,shape=[1]))
. However, casting it into long, or int, throws exception and I do not see any other method or way to get the primitive or array. How would I extract the value from the java object?To reproduce
pip install skl2onnx
iris = load_iris() X, y = iris.data, iris.target X = X.astype(np.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) clr = RandomForestClassifier() clr.fit(X_train, y_train)
Convert into ONNX format.
from skl2onnx import to_onnx
onx = to_onnx(clr, X) with open("rf_iris.onnx", "wb") as f: f.write(onx.SerializeToString())
OnnxRf.java
import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import java.util.*;
import ai.onnxruntime.OnnxValue;
public class OnnxRf {
}
mvn package
java -jar target/onnx-example-1.0.jar