tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.13k stars 395 forks source link

ONNX conversion: Only tensor input is valid Argument #1560

Open AdrianEddy opened 4 months ago

AdrianEddy commented 4 months ago

Describe the bug When trying to convert onnx model to Burn, I'm encountering this panic:

  WARN burn_import::onnx::dim_inference: Must implement dimension inference for Slice
  WARN burn_import::onnx::dim_inference: Must implement dimension inference for Slice
  ERROR burn_import::logger: PANIC => panicked at burn-import-0.12.1\src\onnx\dim_inference.rs:311:9:
  Only tensor input is valid Argument

  --- stderr
  thread 'main' panicked at burn-import-0.12.1\src\onnx\dim_inference.rs:311:9:
  Only tensor input is valid Argument

The node_input on that line is { name: "cast2_out1", ty: Scalar(Int64), value: None, passed: true }

The model I'm trying to convert is gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.onnx

I also have alternative onnx model exported by someone else, which errors with a different message, might also be worth looking into: gmflow-scale1-mixdata-train320x576-4c3a6e9a_1x3x480x640_sim.onnx

To Reproduce

  1. Create a build.rs with:
    fn main() {
    ModelGen::new()
        .input("gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.onnx")
        .out_dir("model/")
        .run_from_script();
    }
  2. Watch it fail

Expected behavior I expected it to finish conversion successfully

Desktop (please complete the following information):

Additional context Model exported using torch.onnx.export from this repo

laggui commented 4 months ago

According to the line that panics it seems to be due to an Unsqueeze op.

Looks like this op is not full supported yet (#600).

Do you know what else is missing for the full support @antimora ?

antimora commented 4 months ago

Thanks, @AdrianEddy, for filing this issue. This helps us prioritize the ONNX ops.

OK. I have checked the ONNX file and it contains lots of nodes (> 10K) but with a few types:

  1. Add
  2. BatchNormalization
  3. Cast
  4. Concat
  5. Constant
  6. ConstantOfShape (CURRENTLY NOT SUPPORTED)
  7. Conv
  8. Div
  9. Expand (CURRENTLY NOT SUPPORTED)
  10. Gather
  11. Identity
  12. Mul
  13. Relu
  14. Reshape
  15. Shape
  16. Slice (CURRENTLY NOT SUPPORTED)
  17. Sub
  18. Tile (CURRENTLY NOT SUPPORTED)
  19. Unsqueeze

Unsupported in ONNX import currently:

  1. ConstantOfShape
  2. Expand
  3. Slice
  4. Tile

The full list of supported ops: https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md

These missing ops came up recently in this issue as well: https://github.com/tracel-ai/burn/issues/1544

The reason why Unsqueeze fails is because it is referencing ConstantOfShape which isn't supported at the moment.

Tagging @nathanielsimard so he's aware of missing ONNX ops.

antimora commented 4 months ago

If we handle https://github.com/tracel-ai/burn/issues/1544 fully then this ticket will get resolved too.

antimora commented 3 months ago

I fixed unsqueeze scalar (https://github.com/tracel-ai/burn/pull/1690) issue but it's blocked on converting Gemm node to linear module:

[burn-import]% cr ./clip_image_model_vitb32_batchsize1.onnx ./out

....

DEBUG burn_import::onnx::from_onnx: output name: 
/visual/transformer/resblocks/resblocks.0/attn/Reshape_3_output_0
DEBUG burn_import::onnx::proto_conversion: Converting ONNX node with type "Gemm"
DEBUG burn_import::onnx::from_onnx: renaming node "/visual/transformer/resblocks/resblocks.0/attn/Gemm"
ERROR burn_import::logger: PANIC => panicked at crates/burn-import/src/onnx/dim_inference.rs:122:46:
called `Option::unwrap()` on a `None` value
thread 'main' panicked at crates/burn-import/src/onnx/dim_inference.rs:122:46:
called `Option::unwrap()` on a `None` value
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
[burn-import]%
image