iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 618 forks source link

[ROCM][gfx942] shared memory limit exceeded on elemwise broadcast (bf16) (flux-dev) #18254

Closed monorimet closed 2 months ago

monorimet commented 3 months ago

What happened?

Compiling a new SD3-like model for gfx942 is introducing a few new things we may need to massage through; the shapes don't seem too out of the ordinary but the graph is in bf16. Given bf16, I'm not sure what to expect w.r.t support in codegen -- I am running into a few compilation failures, which I have tracked down to a few dispatches:

On an elementwise broadcast (24x4608x64x2xbf16) dispatch we run into a shared memory issue:

iree-compile configured_compiled_flux_sampler_run_forward_async_dispatch_36.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-flow-enable-aggressive-fusion --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-outer-dim-concat=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))' --iree-codegen-transform-dialect-library=attention_and_matmul_spec_mfma.mlir --mlir-print-debuginfo ./sampler_dps/compiled_flux_sampler_run_forward$async_dispatch_36.mlir:9:6: error: 'func.func' op uses 56623104 bytes of shared memory; exceeded the limit of 65536 bytes
      func.func @run_forward$async_dispatch_36_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xbf16() {
     ^
./flux_sampler_dps/compiled_flux_sampler_run_forward$async_dispatch_36.mlir:2:2: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {
 ^
failed to translate executables

And the soon following attention dispatch fails without useful diagnostics:

iree-compile configured_compiled_flux_sampler_run_forward_async_dispatch_38.mlir --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-flow-enable-aggressive-fusion --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-outer-dim-concat=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))' --iree-codegen-transform-dialect-library=attention_and_matmul_spec_mfma.mlir --mlir-print-debuginfo
./flux_sampler_dps/compiled_flux_sampler_run_forward$async_dispatch_38.mlir:2:2: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F16_16x16x16_F32>, <MFMA_F16_32x32x8_F32>, <MFMA_I8_16x16x32_I32>, <MFMA_I8_32x32x16_I32>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>, ukernels = "none"}>) {
 ^
failed to translate executables

Steps to reproduce your issue

Download files: attention_and_matmul_spec_mfma.mlir (shared/sdxl_quantized) attention_and_matmul_spec.mlir (main) configured_compiled_flux_sampler_run_forward_async_dispatch_36.mlir configured_compiled_flux_sampler_run_forward_async_dispatch_38.mlir full model if interested -- flux_dev_bs1_512_1024x1024_bf16_sampler.mlir

compile command (shared/sdxl_quantized):

iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-flow-enable-aggressive-fusion --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-outer-dim-concat=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))' --iree-codegen-transform-dialect-library=attention_and_matmul_spec_mfma.mlir <input_mlir> -o out.vmfb

compile command (main):

iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-flow-enable-aggressive-fusion --iree-opt-aggressively-propagate-transposes=true --iree-codegen-llvmgpu-use-vector-distribution=true --iree-opt-outer-dim-concat=true --iree-opt-data-tiling=false --iree-codegen-gpu-native-math-precision=true --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)' --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir <input_mlir> -o out.vmfb

What component(s) does this issue relate to?

Compiler

Version information

Tested on the following commits: tip of shared/sdxl_quantized tip of main

Additional context

shared memory issue seems unrelated to any changes on shared/sdxl_quantized or TD spec compatibility, as both configurations (iree/main vs. iree/shared/sdxl_quantized and their respective td specs) give the same error.

MaheshRavishankar commented 3 months ago

@nithinsubbiah or @andfau-amd do one of you want to take this. cc @nirvedhmeshram for tracking.

monorimet commented 3 months ago

The failure is coming from this op in the torch dialect IR:

 %394:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%390, %393, %367, %float0.000000e00, %false_423, %none_424, %none_425) : (!torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608,128],bf16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608],f32>) 

which is doing bf16 -> bf16/bf32; do we have any lowering to mfma/rocdl that handles this dtype configuration? I'm not sure how it will translate through attention codegen

