tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
7.96k stars 383 forks source link

OnnxGraph does not set `passed` correctly and input names with numbers fail. #1983

Open antimora opened 2 weeks ago

antimora commented 2 weeks ago

Attaching ONNX file and graph output. face-detector-bug.zip

Here is a part of the parsed graph:

ParsedOnnxGraph(
     {
        nodes: [
            Node {
                node_type: Conv2d,
                name: "conv2d1",
                inputs: [
                    Argument {
                        name: "input1",
                        ty: Tensor(
                            TensorType {
                                elem_type: Float32,
                                dim: 4,
                                shape: Some(
                                    [
                                        1,
                                        3,
                                        480,
                                        640,
                                    ],
                                ),
                            },
                        ),
                        value: None,
                        passed: false,
                    },
                    Argument {
                        name: "517",
                        ty: Tensor(
                            TensorType {
                                elem_type: Float32,
                                dim: 4,
                                shape: Some(
                                    [
                                        16,
                                        3,
                                        3,
                                        3,
                                    ],
                                ),
                            },
                        ),
                        value: Some(
                            Float32s([-0.0046816375, 0.0037100762, 0.023904493, -0.0004699939, -0.01531022, -0.013523362, 0.034382608, ...]),
                        ),
                        passed: false,
                    },

Noticed that passed flag is not set to true for the first input.

Also name: "517" fails because it's a not a proper rust identifier. This result in this error:

ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/burn/ty.rs:79:19:
Ident cannot be a number; use Literal instead
thread 'main' panicked at crates/burn-import/src/burn/ty.rs:79:19:
Ident cannot be a number; use Literal instead
stack backtrace:
   0: rust_begin_unwind
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/std/src/panicking.rs:645:5
   1: core::panicking::panic_fmt
             at /rustc/07dca489ac2d933c78d3c5158e3f43beefeb02ce/library/core/src/panicking.rs:72:14
   2: proc_macro2::fallback::validate_ident
   3: proc_macro2::fallback::Ident::new_checked
   4: proc_macro2::imp::Ident::new_checked
   5: proc_macro2::Ident::new
   6: burn_import::burn::ty::ScalarType::new
   7: burn_import::onnx::to_burn::<impl core::convert::From<&onnx_ir::ir::Argument> for burn_import::burn::ty::Type>::from
   8: burn_import::onnx::to_burn::ParsedOnnxGraph::mul_conversion
   9: burn_import::onnx::to_burn::ParsedOnnxGraph::into_burn
  10: burn_import::onnx::to_burn::ModelGen::generate_model
  11: burn_import::onnx::to_burn::ModelGen::run
  12: burn_import::onnx::to_burn::ModelGen::run_from_cli
  13: onnx2burn::main
  14: core::ops::function::FnOnce::call_once
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.

It looks like it was removed. Not sure if this is accidentally or intentionally:

crates/burn-import/src/onnx/from_onnx.rs of https://github.com/tracel-ai/burn/pull/1857/files

image
antimora commented 2 weeks ago

@skewballfox, your input is appreciated. Currently this bug is blocking #1915. It seems to be work for mnist.onnx checked in onnx-inference example. I am not sure why this isn't working. The attached onnx works with #1989 PR fix, so you know.

skewballfox commented 2 weeks ago

rephrasing what I mentioned on discord. passed is only updated on the original graph inputs here, Mainly because passed is only used for filtering graph inputs (graph outputs can be filtered if they exist as a node output). The copy of the input argument in the graph inputs is listed as passed.

we could probably remove passed altogether and instead have a separate passed_input: Vec<bool> in graph data.

The initializer thing (the numbered variable) is weird, because we have arguments like that in some of the onnx test (I think I see a tensor named 13 in add or Gemm test when debugging frequently), but normally those names don't make it into the generated burn code