webonnx / wonnx

A WebGPU-accelerated ONNX inference run-time written 100% in Rust, ready for native and the web
1.54k stars 54 forks source link

Split into tensors with different dimensions seem to work incorrectly #205

Open daniilsunyaev opened 3 months ago

daniilsunyaev commented 3 months ago

Describe the bug We're having some issues trying to execute YOLOv5 model. After some days of debugging it looks like Split operation works incorrectly when trying to split 1x3x8x13x11 into 1x3x8x13x7 + 1x3x8x13x2 + 1x3x8x13x2. I was trying to reproduce minimal example in tests, but have runtime issues.

To Reproduce Steps to reproduce the behavior: Consider having a 1d tensor: $v = [1, 2, 3, 4, 5, 6]$. We want to split it into $v_1 = [1, 2]$ and $v_2 = [3, 4, 5, 6]$. Corresponding test:

fn test_split() {
    let _ = env_logger::builder().is_test(true).try_init();
    let mut input_data = HashMap::new();
    let data = (1..=6).map(|x| x as f32).collect::<Vec<f32>>();
    input_data.insert("input".to_string(), data.as_slice().into());

    let model = model(graph(
        vec![tensor("input", &[6])],
        vec![tensor("Y", &[2]), tensor("W", &[4])],
        vec![initializer_int64("split", vec![2, 4], vec![2])],
            vec!["input", "split"],
            vec!["Y", "W"],
            vec![attribute("axis", 0)],

    let session =
        pollster::block_on(wonnx::Session::from_model(model)).expect("session did not create");
    let result = pollster::block_on(session.run(&input_data)).unwrap();

    let test_y = vec![1., 2.];
    common::assert_eq_vector((&result["Y"]).try_into().unwrap(), &test_y);
    let test_w = vec![3., 4., 5., 6.];
    common::assert_eq_vector((&result["W"]).try_into().unwrap(), &test_w);


test test_split ... FAILED


---- test_split stdout ----
[2024-03-20T10:59:24Z ERROR wgpu::backend::wgpu_core] Handling wgpu errors as fatal by default
thread 'test_split' panicked at /home/daniil/.cargo/registry/src/index.crates.io-6f17d22bba15001f/wgpu-0.19.3/src/backend/wgpu_core.rs:3006:5:
wgpu error: Validation Error

Caused by:
    In Device::create_bind_group
      note: label = `Split`
    Number of bindings in bind group descriptor (4) does not match the number of bindings defined in the bind group layout (3)

note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

After some debugging it looks like here this is executed twice during the test, first time split attribute is correctly set (probably by optimizer) to [4,2]. On the second execution it is set to default value [3,6] (which looks a bit confusing to me on its own, since split intput sum should be eq to dimension of specified axis - 6 in our case). Am I initializing split input incorrectly?

Expected behavior Test should pass.