nbigaouette / onnxruntime-rs

Rust wrapper for Microsoft's ONNX Runtime (version 1.8)
Apache License 2.0
276 stars 99 forks source link

Does onnxruntime-rs support dynamic demension? #84

Closed HaoboGu closed 3 years ago

HaoboGu commented 3 years ago

Hello, I am new to rust, I'm trying to load my model which has dynamic dimension like:

image

the actual size is defined using batch_size and length: image

When I use onnxruntime-rs to load the model, the input I got is

Inputs:
  0:
    name = input
    type = Int64
    dimensions = [None, None]
  1:
    name = past_1
    type = Float
    dimensions = [Some(2), None, Some(4), None, Some(96)]
  2:
    name = past_2
    type = Float
    dimensions = [Some(2), None, Some(4), None, Some(96)]
Outputs:
  0:
    name = output
    type = Float
    dimensions = [None, None, Some(40015)]
  1:
    name = out_past_1
    type = Float
    dimensions = [Some(2), None, Some(4), None, Some(96)]
  2:
    name = out_past_2
    type = Float
    dimensions = [Some(2), None, Some(4), None, Some(96)]

I cannot create a test input for my model because rust panics here:

let input0_shape: Vec<usize> = session.inputs[0].dimensions().map(|d| d.unwrap()).collect();

I also tried to convert an array to the input_shape using let array = Array::linspace(0.0_f32, 1.0, 2*2*1*4*1*96*2*1*4*1*96 as usize).into_shape(([2], [2, 1, 4, 1, 96], [2, 1, 4, 1, 96]));, but I got:

the trait bound `([{integer}; 1], [{integer}; 5], [{integer}; 5]): onnxruntime::ndarray::Dimension` is not satisfied

the trait `onnxruntime::ndarray::Dimension` is not implemented for `([{integer}; 1], [{integer}; 5], [{integer}; 5])`

note: required because of the requirements on the impl of `onnxruntime::ndarray::IntoDimension` for `([{integer}; 1], [{integer}; 5], [{integer}; 5])`rustc(E0277)
main.rs(77, 93): the trait `onnxruntime::ndarray::Dimension` is not implemented for `([{integer}; 1], [{integer}; 5], [{integer}; 5])`

The question is, how can I build a test input for this model?

HaoboGu commented 3 years ago

I managed to build a valid input using IxDyn, but an error occurred in the inference:

Jul 05 15:46:50.651 DEBUG drop{self=OrtTensor { c_ptr: 0x7fce039b0140, array: [[[[[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]]]],

 [[[[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]],

   [[0.0, 0.0, 0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0, 0.0, 0.0]]]]], shape=[2, 1, 4, 1, 96], strides=[384, 384, 96, 96, 1], layout=Cc (0x5), dynamic ndim=5, memory_info: MemoryInfo { ptr: 0x7fce02406310 } }}: onnxruntime::tensor::ort_tensor: Dropping Tensor.
thread 'main' panicked at 'called `Result::unwrap()` on an `Err` value: Run(Msg("Unexpected input data type. Actual: (N11onnxruntime17PrimitiveDataTypeIfEE) , expected: (N11onnxruntime17PrimitiveDataTypeIxEE)"))', src/main.rs:34:24
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

The following is the source:

use ndarray::{ArrayD, IxDyn};
use onnxruntime::{environment::Environment, tensor::OrtOwnedTensor, GraphOptimizationLevel};
use tracing::Level;
use tracing_subscriber::FmtSubscriber;

fn main() {
    // a builder for `FmtSubscriber`.
    let subscriber = FmtSubscriber::builder()
        // all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.)
        // will be written to stdout.
        .with_max_level(Level::TRACE)
        .finish();

    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");

    let env = Environment::builder().with_name("env").build().unwrap();
    let mut session = env
        .new_session_builder()
        .unwrap()
        .with_optimization_level(GraphOptimizationLevel::Basic)
        .unwrap()
        .with_model_from_file("model.onnx")
        .unwrap();

    // println!("{:#?}", session.inputs);
    // println!("{:#?}", session.outputs);

    let inputs = ArrayD::<f32>::ones(IxDyn(&[1,2]));
    let past1 = ArrayD::<f32>::zeros(IxDyn(&[2,1,4,1,96]));
    let past2 = ArrayD::<f32>::zeros(IxDyn(&[2,1,4,1,96]));
    let input_v = [inputs, past1, past2];
    let a =input_v.to_vec();
    let outputs: Vec<OrtOwnedTensor<f32, _>> =
        session.run(a).unwrap();
    print!("outputs: {:#?}", outputs);
}

@nbigaouette do you have any idea about this?

haixuanTao commented 3 years ago

Your dimension seems off. Shouldn't you be passing the following?

    let inputs = ArrayD::<f32>::ones(IxDyn(&[1,2]));
    let past1 = ArrayD::<f32>::zeros(IxDyn(&[2,1,4,2,96]));
    let past2 = ArrayD::<f32>::zeros(IxDyn(&[2,1,4,2,96]));
HaoboGu commented 3 years ago

@haixuanTao Thanks, you're right