Open paranjapeved15 opened 1 year ago
[J@ba1f559
is how an array reference is printed if you don't use Arrays.toString()
. The getValue()
call is returning an array of longs representing the predicted classes, for that input it's probably a single element array, but it's still an array.
@Craigacp thanks for your reply. I tried doing: session.run(inputs).get(0).getValue().toString() session.run(inputs).get(1).getValue().toString()
but it is still giving me similar results. Can you please show me in code how to get the actual long/float values?
It's a regular Java array, so you can get the String representation of the array values with:
try (OrtSession.Result r = session.run(inputs)) {
OnnxValue first = r.get(0);
long[] classes = (long[]) first.getValue();
System.out.println(Arrays.toString(classes));
}
You should ensure that all input and output values are closed, as otherwise you'll leak memory on the native heap.
There is a complete test example using MNIST here - https://github.com/microsoft/onnxruntime/blob/main/java/src/test/java/sample/ScoreMNIST.java#L259, and there are plenty of test cases showing how to get data into and out of ONNX Runtime. Note the MNIST example uses arrays, but it's not the most efficient way of using ORT as multidimensional arrays in Java are collections of pointers and so have poor locality and cache behaviour. It's therefore better to use java.nio.Buffer
s as the input and output types for creating and accessing OnnxTensor
objects. There are examples of doing that in the tests. We're working on a complete example application of how to effectively use ORT from Java but it's not finished yet.
Great, thanks a lot @Craigacp! I am eagerly waiting for the complete example. Any ETAs on that? I am working on migrating from using jpmml to onnx for our ML models and would be glad to use the document.
No ETA yet, the code is complete but there are some other processes I need to work through before it can be made public.
One more question @Craigacp , the above model when run inference on like: OrtSession.Result output = session.run(inputs); System.out.println(output.get(1)); gives me an object of type ONNXTensor
But another model that I have gives an object of type ONNXSequence. When I run the .getValue() on that it gives me an ONNXMap which I don't know how to traverse. I think it would be helpful to document the nuances in result types, how to parse each type of result, etc.
How do I parse an ONNXMap anyways?
I recommend you have a look through the javadoc. Each OnnxValue
has an associated ValueInfo
which describes the type & shape of the value. The getValue()
method on tensors and maps returns basic Java objects, and OnnxSequence.getValue()
returns a list of OnnxTensor
or OnnxMap
. It's the same structures you get out of the Python or C# ORT APIs, just mapped into Java types. The scikit-learn converters produce this combination of tensor and sequence outputs, in deep learning models you're much more likely to just get tensors out. You can see an example of parsing it into another structure here, though that's for a slightly older version of ONNX Runtime before a recent change to how OnnxSequence
works.
Parsing the outputs is ONNX model dependent, but the outputs are always made up of OnnxTensor
, OnnxSparseTensor
, OnnxSequence
and OnnxMap
, and those classes describe the possible values they can contain in the docs.
This is very helpful @Craigacp! However, it would be very useful to provide these code examples as tutorials on the official page https://onnxruntime.ai/docs/get-started/with-java.html#sample. I feel like the documentation/tutorials for onnx java (at least) has much room to grow to be adopted by the wider ML community. I hope your complete example would talk about all these nuances :)
I agree that more tutorials would be useful, but I'm spending the time I can put on ORT into maintaining the Java API, and haven't got any spare bandwidth for writing tutorials. When the example code is released it will focus on efficient use of OnnxTensor and memory management, rather than the maps or sequences, and we will keep it updated with best practices for ORT in Java. In general as I maintain both projects then Tribuo will show how to use most ORT features, but not necessarily in the most efficient way for other projects as Tribuo imposes a bunch of restrictions on how libraries can be used.
@Craigacp do you recommend trying out Tribuo wrappers over the raw onnx runtime Java API?
Tribuo is good if you want to build ML models, but also have an ONNX model accessible through the same interface. It's focused on traditional ML & NLP use cases, and less optimal for large dense inputs like images or text embedding vectors. So if you want to work on vision or transformer based NLP tasks exclusively then you should use ONNX Runtime directly as it's lower overhead.
Got it, thanks @Craigacp I am currently working with a xgboost model exported in onnx format so I guess I will give tribuo a try.
Tribuo can directly load in an XGBoost model trained in Python without exporting it to ONNX first if that is useful.
Describe the issue
Hi, I am getting wrong result when I try to run the below java code. I get gibberish result like "[J@ba1f559". Am I messing up something in my code?
To reproduce
Inference Code
package src.main.java.ml;
import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession;
import java.util.HashMap; import java.util.Map;
import static javax.swing.UIManager.put;
public class BaseModelONNX {
}
Training Code
import numpy from sklearn.datasets import load_iris from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.ensemble import RandomForestClassifier
data = load_iris() X = data.data[:, :4] y = data.target
ind = numpy.arange(X.shape[0]) numpy.random.shuffle(ind) X = X[ind, :].copy() y = y[ind].copy()
pipe = Pipeline([('scaler', StandardScaler()), ('clr', RandomForestClassifier())]) pipe.fit(X, y)
from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType
options = {id(pipe): {'zipmap': False}}
initial_types = [ ('sepal_length', FloatTensorType([None, 1])), ('sepal_width', FloatTensorType([None, 1])), ('petal_length', FloatTensorType([None, 1])), ('petal_width', FloatTensorType([None, 1])), ]
model_onnx = convert_sklearn( pipe, 'pipeline_rf', initial_types=initial_types, options=options )
with open('pipeline_rf.onnx', 'wb') as f: f.write(model_onnx.SerializeToString())
Urgency
No response
Platform
Mac
OS Version
13.4.1
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.15.1
ONNX Runtime API
Java
Architecture
X86
Execution Provider
Default CPU
Execution Provider Library Version
No response