is it right to expect that this may take a nontrivial amount of work? I want to contribute here but don't have a clear picture on current state of amdgpu/mfma codegen and might be better off looking for workarounds at the modeling level in the meantime.

monorimet commented 3 months ago

I managed to export the model with fp16 precision and ran into the same compile issue with shared memory:

<stdin>:1656:12: error: 'func.func' op uses 56623104 bytes of shared memory; exceeded the limit of 65536 bytes
    %390 = torch.prims.convert_element_type %389, %int5_418 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>
           ^
<stdin>:1656:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>
    %390 = torch.prims.convert_element_type %389, %int5_418 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>
           ^

FP16 flux sampler MLIR -- Azure

This is the elementwise broadcast dispatch (in fp16) on which the shared memory issue occurs:

hal.executable public @run_forward$async_dispatch_36 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>) {
    hal.executable.export public @run_forward$async_dispatch_36_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
    ^bb0(%arg0: !hal.device):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
      hal.return %x, %y, %z : index, index, index
    }
    builtin.module {
      func.func @run_forward$async_dispatch_36_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
        %c28391424 = arith.constant 28391424 : index
        %0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
        %1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
        %2 = arith.index_castui %0 : i32 to index
        %3 = arith.index_castui %1 : i32 to index
        %4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c28391424) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>>
        %5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4608x1x24x64x1x2xf16>>
        %6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%3) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>>
        %7 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 0], sizes = [4608, 1, 24, 64, 1, 2], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4608x1x24x64x1x2xf16>> -> tensor<4608x1x24x64x1x2xf16>
        %8 = tensor.empty() : tensor<24x4608x64x2xf16>
        %9 = tensor.empty() : tensor<1x24x4608x64x1x2xf32>
        %10 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32>
        %11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32>
        %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
        ^bb0(%in: f16, %out: f32):
          %14 = arith.extf %in : f16 to f32
          linalg.yield %14 : f32
        } -> tensor<1x24x4608x64x1x2xf32>
        %extracted_slice = tensor.extract_slice %12[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
        %extracted_slice_0 = tensor.extract_slice %12[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
        %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %extracted_slice, %10, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
        ^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f16):
          %14 = arith.mulf %in_2, %in_3 : f32
          %15 = arith.mulf %in, %in_1 : f32
          %16 = arith.addf %15, %14 : f32
          %17 = arith.truncf %16 : f32 to f16
          linalg.yield %17 : f16
        } -> tensor<24x4608x64x2xf16>
        flow.dispatch.tensor.store %13, %6, offsets = [0, 0, 0, 0], sizes = [24, 4608, 64, 2], strides = [1, 1, 1, 1] : tensor<24x4608x64x2xf16> -> !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>>
        return
      }
    }
  }
}
MaheshRavishankar commented 3 months ago

Cc @IanWood1 . I think he has seen and fixed a similar issue previously. I don't know how he did it though. But extracting a long the innermost dimension this way is an anitpattern IMO. Maybe it is something like a complex number being handled as two separate floats?

monorimet commented 3 months ago

The shared memory issue does not occur and the full model compiles successfully if the q,k,v values are traced with flow.tensor_trace ops. Here is a link to the "traced" .mlir that compiles successfully: flux_dev_bs1_1024x1024_fp16_sampler_traced.mlir

Last time tracing around an op fixed a compile issue, it was with SD3 -- tracing prevented some fusion from being applied. We may have a similar issue here, I'll take a closer look.

monorimet commented 3 months ago

In the traced module's dispatches, I located the same dispatch that originally failed:

