tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
Apache License 2.0
8.39k stars 410 forks source link

Error when importing onnx of transformers bert model #1811

Open 7kanak opened 3 months ago

7kanak commented 3 months ago

I was following the onnx import example in burn examples/onnx-inference . but instead of using to generate the onnx model, i am using transformers python package to generate onnx model

model_id = "albert/albert-base-v2"
feature = "sequence-classification"

model = AutoModelForSequenceClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

inputs = tokenizer("onnx is wonderful", return_tensors="pt")
    input_names=['input_ids', 'attention_mask'], 
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'}, 
                  'attention_mask': {0: 'batch_size', 1: 'sequence'}, 
                  'logits': {0: 'batch_size', 1: 'sequence'}}, 

getting this error

  checking outputs    
  DEBUG burn_import::onnx::from_onnx: it's a constant    
  DEBUG burn_import::onnx::proto_conversion: Converting ONNX node with type "Unsqueeze"    
  DEBUG burn_import::onnx::from_onnx: renaming node "/albert/embeddings/Unsqueeze"    
  DEBUG burn_import::onnx::from_onnx: checking node unsqueeze3 for constants    
  DEBUG burn_import::onnx::from_onnx: checking input Argument { name: "/albert/embeddings/Constant_4_output_0", ty: Tensor(TensorType { elem_type: Int64, dim: 1, shape: Some([1]) }), value: None, passed: false } for const    
  DEBUG burn_import::onnx::from_onnx: input /albert/embeddings/Constant_4_output_0 matched constant node constant10    
  ERROR burn_import::logger: PANIC => panicked at /home/kanak/.cargo/registry/src/
  called `Option::unwrap()` on a `None` value    

  --- stderr
  thread 'main' panicked at /home/kanak/.cargo/registry/src/
  called `Option::unwrap()` on a `None` value

you can find the onnx model file at

laggui commented 3 months ago

Looking at the panicking statement in question, seems to be an issue with the unsqueeze op.

Right now most of the ONNX ops are parsed during import expecting some of the shapes to be available in the protobuf metadata, but in practice that is not always the case (as seen here).

sky-2002 commented 3 weeks ago

I am also having same error @laggui

laggui commented 3 weeks ago

In previous versions the unsqueeze onnx op was also checking the input shapes but it wasn't required. This has been changed on main if you want to try it out. It should not give you this error 🙂