nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

(torch-to-onnx) FLUX.1 - bf16 onnx.LayerNormalization failing to legalize #888

Open monorimet opened 3 days ago

monorimet commented 3 days ago

Hi all, I'm trying to compile bf16 flux mmdit from onnx export.

Running into the following torch-to-onnx legalization error:

iree-compile --iree-hal-target-device=amdgpu --iree-hip-target=gfx942 --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external flux_1_dev_static_bf16.mlir -o flux-dev_sampler_bs1_512_1024x1024_bf16_amdgpu-gfx942.vmfb

flux_1_dev_static_bf16.mlir:3873:13: error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
    %1315 = torch.operator "onnx.LayerNormalization"(%1161, %1313, %1314) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32} : (!torch.vtensor<[1,4096,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.vtensor<[3072],bf16>) -> !torch.vtensor<[1,4096,3072],bf16> 
            ^
flux_1_dev_static_bf16.mlir:3873:13: note: see current operation: %3010 = "torch.operator"(%2325, %3007, %3009) <{name = "onnx.LayerNormalization"}> {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32} : (!torch.vtensor<[1,4096,3072],bf16>, !torch.vtensor<[3072],bf16>, !torch.vtensor<[3072],bf16>) -> !torch.vtensor<[1,4096,3072],bf16>

reproducible with the following MLIR and compile command:

onnxln_test.mlir

func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],bf16>, %arg1: !torch.vtensor<[768],bf16>, %arg2: !torch.vtensor<[768],bf16>) -> (!torch.vtensor<[1,4,768], bf16>) 
                           attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { 
  %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : bf16} : (!torch.vtensor<[1,4,768],bf16>, !torch.vtensor<[768],bf16>, !torch.vtensor<[768],bf16>) -> !torch.vtensor<[1,4,768],bf16>
  return %0 : !torch.vtensor<[1,4,768],bf16>
}

compile command:

iree-compile onnxln_test.mlir -o ln.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx942

The minimized reproducer may be taking some liberties as to a "correct" usage of bf16 layernorm -- I took our fp32 test in torch-mlir and find+replaced "fp32" with "bf16", which I'm not confident in, but it does reproduce the same error.

jinchen62 commented 3 days ago

There are two issues.

  1. We are missing TorchToLinalg lowering support for LayerNorm op. I will work on it.
  2. To make the bf16 working like the f32 case to lower to torch level, the reproducer should be like
    func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],bf16>, %arg1: !torch.vtensor<[768],bf16>, %arg2: !torch.vtensor<[768],bf16>) -> (!torch.vtensor<[1,4,768], bf16>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { 
    %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32, torch.onnx.stash_type = 16 : si64} : (!torch.vtensor<[1,4,768],bf16>, !torch.vtensor<[768],bf16>, !torch.vtensor<[768],bf16>) -> !torch.vtensor<[1,4,768],bf16>
    return %0 : !torch.vtensor<[1,4,768],bf16>
    }

    The attribute epsilon should be f32, and it should come with stash_type to match bf16 type.

jinchen62 commented 3 days ago

@monorimet Actually we do have decomposition for torch.aten.native_layer_norm which is lowered from onnx.laynorm. I was able to compile the bf16 test I posted above. For your case, so doesn't it come with the stash_type attribute?

monorimet commented 3 days ago

@jinchen62 Thanks, I forgot to include the MLIR: https://sharkpublic.blob.core.windows.net/sharkpublic/flux.1/flux_1_dev_static_bf16.mlir It does not seem to come with the stash_type attribute, only {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999997E-7 : f32}