hal.executable public @run_forward$async_dispatch_42 {
  hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>) {
    hal.executable.export public @run_forward$async_dispatch_42_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
    ^bb0(%arg0: !hal.device loc("<stdin>":1773:12)):
      %x, %y, %z = flow.dispatch.workgroup_count_from_slice  loc("<stdin>":1773:12)
      hal.return %x, %y, %z : index, index, index loc("<stdin>":1773:12)
    } loc("<stdin>":1773:12)
    builtin.module {
      func.func @run_forward$async_dispatch_42_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() {
        %c28311552 = arith.constant 28311552 : index loc(unknown)
        %0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32 loc("<stdin>":1723:12)
        %1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32 loc("<stdin>":1773:12)
        %2 = arith.index_castui %0 : i32 to index loc("<stdin>":1723:12)
        %3 = arith.index_castui %1 : i32 to index loc("<stdin>":1773:12)
        %4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> loc("<stdin>":1723:12)
        %5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c28311552) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> loc("<stdin>":1221:12)
        %6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>> loc("<stdin>":1773:12)
        %7 = tensor.empty() : tensor<24x4608x64x2xf16> loc("<stdin>":1773:12)
        %8 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 24, 4608, 64, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> -> tensor<24x4608x64xf32> loc("<stdin>":1746:12)
        %9 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32> loc("<stdin>":1743:12)
        %10 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 24, 4608, 64, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> -> tensor<24x4608x64xf32> loc("<stdin>":1739:12)
        %11 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32> loc("<stdin>":1736:12)
        %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %10, %9, %8 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%7 : tensor<24x4608x64x2xf16>) {
        ^bb0(%in: f32 loc("<stdin>":1736:12), %in_0: f32 loc("<stdin>":1739:12), %in_1: f32 loc("<stdin>":1743:12), %in_2: f32 loc("<stdin>":1746:12), %out: f16 loc("<stdin>":1773:12)):
          %13 = arith.mulf %in_1, %in_2 : f32 loc("<stdin>":1747:12)
          %14 = arith.mulf %in, %in_0 : f32 loc("<stdin>":1740:12)
          %15 = arith.addf %14, %13 : f32 loc("<stdin>":1749:12)
          %16 = arith.truncf %15 : f32 to f16 loc("<stdin>":1773:12)
          linalg.yield %16 : f16 loc("<stdin>":1773:12)
        } -> tensor<24x4608x64x2xf16> loc("<stdin>":1773:12)
        flow.dispatch.tensor.store %12, %6, offsets = [0, 0, 0, 0], sizes = [24, 4608, 64, 2], strides = [1, 1, 1, 1] : tensor<24x4608x64x2xf16> -> !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>> loc("<stdin>":1773:12)
        return loc("<stdin>":1773:12)
      } loc("<stdin>":1773:12)
    } loc("<stdin>":1773:12)
  } loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)

For some reason, without the traced ops, this elementwise broadcast has a few things that stick out:

monorimet commented 3 months ago

I managed to isolate a smaller reproducer with just the double-stream block from the flux implementation.

Here is the resulting flux_attn_repro_fp16.mlir that reproduces the failure: Azure Compile command:

iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics) --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir --iree-codegen-llvmgpu-enable-transform-dialect-jit=false  flux_attn_repro_fp16.mlir -o out.vmfb

This command and MLIR compiles successfully if the --iree-opt-outer-dim-concat=true flag is added. I am assuming therefore that said flag must have something to do with the discrepancy between the pass and fail cases.

More notes:

monorimet commented 3 months ago

I see where something seems to be blowing up after comparing the print-after-all of the two smaller cases described in the previous comment. With outer dims concatenated, we get a more reasonable output from vectorization which compiles successfully:

