sonos / tract

Tiny, no-nonsense, self-contained, Tensorflow and ONNX inference
Other
2.23k stars 214 forks source link

Thread panicked: assertion failed index < dim #653

Closed SuperFluffy closed 2 years ago

SuperFluffy commented 2 years ago

I can't get my model, which expects an input of shape 1x51, to take an array of the same shape. Instead the thread panics. Am I doing something obviously wrong?

I am defining a model and load a CSV file, the first row of which I want to feed to the model:

    let model = tract_onnx::onnx()
        .model_for_path("../model.onnx")?
        .with_input_fact(0, InferenceFact::dt_shape(i64::datum_type(), tvec!(1, 51)))?
        .into_optimized()?
        .into_runnable()?;

    let file = std::fs::File::open(data.csv")?;
    let mut reader = csv::ReaderBuilder::new().has_headers(true).from_reader(file);
    let all_data: Array2<i64> = reader.deserialize_array2_dynamic()?;
    let first_row = all_data.row(0).to_owned().into_shape((1, all_data.ncols())).unwrap();
    let tensor: Tensor = first_row.into();
    let result = model.run(tvec!(tensor))?;

For reference, this is what println!("{tensor:?}") gives:

1,51,I64 24, 24, 22, 23, 24, 25, 25, 22, 23, 23, 23, 23...

However, I am immediately confronted with this error:

thread 'main' panicked at 'assertion failed: index < dim', /Users/janis/.cargo/registry/src/github.com-1ecc6299db9ec823/ndarray-0.15.4/src/dimension/mod.rs:361:5

I am not quite sure where to start investigating this. When inspecting the model in Python, I see this:

In [1]: import onnx

In [2]: model = onnx.load("../model.onnx")

In [3]: model.graph.input
Out[3]:
[name: "x_categorical"
type {
  tensor_type {
    elem_type: 7
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 51
      }
    }
  }
}
]

Similarly, looking at the model loaded through tract reveals this:

SimplePlan {
    model: Graph {
        nodes: [
            Node {
                id: 0,
                name: "embeddings.0.weight.0",
                inputs: [],
                op: Const(
                    17,9,F32 0.56551576, -1.2209389, 0.6316886, 0.39847645, 1.5291617, 0.08389956, -0.606807, -1.2726634, -0.70934343, 0.10189452, -0.8492845, -1.3055661...,
                ),
                outputs: [
                    17,9,F32 0.56551576, -1.2209389, 0.6316886, 0.39847645, 1.5291617, 0.08389956, -0.606807, -1.2726634, -0.70934343, 0.10189452, -0.8492845, -1.3055661... >4/0,
                ],
            },
            Node {
                id: 1,
                name: "x_categorical",
                inputs: [],
                op: TypedSource {
                    fact: 1,51,I64,
                },
                outputs: [
                    1,51,I64 >2/0 >7/0 >13/0 >19/0 >25/0 >31/0 >37/0 >43/0 >49/0 >55/0 >61/0 >67/0 >73/0 >79/0 >85/0 >91/0 >97/0 >103/0 >109/0 >115/0 >121/0 >127/0 >133/0 >139/0 >145/0 >151/0 >157/0 >163/0 >169/0 >175/0 >181/0 >187/0 >193/0 >199/0 >205/0 >211/0 >217/0 >223/0 >229/0 >235/0 >241/0 >247/0 >253/0 >259/0 >265/0 >271/0 >277/0 >283/0 >289/0 >295/0 >301/0,
                ],
            },
            Node {
                id: 2,
                name: "Gather_1.slice",
                inputs: [
                    1/0>,
                ],
                op: Slice {
                    axis: 1,
                    start: Val(
                        0,
                    ),
                    end: Val(
                        1,
                    ),
                },
                outputs: [
                    1,1,I64 >3/0,
                ],
            },
            Node {
                id: 3,
                name: "Gather_1.rm_axis",
                inputs: [
                    2/0>,
                ],
                op: Rm(
                    1,
                ),
                outputs: [
                    1,I64 >4/1,
                ],
            },

Here is RUST_BACKTRACE=1:

stack backtrace:
   0: rust_begin_unwind
             at /rustc/9d1b2106e23b1abd32fce1f17267604a5102f57a/library/std/src/panicking.rs:498:5
   1: core::panicking::panic_fmt
             at /rustc/9d1b2106e23b1abd32fce1f17267604a5102f57a/library/core/src/panicking.rs:116:14
   2: core::panicking::panic
             at /rustc/9d1b2106e23b1abd32fce1f17267604a5102f57a/library/core/src/panicking.rs:48:5
   3: ndarray::impl_methods::<impl ndarray::ArrayBase<S,D>>::index_axis
   4: tract_core::ops::array::gather::Gather::eval_t
   5: <tract_core::ops::array::gather::Gather as tract_core::ops::EvalOp>::eval
   6: tract_core::plan::SimplePlan<F,O,M>::run
   7: tract_load::main

EDIT: I have been investigating this. The error comes from core::ops::array::gather, specifically Gather::eval_t:

            to_update.assign(&data_view.index_axis(Axis(self.axis), index_value));

This is the input to index_axis:

data_view.shape(): [17, 9]
axis: 0,
index_value: 24

This value of 24 comes directly from the input data (I checked by the first element to 23).

Looking at the ONNX graph and eyeballing, this looks like Node 2 with the name Gather_1.slice is trying to slice across axis 0 using the first value of Node 1, which happens to be 24?

So is the model bad?

kali commented 2 years ago

Hello! Thanks for your interest in tract!

At this stage, I need to have a look at a test-case with model and data. Can you provide them as:

You can "check" the test case with tract model.onnx --input-bundle io.npz run --assert-output-bundle io.npz

NB:

I need to write a command line cookbook.

kali commented 2 years ago

I'm closing this as nothing is happening, will reopen if relevant.