iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 615 forks source link

[CPU] SDXL UNet gives NaN output on X86 #16529

Open monorimet opened 9 months ago

monorimet commented 9 months ago

What happened?

The SDXL 1.0 1024x1024 UNet e2e inference outputs NaN values for CPU backend. This is with iree_linalg_ext.attention tiled and decomposed.

In my dispatch breakdown, I found that:

hal compiled_unet_main_dispatch_131.mlir.txt = flow dispatch 184 -> starts outputting zeroes

hal.executable public @main_dispatch_131 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,+mwaitx,-lwp,+lzcnt,+sha,-movdir64b,+wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "default"}>) {
    hal.executable.export public @main_dispatch_131_mmt4d_16x80x2048x8x8x1_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 1, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_dispatch_131_mmt4d_16x80x2048x8x8x1_f16() {
        %c84454464 = arith.constant 84454464 : index
        %cst = arith.constant 0.000000e+00 : f16
        %c0 = arith.constant 0 : index
        %0 = hal.interface.constant.load[0] : i32
        %1 = arith.index_castui %0 : i32 to index
        %2 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c84454464) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<16x2048x8x1xf16>>
        %3 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%1) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<80x2048x8x1xf16>>
        %4 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<16x80x8x8xf16>>
        %5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [16, 2048, 8, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<16x2048x8x1xf16>> -> tensor<16x2048x8x1xf16>
        %6 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [80, 2048, 8, 1], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<80x2048x8x1xf16>> -> tensor<80x2048x8x1xf16>
        %7 = tensor.empty() : tensor<16x80x8x8xf16>
        %8 = linalg.fill ins(%cst : f16) outs(%7 : tensor<16x80x8x8xf16>) -> tensor<16x80x8x8xf16>
        %9 = linalg.mmt4d ins(%5, %6 : tensor<16x2048x8x1xf16>, tensor<80x2048x8x1xf16>) outs(%8 : tensor<16x80x8x8xf16>) -> tensor<16x80x8x8xf16>
        flow.dispatch.tensor.store %9, %4, offsets = [0, 0, 0, 0], sizes = [16, 80, 8, 8], strides = [1, 1, 1, 1] : tensor<16x80x8x8xf16> -> !flow.dispatch.tensor<writeonly:tensor<16x80x8x8xf16>>
        return
      }
    }
  }
}

hal compiled_unet_main_dispatch_113.mlir.txt = flow dispatch 186 -> outputs all NaNs

hal.executable public @main_dispatch_113 {
  hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "znver3", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-amx-int8,-kl,-avx10.1-256,-sha512,-avxvnni,-rtm,+adx,+avx2,-hreset,-movdiri,-serialize,+vpclmulqdq,-avx512vl,-uintr,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-gfni,-avxvnniint16,-amx-fp16,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,+cx8,-avx512bw,+sse3,-pku,+fsgsbase,+clzero,+mwaitx,-lwp,+lzcnt,+sha,-movdir64b,+wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,+rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,+sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : index, target_triple = "x86_64-unknown-unknown-eabi-elf", ukernels = "default"}>) {
    hal.executable.export public @main_dispatch_113_transpose_2x4096x10x64_f16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @main_dispatch_113_transpose_2x4096x10x64_f16() {
        %0 = hal.interface.constant.load[0] : i32
        %1 = hal.interface.constant.load[1] : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%2) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x4096x10x64xf16>>
        %5 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%3) : !flow.dispatch.tensor<writeonly:tensor<2x10x4096x64xf16>>
        %6 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0], sizes = [2, 4096, 10, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x4096x10x64xf16>> -> tensor<2x4096x10x64xf16>
        %7 = tensor.empty() : tensor<2x10x4096x64xf16>
        %8 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<2x4096x10x64xf16>) outs(%7 : tensor<2x10x4096x64xf16>) {
        ^bb0(%in: f16, %out: f16):
          linalg.yield %in : f16
        } -> tensor<2x10x4096x64xf16>
        flow.dispatch.tensor.store %8, %5, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : tensor<2x10x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x4096x64xf16>>
        return
      }
    }
  }
}

With this UNet .mlir file: stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir

And this weights file: SDXL1_unet.safetensors

The numerics issue should be reproducible by the following instructions.

Steps to reproduce your issue

  1. Download files.

  2. Compile for NaN output reproducing:

    iree-compile ./stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo=false --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-target-cpu-features=host --iree-opt-const-eval=False --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807 --iree-opt-strip-assertions=true --verify=false --iree-llvmcpu-enable-ukernels=all --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false -o stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_cpu.vmfb --iree-flow-trace-dispatch=@main:186 --iree-flow-break-dispatch=@main:186
  3. Run:

    iree-run-module --device=local-task --module=stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_cpu.vmfb --parameters=model=./stable_diffusion_xl_base_1_0_unet.safetensors --function=main --input=1x4x128x128xf16 --input=1xf16 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 2> unet_out.txt
  4. Check unet_out.txt and see NaN values

  5. Compile for zeroes output reproducing:

    iree-compile ./stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --mlir-print-debuginfo=false --mlir-print-op-on-diagnostic=false --iree-llvmcpu-target-triple=x86_64-linux-gnu --iree-llvmcpu-target-cpu-features=host --iree-opt-const-eval=False --iree-opt-const-expr-hoisting=False --iree-codegen-linalg-max-constant-fold-elements=9223372036854775807 --iree-opt-strip-assertions=true --verify=false --iree-llvmcpu-enable-ukernels=all --iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false -o stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_cpu.vmfb --iree-flow-trace-dispatch=@main:184 --iree-flow-break-dispatch=@main:184
  6. Repeat 3, 4 to see mostly zero output (could be red herring? seems to be the first time any outputs are mostly zero, next flow dispatch is all zeroes, and then we are at NaNs the dispatch after that (which we repro'd first)

What component(s) does this issue relate to?

No response

Version information

Latest IREE (https://github.com/openxla/iree/commit/5d8907e82fc1eb741a4d4d27f5cae865323fd1d7)

(Notably enabled by https://github.com/openxla/iree/commit/946375cad71786462bcfd63dde6fe305d1e3b9ff)

Additional context

No response

hanhanW commented 8 months ago

I got some ideas today, and will take a look at details tomorrow. It seems to happen in attention op on f16 types. We do have some numerical issues in f16 power approximation, see https://github.com/openxla/iree/issues/15936. We could have similar issues in f16 exp approximation.

The workaround is running ExpandF16OpToF32Pass on some f16 ops, though I hope to remove the pass for a long time. The decomposition happens after f16->f32 conversion, so there could be problems in polynomial approximation. I'll take a look at this, and see if this is the same issue in https://github.com/openxla/iree/issues/16544

hanhanW commented 8 months ago

https://github.com/openxla/iree/pull/16577 should address the issue.

hanhanW commented 8 months ago

All the patches are landed to IREE. @monorimet could you verify if this is fixed when you're available?