// -----// IR Dump After LLVMGPUVectorLoweringPass (iree-llvmgpu-vector-lowering) //----- //
func.func @main$async_dispatch_26_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
  %cst = arith.constant dense<0.000000e+00> : vector<1xf32>
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
  %2 = arith.index_castui %0 : i32 to index
  %3 = arith.index_castui %1 : i32 to index
  %4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %4, 1 : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  memref.assume_alignment %5, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  %6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %thread_id_x = gpu.thread_id  x
  %thread_id_y = gpu.thread_id  y
  %7 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_x]
    %17 = memref.load %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %16, %c0] : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
    %18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
    scf.yield %18 : vector<1xf32>
  }
  %8 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_y]
    %17 = memref.load %4[%c0, %workgroup_id_y, %workgroup_id_x, %16, %c0, %c0] : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
    %18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
    scf.yield %18 : vector<1xf32>
  }
  %9 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_x]
    %17 = memref.load %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %16, %c1] : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
    %18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
    scf.yield %18 : vector<1xf32>
  }
  %10 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
    %16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_y]
    %17 = memref.load %4[%c0, %workgroup_id_y, %workgroup_id_x, %16, %c0, %c1] : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
    %18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
    scf.yield %18 : vector<1xf32>
  }
  %11 = arith.mulf %9, %10 : vector<1xf32>
  %12 = arith.mulf %7, %8 : vector<1xf32>
  %13 = arith.addf %12, %11 : vector<1xf32>
  %14 = arith.truncf %13 : vector<1xf32> to vector<1xf16>
  %15 = vector.extract %14[0] : f16 from vector<1xf16>
  memref.store %15, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  return
}

