pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
842 stars 98 forks source link

Inference failed when there are tensors of two datatypes in input. #298

Open VanderBieu opened 19 hours ago

VanderBieu commented 19 hours ago

I have an ONNX model which takes one float tensor and three int tensors as input. The inference session works well in python version of onnxruntime. However it cannot work with "2.0.0-rc.6" version of ort. The input is

inputs!{
            "image_embeddings" => image_embeddings.view(),// f32
            "graph_points" => collated_points.view(), // i64
            "pairs" => collated_pairs.view(),// i64
            "valid" => collated_valid.view(),// i64
        }.unwrap();

in rust

image_embeddings = np.random.rand(1,256,64,64).astype(np.float32)
graph_points = np.random.rand(1, 40, 2).astype(np.int64)
pairs = np.random.rand(1, 40,16, 2).astype(np.int64)
valid = np.random.rand(1, 40,16).astype(np.int64)
inputs = {
    'image_embeddings': image_embeddings,
    'graph_points': graph_points,
    'pairs': pairs,
    'valid': valid
}

in python

No matter how I convert float to int or int to float the inference session kept failing with { code: InvalidArgument, msg: "Unexpected input data type. Actual: (tensor(float)) , expected: (tensor(int64))" }

VanderBieu commented 18 hours ago

I think it might be the problem that SessionInput is implemented using static type. To be specific the Session Input is constructed from HashMap<K,V> and Vec<(K,V)> so the value is automatically converted to f32 or i64. My suggestion is that the input.rs should be reworked using more flexible structure.

decahedron1 commented 14 hours ago

Check the expected type of the graph's inputs using a program like Netron.

VanderBieu commented 3 hours ago

Check the expected type of the graph's inputs using a program like Netron.

I doublechecked, there is nothing wrong with input types.

image
decahedron1 commented 2 hours ago

How are you creating the tensors? Are you certain they are the expected dtype? (You can print Value::dtype to check.)