microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
270 stars 50 forks source link

Optimizer fails on shape inference error over native_batch_norm #1443

Open xadupre opened 4 months ago

xadupre commented 4 months ago

The optimizer fails for the attach model (so dort fails as well). It was obtained with the latest onnx, onnxscript and torch nightly.

dump3bug.zip

To replicate:

import onnx
from onnxscript import optimizer
onx = onnx.load(model)
optimized = optimizer.optimize(onx)

It is coming from the following graph module.

graph():
    %primals_7 : [num_users=1] = placeholder[target=primals_7]
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_3 : [num_users=1] = placeholder[target=primals_3]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %add : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_7, %primals_1), kwargs = {})
    %_native_batch_norm_legit_no_training : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%add, %primals_2, %primals_3, %primals_4, %primals_5, 0.1, 1e-05), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_no_training, 0), kwargs = {})
    return (add, getitem)

Error:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "check_model.py", line 43, in <module>
    optimized = optimizer.optimize(onx)
  File "onnxscript/onnxscript/optimizer/__init__.py", line 61, in optimize
    model = onnx.shape_inference.infer_shapes(
  File "onnx/onnx/shape_inference.py", line 46, in infer_shapes
    inferred_model_str = C.infer_shapes(
onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:_aten_native_batch_norm_inference_onnx, node name: _aten_native_batch_norm_inference_onnx_2): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (2) vs (0)
justinchuby commented 4 months ago

@titaiwangms Do you have an idea? Looks like it is related to https://github.com/microsoft/onnxscript/blame/0d98619dee85025f8fb110864607f6f477c3d8ae/onnxscript/function_libs/torch_lib/ops/core.py#L5625

titaiwangms commented 4 months ago

I can take a look after I done my work on hands,

titaiwangms commented 4 months ago

@xadupre I think optimizer should be already applied in torch-nightly? https://github.com/pytorch/pytorch/blob/d5182bb75bbc109cb327212e7205981fbf72cb5e/torch/onnx/_internal/exporter.py#L1274

Are you writing new optimization? Just trying to understand the usage here.

Specifically, if we try:

import onnx
from onnxscript import optimizer
from onnxscript.rewriter import onnxruntime as ort_rewriter

onx = onnx.load("dump3bug.onnx")
onnx.checker.check_model(onx, full_check=True)
optimized = optimizer.optimize(onx)

The same error is spotted by checker.

xadupre commented 4 months ago

It does. I tweaked torch code in onnxruntime.py to get the model before it gets optimized to know of the error happens before optimization or after. It is after.

titaiwangms commented 4 months ago

Could you update the model to the one before optimizer?

xadupre commented 4 months ago

I'll check again but it should be the one before optimizer.

titaiwangms commented 4 months ago

I mean the model in zip doesn't pass onnx.checker.check_model(model, full_check=True). That's why it gets the error message from https://github.com/microsoft/onnxscript/blob/2a78f61b5d4e31bd591cc020211aa903bc425270/onnxscript/optimizer/__init__.py#L61

It's not even hitting the constant folding and general rewriter yet it seems.

titaiwangms commented 4 months ago

I wonder if we should put onnx.checker to guard the models generated from converter/dort. Or we already did?

gramalingam commented 4 months ago

I mean the model in zip doesn't pass onnx.checker.check_model(model, full_check=True).

True ... tried it, and this seems to fail

xadupre commented 4 months ago

I would not call onnx.checker. The converter may introduce nodes coming from domain com.microsoft. I created PR #1467 to replicate the issue.

titaiwangms commented 4 months ago

So I think there are two issues here. The first one is that if we don't want to make sure our models passing checker before feeding to optimizer, we should turn off strict_mode in ONNX shape type inference inside optimizer, since they are basically the same. I will submit a PR for this to unblock this model.

The other issue is that, in torchlib, we respect PyTorch native_batch_norm CUDA to accept size=0 outputs in index=1 and 2 (here), which is originated from PyTorch code. That's why in the error message saying the existing shape is 0. However, ONNX shape type inference infers this as 2. @justinchuby @xiaowuhu @gramalingam any suggestion on this?

justinchuby commented 4 months ago

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

titaiwangms commented 4 months ago

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

The tests @xadupre I executed with CUDA, and repro the error. Could you point the code "passing shape inference". My guess is that one does not invoke strict mode.

justinchuby commented 4 months ago

I think it should be covered in the torchlib tests, but we don't run it with cuda regularly.

gramalingam commented 4 months ago

Hi, is this related to https://github.com/microsoft/onnxscript/issues/1256 ?

gramalingam commented 4 months ago

Do we know if this model is exported with cuda or with cpu? Even though the models exported under cuda is different from that under cpu, each of them should pass shape inference, or there must be something I don't remember?

Given the comment in the code that Titai links above, it appears that cuda/cpu have different behavior? But the onnxscript encoding chooses one of the two behaviors (it says cuda) ... now, if the actual shapes are being emitted as produced by the runtime, there is going to be a mismatch between shape inferred by ONNX (the cuda shape) and the valueinfo shape embedded (coming from cpu) ... that would explain it, right?

gramalingam commented 4 months ago

But Titai also says the error is reproduced in a cuda run, which seems strange (inconsistent with the message here)

titaiwangms commented 4 months ago

I guess we need to find out what happened in ONNX shape type inference. One can try this out with #1467 test cases, and turn https://github.com/microsoft/onnxscript/pull/1472 strict mode back to True.

titaiwangms commented 4 months ago

Write down some findings today:

This is only reproducible on DORT. Dynamo_export does not support this case, because it is decomposed at aot_autograd (Functionalization). And ExportedProgram can't repro this because the unused outputs are trimmed.