pykeio / ort

Fast ML inference & training for Rust with ONNX Runtime
https://ort.pyke.io/
Apache License 2.0
903 stars 102 forks source link

Support for scalar inputs #12

Closed travismorton closed 1 year ago

travismorton commented 1 year ago

Right now running a session on a model that has a scalar input (0 dimension array) fails. I think these are rare, but one example is the silero-vad ONNX model which takes sample rate as a scalar input. Here's a minimum reproducible example:

use std::sync::Arc;

use ndarray::{arr0, Array};
use ort::{
    tensor::{DynOrtTensor, FromArray, InputTensor, OrtOwnedTensor},
    Environment, ExecutionProvider, GraphOptimizationLevel, OrtResult, SessionBuilder,
};

fn main() -> OrtResult<()> {
    let environment = Arc::new(
        Environment::builder()
            .with_name("silero-vad")
            .with_execution_providers([ExecutionProvider::cpu()])
            .build()?,
    );

    let session = SessionBuilder::new(&environment)?
        .with_optimization_level(GraphOptimizationLevel::Level1)?
        .with_intra_threads(1)?
        .with_model_from_file("./silero-vad.onnx")?;
    let inputs = vec![
        InputTensor::from_array(Array::<f32, _>::zeros([1, 512]).into_dyn()),
        // 0-dim input //
        InputTensor::from_array(arr0::<i64>(16000).into_dyn()),
        InputTensor::from_array(Array::<f32, _>::zeros([2, 1, 64]).into_dyn()),
        InputTensor::from_array(Array::<f32, _>::zeros([2, 1, 64]).into_dyn()),
    ];

    let result: Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>> = session.run(inputs).unwrap();
    let vad: OrtOwnedTensor<f32, _> = result[0].try_extract().unwrap();
    println!("VAD: {:?}", vad);

    Ok(())
}

Running this will result in a runtime error:

thread 'main' panicked at 'assertion failed: `(left != right)`
  left: `0`,
 right: `0`', /home/travis/.cargo/registry/src/github.com-1ecc6299db9ec823/ort-1.13.3/src/session.rs:627:5

By removing the the dimension assertion at that line it will run correctly

VAD: OrtOwnedTensor { data: TensorPtr { ptr: TensorPointerHolder { tensor_ptr: 0x55c22ac452d0 }, array_view: [[0.041475803]], shape=[1, 1], strides=[1, 1], layout=CFcf (0xf), dynamic ndim=2 } }

I'm not really sure what other effects removing that assertion would have here, but I'm happy to open a PR https://github.com/pykeio/ort/blob/e4376dc97a3e5282cd74c139d1fa069730521720/src/session.rs#L624-L632

decahedron1 commented 1 year ago

Yes, removing the assertion is fine. Not sure why it was there to begin with 😅

travismorton commented 1 year ago

Great, thank you!