Without concatenating outer dims, we get the same dispatch result, but prepended with arith.constant indexes counting down from 4608, and then after setting thread ids as in the above IR, the failing (without outer dim concat) case starts doing what looks like a botched elementwise broadcast -- a load/broadcast/insert pattern that repeats for some 350000+ lines followed by a similarly long chain of vector.extracts:

 %7 = memref.load %5[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %8 = vector.broadcast %7 : f16 to vector<1xf16>
  %9 = vector.insert %8, %cst_0 [0, 0, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %10 = memref.load %5[%c0, %c0, %c1, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %11 = vector.broadcast %10 : f16 to vector<1xf16>
  %12 = vector.insert %11, %9 [0, 1, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %13 = memref.load %5[%c0, %c0, %c2, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %14 = vector.broadcast %13 : f16 to vector<1xf16>
  %15 = vector.insert %14, %12 [0, 2, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %16 = memref.load %5[%c0, %c0, %c3, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %17 = vector.broadcast %16 : f16 to vector<1xf16>
  %18 = vector.insert %17, %15 [0, 3, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %19 = memref.load %5[%c0, %c0, %c4, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %20 = vector.broadcast %19 : f16 to vector<1xf16>
  %21 = vector.insert %20, %18 [0, 4, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %22 = memref.load %5[%c0, %c0, %c5, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %23 = vector.broadcast %22 : f16 to vector<1xf16>
  %24 = vector.insert %23, %21 [0, 5, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %25 = memref.load %5[%c0, %c0, %c6, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %26 = vector.broadcast %25 : f16 to vector<1xf16>
  %27 = vector.insert %26, %24 [0, 6, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %28 = memref.load %5[%c0, %c0, %c7, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %29 = vector.broadcast %28 : f16 to vector<1xf16>
  %30 = vector.insert %29, %27 [0, 7, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %31 = memref.load %5[%c0, %c0, %c8, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %32 = vector.broadcast %31 : f16 to vector<1xf16>
  %33 = vector.insert %32, %30 [0, 8, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
  %34 = memref.load %5[%c0, %c0, %c9, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %35 = vector.broadcast %34 : f16 to vector<1xf16>
  %36 = vector.insert %35, %33 [0, 9, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>

  (continues on for quite a while ... )

  %442372 = vector.extract %331783[23, 4604, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
  memref.store %442372, %alloc[%c0, %c23, %c4604, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
  %442373 = vector.extract %331783[23, 4605, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
  memref.store %442373, %alloc[%c0, %c23, %c4605, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
  %442374 = vector.extract %331783[23, 4606, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
  memref.store %442374, %alloc[%c0, %c23, %c4606, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
  %442375 = vector.extract %331783[23, 4607, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
  memref.store %442375, %alloc[%c0, %c23, %c4607, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>

I don't really know what to make of this yet, but here is the same dispatch before vectorization in both cases: Failing case:

  // -----// IR Dump After FoldTensorExtractOpPass (iree-codegen-fold-tensor-extract-op) //----- //
func.func @main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %cst_0 = arith.constant 0.000000e+00 : f16
  %0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
  %2 = arith.index_castui %0 : i32 to index
  %3 = arith.index_castui %1 : i32 to index
  %4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  memref.assume_alignment %4, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  %5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %5, 1 : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %alloc = memref.alloc() : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
  %thread_id_x = gpu.thread_id  x
  %thread_id_y = gpu.thread_id  y
  %7 = vector.transfer_read %5[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x], %cst_0 {in_bounds = [true, true, true, true, true]} : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<24x4608x1x1x1xf16>
  %8 = arith.extf %7 : vector<24x4608x1x1x1xf16> to vector<24x4608x1x1x1xf32>
  vector.transfer_write %8, %alloc[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x] {in_bounds = [true, true, true, true, true]} : vector<24x4608x1x1x1xf32>, memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
  gpu.barrier
  %9 = vector.transfer_read %4[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
  %10 = vector.transfer_read %alloc[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>, vector<1xf32>
  %11 = vector.transfer_read %4[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
  %12 = vector.transfer_read %alloc[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>, vector<1xf32>
  %13 = arith.mulf %11, %12 : vector<1xf32>
  %14 = arith.mulf %9, %10 : vector<1xf32>
  %15 = arith.addf %14, %13 : vector<1xf32>
  %16 = arith.truncf %15 : vector<1xf32> to vector<1xf16>
  vector.transfer_write %16, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] {in_bounds = [true]} : vector<1xf16>, memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  return
}

Passing case:

// -----// IR Dump After FoldTensorExtractOpPass (iree-codegen-fold-tensor-extract-op) //----- //
func.func @main$async_dispatch_26_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %c0 = arith.constant 0 : index
  %cst = arith.constant 0.000000e+00 : f32
  %0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
  %2 = arith.index_castui %0 : i32 to index
  %3 = arith.index_castui %1 : i32 to index
  %4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %4, 1 : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
  %5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  memref.assume_alignment %5, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
  %6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  %workgroup_id_y = hal.interface.workgroup.id[1] : index
  %workgroup_id_x = hal.interface.workgroup.id[0] : index
  %thread_id_x = gpu.thread_id  x
  %thread_id_y = gpu.thread_id  y
  %7 = vector.transfer_read %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
  %8 = vector.transfer_read %4[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf32>
  %9 = vector.transfer_read %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
  %10 = vector.transfer_read %4[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf32>
  %11 = arith.mulf %9, %10 : vector<1xf32>
  %12 = arith.mulf %7, %8 : vector<1xf32>
  %13 = arith.addf %12, %11 : vector<1xf32>
  %14 = arith.truncf %13 : vector<1xf32> to vector<1xf16>
  vector.transfer_write %14, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] {in_bounds = [true]} : vector<1xf16>, memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
  return
}
monorimet commented 3 months ago

My best guess is that, in the failing case, the element type conversion on this buffer:

memref.assume_alignment %5, 1 : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>

is being awkwardly fused with the elementwise broadcast (?) which is causing it to take a bad path during vectorization. In the passing case, the type conversion does not happen in the same dispatch. There's some weird reordering happening to the buffers as well, which was observed earlier in the issue thread.

nirvedhmeshram commented 3 months ago

Hi @monorimet just want to confirm there is only one known issue at the moment regarding shared memory exhaustion here, if there are multiple issues we should file separate issues so that we can have a developer assigned to each one.

To narrow the failing IR in your last message better it would help to see a full IR dump after all for main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16 dispatch. To do this you can compile the model with the flags, --iree-hal-dump-executable-sources-to=$SOME_PATH and then at $SOME_PATH you will have main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16.mlir which you can again compile with the following flag and pipe that output to a file something like --mlir-print-ir-after-all &> output_dump.txt and then upload that dump somewhere and we can take a look at it.

monorimet commented 3 months ago

@nirvedhmeshram there is only one issue tracked here. I am going through all this because we have a success case and a failure case -- fail case is shared memory error, success case being full model compilation succeeds.

Success case is: use --iree-opt-outer-dim-concat=true and inject trace tensors for the attention op's query and key inputs.

We cannot use the succeeding case because trace tensors aren't something you want to use outside of debugging.

So, I've been trying to find out why we only compile successfully with the above criteria -- since the dispatch_21 is different based on whether --iree-opt-outer-dim-concat=true is used, I have been comparing two IR dumps for the model excerpt. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):

https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump_fail.txt (has trace tensors, no outer dim concat flag used. This is only interesting apart from the following IR because it shows that the outer dim concatenation toggles the success case with trace tensors.)

https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump_fail_2.txt (no trace tensors, with outer dim concat flag. This is the configuration that we want to fix.)

https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump.txt (success, using trace tensors and outer dim concat flag.)

nirvedhmeshram commented 3 months ago

. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):

You can still isolate to the failing dispatch from the failing case and produce a smaller dump with what I described in the previous message.

But the dumps you shared are good to start with too. Will take a look tomorrow.

monorimet commented 3 months ago

. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):

You can still isolate to the failing dispatch from the failing case and produce a smaller dump with what I described in the previous message.

But the dumps you shared are good to start with too. Will take a look tomorrow.

Gotcha. Without trace tensors, the failing dispatch index moves to 18 -- so here's that IR: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/compiled_attn_main_async_dispatch_18.mlir

My compile command:

iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)' --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-opt-outer-dim-concat=true compiled_attn_main_async_dispatch_18.mlir --mlir-print-ir-after-all &> dispatch_18_ir_dump_fail.txt

and here's the IR dump:

https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/dispatch_18_ir_dump_fail.txt

monorimet commented 3 months ago

Given that (basically) the same IR is generated in the two failing configurations (no trace tensors + outer dim concat & traced + no outer dim concat) I figure that somehow the outer dim concatenation is not working correctly without trace tensors for this IR -- the IR for the pass and fail cases start to diverge in the IR dumps at iree-global-opt-decompose-concat

This may not be the most methodical way to find root cause, or suggest any solutions -- perhaps the failing dispatch IR is OK and vectorization needs to be fixed, but it seems to me that the failing dispatch isn't ideal code to begin with, and perhaps it would be best to fix whatever is causing the "bad" dispatch to be formed. The odd behavior where adding trace tensors causes --iree-opt-outer-dim-concat to actually do something seems symptomatic of a pre-dispatch codegen issue, but I don't know why throwing in trace tensors would make a difference besides naively suspecting that the trace tensors prevent some bad fusion.

IanWood1 commented 3 months ago

We were talking about this during codegen sync and I think Mahesh will add more info. But this seems related to what https://github.com/llvm/torch-mlir/pull/3483 is trying to address.

MaheshRavishankar commented 3 months ago

@monorimet adding tracing changes the way fusion works (at least today, maybe we can make that not be the case). Sorry I wasnt online over the weekend and looks like you tried a bunch of different things. Reading through, the most important point is

        %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
        ^bb0(%in: f16, %out: f32):
          %14 = arith.extf %in : f16 to f32
          linalg.yield %14 : f32
        } -> tensor<1x24x4608x64x1x2xf32>
        %extracted_slice = tensor.extract_slice %12[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
        %extracted_slice_0 = tensor.extract_slice %12[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
        %13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %extracted_slice, %10, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
        ^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f16):
          %14 = arith.mulf %in_2, %in_3 : f32
          %15 = arith.mulf %in, %in_1 : f32
          %16 = arith.addf %15, %14 : f32
          %17 = arith.truncf %16 : f32 to f16
          linalg.yield %17 : f16
        } -> tensor<24x4608x64x2xf16>

kind of IR (not even thinking of dispatches, but just lowering form IR) for the backend is complicated to generate code for. (Side note: We should fix dispatch region formation to not create such dispatches). but instead of the above we should lower to

        %12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
        ^bb0(%in: f16, %out: f32):
          %14 = arith.extf %in : f16 to f32
          linalg.yield %14 : f32
        } -> tensor<24x4608x64x2xf32>
        %13 = linalg.generic 
            ins(%11, %12, %10: tensor<4608x64x2xf32>, tensor<1x24x4608x64x1x2xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) {...}

The slicing of the inner most dimension that are made into separate tensors to feed into another operation is a bad lowering sequence. The iree-opt-outer-dim-concat is bit of a red-herring cause that is meant for cases where concatenation is happening on the inner most dim, but this is not what is happening here (it has some similar artifacts, but its not a good match).

Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.

monorimet commented 3 months ago

Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.

That clarifies it quite a bit, thanks. My efforts were just me flapping around trying to find something useful ahead of someone else picking this up -- call it an educational exercise. I will take a shot at hand-editing the linalg IR while the torchtolinalg fix lands.

MaheshRavishankar commented 3 months ago

Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.

That clarifies it quite a bit, thanks. My efforts were just me flapping around trying to find something useful ahead of someone else picking this up -- call it an educational exercise. I will take a shot at hand-editing the linalg IR while the torchtolinalg fix lands.

If you want I can write what the lowering should look like (or if you take a stab and post it here, I can verify)

monorimet commented 3 months ago

If you want I can write what the lowering should look like (or if you take a stab and post it here, I can verify)

No, that's okay -- I'll give it a shot

monorimet commented 3 months ago

@MaheshRavishankar I'm a little confused about how the linalg lowering leads to this bad dispatch -- PTAL:

This is the problematic bit of flow IR that is generated (you've seen this already, I just attached debug info to see corresponding source lines):

%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<4608x1x24x64x1x2xf16>) outs(%2 : tensor<1x24x4608x64x1x2xf32>) {
^bb0(%in: f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":395:12), %out: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)):
  %7 = arith.extf %in : f16 to f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
  linalg.yield %7 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
} -> tensor<1x24x4608x64x1x2xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
%extracted_slice = tensor.extract_slice %5[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":427:12)
%extracted_slice_0 = tensor.extract_slice %5[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":434:12)
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4, %extracted_slice, %3, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%1 : tensor<24x4608x64x2xf16>) {
^bb0(%in: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":424:12), %in_1: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":427:12), %in_2: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":431:12), %in_3: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":434:12), %out: f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)):
  %7 = arith.mulf %in_2, %in_3 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":435:12)
  %8 = arith.mulf %in, %in_1 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":428:12)
  %9 = arith.addf %8, %7 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":437:12)
  %10 = arith.truncf %9 : f32 to f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)
  linalg.yield %10 : f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)
} -> tensor<24x4608x64x2xf16> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)

In the flow IR, the debuginfo points to these lines in the source IR:

 %156 = torch.prims.convert_element_type %151, %int6_185 : !torch.vtensor<[1,24,4608,128],f16>, !torch.int -> !torch.vtensor<[1,24,4608,128],f32>
 <...>
 %163 = torch.aten.select.int %158, %int5_200, %int0_201 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32>
 <...>
 %166 = torch.aten.select.int %158, %int5_204, %int1_205 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32>
<...>
 %178 = torch.prims.convert_element_type %177, %int5_220 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>

For reference, here is an excerpt of the linalg IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/590bff3dffe648b70541b20b9ee8d76843324e0b/pattern_linalg.mlir Full linalg IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/9b32448dd959b30b24443e4e86f3d8b50eddcfdd/flux_sampler_attn_repro_fp16_linalg.mlir And the source IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/590bff3dffe648b70541b20b9ee8d76843324e0b/flux_sampler_attn_repro_fp16.mlir

so I'm wondering -- is it one torch to linalg lowering (i.e. the element type conversion) that is responsible, or a pattern of torch.aten.select.ints followed by the truncf? In that case, is it really a torchtolinalg issue or is this just an ugly pattern to begin with (in torch dialect)?

Also, since we want to do a batch size sweep for this model, hand-editing the IR is not sufficient. Goal would be, as suggested, to fix the torchtolinalg lowering, but I am missing some context here and still haven't really figured out what needs to be changed.

Perhaps it would help after all to know exactly what we should be lowering this to... less room for me to get confused that way.

monorimet commented 3 months ago

@IanWood1 here is what I believe is responsible for this: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py#L6

As suspected, I was able to minimize further by restricting to the apply_rope function in the linked flux/math.py file: https://gist.github.com/monorimet/d39a0cfa7d9307e91a61da209452e2f6/raw/dc1b12ece62b75cd01db7b99615a95cc2019b6a0/flux_apply_rope_torch.mlir

monorimet commented 3 months ago

@IanWood1 Not sure if fix was ready to try, but I have an updated IR dump and reproducer here: https://gist.github.com/monorimet/3a0a4310c1ed09265353ce747599d502

On version;

commit 02c4a74931e403a414db5c0c8ac7085d0a91189e (HEAD, ianwood/extract_slice_prop)

tl'dr it seems the BubbleUpExtractThroughDequantize did not have an effect on the problematic code; the before and after are identical.

How can I help here? I have cycles to offer.

monorimet commented 3 months ago

In the meantime, I have found a workaround in the torch model code -- simply removing the cast to float in these lines: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py#L26-27 gives identical results in pytorch and prevents us having to deal with this troublesome pattern. That at least makes this less blocking, but naturally supporting cases like this would be better in the long run.

MaheshRavishankar commented 2 months ago

https://github.com/iree-org/iree/pull/18332 should fix this. I tried the small repro from here and I get

#map = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {                                                                                                                                                                                                           util.func public @repro(%arg0: tensor<4608x1x24x64x1x2xf16>, %arg1: tensor<4608x64x2xf32>, %arg2: tensor<4608x64x2xf32>) -> tensor<24x4608x64x2xf16> {                                                             %0 = tensor.empty() : tensor<1x24x4608x64x1x2xf32>
    %extracted_slice = tensor.extract_slice %arg0[0, 0, 0, 0, 0, 0] [4608, 1, 24, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<4608x1x24x64x1x2xf16> to tensor<4608x24x64xf16>
    %extracted_slice_0 = tensor.extract_slice %0[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
    %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<4608x24x64xf16>) outs(%extracted_slice_0 : tensor<24x4608x64xf32>) \{
    ^bb0(%in: f16, %out: f32):
      %5 = arith.extf %in : f16 to f32
      linalg.yield %5 : f32
    } -> tensor<24x4608x64xf32>
    %extracted_slice_1 = tensor.extract_slice %arg0[0, 0, 0, 0, 0, 1] [4608, 1, 24, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<4608x1x24x64x1x2xf16> to tensor<4608x24x64xf16>
    %extracted_slice_2 = tensor.extract_slice %0[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1 : tensor<4608x24x64xf16>) outs(%extracted_slice_2 : tensor<24x4608x64xf32>\) {
    ^bb0(%in: f16, %out: f32):
      %5 = arith.extf %in : f16 to f32
      linalg.yield %5 : f32
    } -> tensor<24x4608x64xf32>
    %3 = tensor.empty() : tensor<24x4608x64x2xf16>
    %4 = linalg.generic {indexing_maps = [#map2, #map3, #map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1, %1, %arg2, %2 : tensor<4608x64x2xf32>, tensor<24x460\8x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%3 : tensor<24x4608x64x2xf16>) {
    ^bb0(%in: f32, %in_3: f32, %in_4: f32, %in_5: f32, %out: f16):
      %5 = arith.mulf %in_4, %in_5 : f32
      %6 = arith.mulf %in, %in_3 : f32
      %7 = arith.addf %6, %5 : f32
      %8 = arith.truncf %7 : f32 to f16
      linalg.yield %8 : f16
    } -> tensor<24x4608x64x2xf16>
    util.return %4 : tensor<24x4608x64x2xf16>
  }
}

which seems right

MaheshRavishankar commented 2 months ago

Tried compiling the original repro with #18332 . It now hits the same error as #18325