Closed monorimet closed 2 months ago
@nithinsubbiah or @andfau-amd do one of you want to take this. cc @nirvedhmeshram for tracking.
The failure is coming from this op in the torch dialect IR:
%394:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%390, %393, %367, %float0.000000e00, %false_423, %none_424, %none_425) : (!torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608,128],bf16>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,24,4608,128],bf16>, !torch.vtensor<[1,24,4608],f32>)
which is doing bf16 -> bf16/bf32; do we have any lowering to mfma/rocdl that handles this dtype configuration? I'm not sure how it will translate through attention codegen
is it right to expect that this may take a nontrivial amount of work? I want to contribute here but don't have a clear picture on current state of amdgpu/mfma codegen and might be better off looking for workarounds at the modeling level in the meantime.
I managed to export the model with fp16 precision and ran into the same compile issue with shared memory:
<stdin>:1656:12: error: 'func.func' op uses 56623104 bytes of shared memory; exceeded the limit of 65536 bytes
%390 = torch.prims.convert_element_type %389, %int5_418 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>
^
<stdin>:1656:12: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>
%390 = torch.prims.convert_element_type %389, %int5_418 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>
^
FP16 flux sampler MLIR -- Azure
This is the elementwise broadcast dispatch (in fp16) on which the shared memory issue occurs:
hal.executable public @run_forward$async_dispatch_36 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>) {
hal.executable.export public @run_forward$async_dispatch_36_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @run_forward$async_dispatch_36_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
%c28391424 = arith.constant 28391424 : index
%0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c28391424) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>>
%5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4608x1x24x64x1x2xf16>>
%6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%3) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>>
%7 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 0], sizes = [4608, 1, 24, 64, 1, 2], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<4608x1x24x64x1x2xf16>> -> tensor<4608x1x24x64x1x2xf16>
%8 = tensor.empty() : tensor<24x4608x64x2xf16>
%9 = tensor.empty() : tensor<1x24x4608x64x1x2xf32>
%10 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32>
%11 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32>
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
^bb0(%in: f16, %out: f32):
%14 = arith.extf %in : f16 to f32
linalg.yield %14 : f32
} -> tensor<1x24x4608x64x1x2xf32>
%extracted_slice = tensor.extract_slice %12[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%extracted_slice_0 = tensor.extract_slice %12[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %extracted_slice, %10, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f16):
%14 = arith.mulf %in_2, %in_3 : f32
%15 = arith.mulf %in, %in_1 : f32
%16 = arith.addf %15, %14 : f32
%17 = arith.truncf %16 : f32 to f16
linalg.yield %17 : f16
} -> tensor<24x4608x64x2xf16>
flow.dispatch.tensor.store %13, %6, offsets = [0, 0, 0, 0], sizes = [24, 4608, 64, 2], strides = [1, 1, 1, 1] : tensor<24x4608x64x2xf16> -> !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>>
return
}
}
}
}
Cc @IanWood1 . I think he has seen and fixed a similar issue previously. I don't know how he did it though. But extracting a long the innermost dimension this way is an anitpattern IMO. Maybe it is something like a complex number being handled as two separate floats?
The shared memory issue does not occur and the full model compiles successfully if the q
,k
,v
values are traced with flow.tensor_trace
ops.
Here is a link to the "traced" .mlir that compiles successfully: flux_dev_bs1_1024x1024_fp16_sampler_traced.mlir
Last time tracing around an op fixed a compile issue, it was with SD3 -- tracing prevented some fusion from being applied. We may have a similar issue here, I'll take a closer look.
In the traced module's dispatches, I located the same dispatch that originally failed:
hal.executable public @run_forward$async_dispatch_42 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>, ukernels = "none"}>) {
hal.executable.export public @run_forward$async_dispatch_42_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16 ordinal(0) layout(#hal.pipeline.layout<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>]} {
^bb0(%arg0: !hal.device loc("<stdin>":1773:12)):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice loc("<stdin>":1773:12)
hal.return %x, %y, %z : index, index, index loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)
builtin.module {
func.func @run_forward$async_dispatch_42_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() {
%c28311552 = arith.constant 28311552 : index loc(unknown)
%0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32 loc("<stdin>":1723:12)
%1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32 loc("<stdin>":1773:12)
%2 = arith.index_castui %0 : i32 to index loc("<stdin>":1723:12)
%3 = arith.index_castui %1 : i32 to index loc("<stdin>":1773:12)
%4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> loc("<stdin>":1723:12)
%5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c28311552) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> loc("<stdin>":1221:12)
%6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>> loc("<stdin>":1773:12)
%7 = tensor.empty() : tensor<24x4608x64x2xf16> loc("<stdin>":1773:12)
%8 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 24, 4608, 64, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> -> tensor<24x4608x64xf32> loc("<stdin>":1746:12)
%9 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 1], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32> loc("<stdin>":1743:12)
%10 = flow.dispatch.tensor.load %4, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 24, 4608, 64, 1, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x24x4608x64x1x2xf32>> -> tensor<24x4608x64xf32> loc("<stdin>":1739:12)
%11 = flow.dispatch.tensor.load %5, offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 4608, 64, 2, 1], strides = [1, 1, 1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x1x4608x64x2x2xf32>> -> tensor<4608x64x2xf32> loc("<stdin>":1736:12)
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %10, %9, %8 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%7 : tensor<24x4608x64x2xf16>) {
^bb0(%in: f32 loc("<stdin>":1736:12), %in_0: f32 loc("<stdin>":1739:12), %in_1: f32 loc("<stdin>":1743:12), %in_2: f32 loc("<stdin>":1746:12), %out: f16 loc("<stdin>":1773:12)):
%13 = arith.mulf %in_1, %in_2 : f32 loc("<stdin>":1747:12)
%14 = arith.mulf %in, %in_0 : f32 loc("<stdin>":1740:12)
%15 = arith.addf %14, %13 : f32 loc("<stdin>":1749:12)
%16 = arith.truncf %15 : f32 to f16 loc("<stdin>":1773:12)
linalg.yield %16 : f16 loc("<stdin>":1773:12)
} -> tensor<24x4608x64x2xf16> loc("<stdin>":1773:12)
flow.dispatch.tensor.store %12, %6, offsets = [0, 0, 0, 0], sizes = [24, 4608, 64, 2], strides = [1, 1, 1, 1] : tensor<24x4608x64x2xf16> -> !flow.dispatch.tensor<writeonly:tensor<24x4608x64x2xf16>> loc("<stdin>":1773:12)
return loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)
} loc("<stdin>":1773:12)
For some reason, without the traced ops, this elementwise broadcast has a few things that stick out:
arith.constant 28311552
for the buffer load is different (and matches 1/2 the shared memory reported by the original error) -- I don't know what this means or if it's usefulI managed to isolate a smaller reproducer with just the double-stream block from the flux implementation.
Here is the resulting flux_attn_repro_fp16.mlir
that reproduces the failure: Azure
Compile command:
iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics) --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir --iree-codegen-llvmgpu-enable-transform-dialect-jit=false flux_attn_repro_fp16.mlir -o out.vmfb
This command and MLIR compiles successfully if the --iree-opt-outer-dim-concat=true
flag is added.
I am assuming therefore that said flag must have something to do with the discrepancy between the pass and fail cases.
More notes:
q
, k
, v
) inputs to sdpa here, only the query and key tensors need to be traced for compilation to succeed. These two, unlike the value tensor, need to be converted from f32 to f16.I see where something seems to be blowing up after comparing the print-after-all of the two smaller cases described in the previous comment. With outer dims concatenated, we get a more reasonable output from vectorization which compiles successfully:
// -----// IR Dump After LLVMGPUVectorLoweringPass (iree-llvmgpu-vector-lowering) //----- //
func.func @main$async_dispatch_26_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
%cst = arith.constant dense<0.000000e+00> : vector<1xf32>
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %4, 1 : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
memref.assume_alignment %5, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
%6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%7 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
%16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_x]
%17 = memref.load %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %16, %c0] : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
%18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
scf.yield %18 : vector<1xf32>
}
%8 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
%16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_y]
%17 = memref.load %4[%c0, %workgroup_id_y, %workgroup_id_x, %16, %c0, %c0] : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
scf.yield %18 : vector<1xf32>
}
%9 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
%16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_x]
%17 = memref.load %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %16, %c1] : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
%18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
scf.yield %18 : vector<1xf32>
}
%10 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (vector<1xf32>) {
%16 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg0)[%thread_id_y]
%17 = memref.load %4[%c0, %workgroup_id_y, %workgroup_id_x, %16, %c0, %c1] : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%18 = vector.insertelement %17, %arg1[%arg0 : index] : vector<1xf32>
scf.yield %18 : vector<1xf32>
}
%11 = arith.mulf %9, %10 : vector<1xf32>
%12 = arith.mulf %7, %8 : vector<1xf32>
%13 = arith.addf %12, %11 : vector<1xf32>
%14 = arith.truncf %13 : vector<1xf32> to vector<1xf16>
%15 = vector.extract %14[0] : f16 from vector<1xf16>
memref.store %15, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
return
}
Without concatenating outer dims, we get the same dispatch result, but prepended with arith.constant index
es counting down from 4608, and then after setting thread ids as in the above IR, the failing (without outer dim concat) case starts doing what looks like a botched elementwise broadcast -- a load/broadcast/insert pattern that repeats for some 350000+ lines followed by a similarly long chain of vector.extract
s:
%7 = memref.load %5[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%8 = vector.broadcast %7 : f16 to vector<1xf16>
%9 = vector.insert %8, %cst_0 [0, 0, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%10 = memref.load %5[%c0, %c0, %c1, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%11 = vector.broadcast %10 : f16 to vector<1xf16>
%12 = vector.insert %11, %9 [0, 1, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%13 = memref.load %5[%c0, %c0, %c2, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%14 = vector.broadcast %13 : f16 to vector<1xf16>
%15 = vector.insert %14, %12 [0, 2, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%16 = memref.load %5[%c0, %c0, %c3, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%17 = vector.broadcast %16 : f16 to vector<1xf16>
%18 = vector.insert %17, %15 [0, 3, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%19 = memref.load %5[%c0, %c0, %c4, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%20 = vector.broadcast %19 : f16 to vector<1xf16>
%21 = vector.insert %20, %18 [0, 4, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%22 = memref.load %5[%c0, %c0, %c5, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%23 = vector.broadcast %22 : f16 to vector<1xf16>
%24 = vector.insert %23, %21 [0, 5, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%25 = memref.load %5[%c0, %c0, %c6, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%26 = vector.broadcast %25 : f16 to vector<1xf16>
%27 = vector.insert %26, %24 [0, 6, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%28 = memref.load %5[%c0, %c0, %c7, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%29 = vector.broadcast %28 : f16 to vector<1xf16>
%30 = vector.insert %29, %27 [0, 7, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%31 = memref.load %5[%c0, %c0, %c8, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%32 = vector.broadcast %31 : f16 to vector<1xf16>
%33 = vector.insert %32, %30 [0, 8, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
%34 = memref.load %5[%c0, %c0, %c9, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%35 = vector.broadcast %34 : f16 to vector<1xf16>
%36 = vector.insert %35, %33 [0, 9, 0, 0] : vector<1xf16> into vector<24x4608x1x1x1xf16>
(continues on for quite a while ... )
%442372 = vector.extract %331783[23, 4604, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
memref.store %442372, %alloc[%c0, %c23, %c4604, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
%442373 = vector.extract %331783[23, 4605, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
memref.store %442373, %alloc[%c0, %c23, %c4605, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
%442374 = vector.extract %331783[23, 4606, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
memref.store %442374, %alloc[%c0, %c23, %c4606, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
%442375 = vector.extract %331783[23, 4607, 0, 0, 0] : f32 from vector<24x4608x1x1x1xf32>
memref.store %442375, %alloc[%c0, %c23, %c4607, %thread_id_y, %c0, %thread_id_x] : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
I don't really know what to make of this yet, but here is the same dispatch before vectorization in both cases: Failing case:
// -----// IR Dump After FoldTensorExtractOpPass (iree-codegen-fold-tensor-extract-op) //----- //
func.func @main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_0 = arith.constant 0.000000e+00 : f16
%0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
memref.assume_alignment %4, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
%5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %5, 1 : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%alloc = memref.alloc() : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%7 = vector.transfer_read %5[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x], %cst_0 {in_bounds = [true, true, true, true, true]} : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<24x4608x1x1x1xf16>
%8 = arith.extf %7 : vector<24x4608x1x1x1xf16> to vector<24x4608x1x1x1xf32>
vector.transfer_write %8, %alloc[%c0, %c0, %c0, %thread_id_y, %c0, %thread_id_x] {in_bounds = [true, true, true, true, true]} : vector<24x4608x1x1x1xf32>, memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>
gpu.barrier
%9 = vector.transfer_read %4[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
%10 = vector.transfer_read %alloc[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>, vector<1xf32>
%11 = vector.transfer_read %4[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
%12 = vector.transfer_read %alloc[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, #gpu.address_space<workgroup>>, vector<1xf32>
%13 = arith.mulf %11, %12 : vector<1xf32>
%14 = arith.mulf %9, %10 : vector<1xf32>
%15 = arith.addf %14, %13 : vector<1xf32>
%16 = arith.truncf %15 : vector<1xf32> to vector<1xf16>
vector.transfer_write %16, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] {in_bounds = [true]} : vector<1xf16>, memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
return
}
Passing case:
// -----// IR Dump After FoldTensorExtractOpPass (iree-codegen-fold-tensor-extract-op) //----- //
func.func @main$async_dispatch_26_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16() attributes {translation_info = #iree_codegen.translation_info<LLVMGPUVectorize workgroup_size = [2, 64, 1] subgroup_size = 64>} {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(0) : i32
%1 = hal.interface.constant.load layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) ordinal(1) : i32
%2 = arith.index_castui %0 : i32 to index
%3 = arith.index_castui %1 : i32 to index
%4 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(0) alignment(64) offset(%2) flags("ReadOnly|Indirect") : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %4, 1 : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
%5 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
memref.assume_alignment %5, 64 : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>
%6 = hal.interface.binding.subspan layout(<push_constants = 2, sets = [<0, bindings = [<0, storage_buffer, "ReadOnly|Indirect">, <1, storage_buffer, "ReadOnly|Indirect">, <2, storage_buffer, Indirect>], flags = Indirect>]>) set(0) binding(2) alignment(64) offset(%3) flags(Indirect) : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
memref.assume_alignment %6, 1 : memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%7 = vector.transfer_read %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
%8 = vector.transfer_read %4[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c0], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf32>
%9 = vector.transfer_read %5[%c0, %c0, %workgroup_id_x, %thread_id_y, %thread_id_x, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4)>} : memref<1x1x4608x64x2x2xf32, #gpu.address_space<global>>, vector<1xf32>
%10 = vector.transfer_read %4[%c0, %workgroup_id_y, %workgroup_id_x, %thread_id_y, %c0, %c1], %cst {in_bounds = [true], permutation_map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3)>} : memref<1x24x4608x64x1x2xf32, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>, vector<1xf32>
%11 = arith.mulf %9, %10 : vector<1xf32>
%12 = arith.mulf %7, %8 : vector<1xf32>
%13 = arith.addf %12, %11 : vector<1xf32>
%14 = arith.truncf %13 : vector<1xf32> to vector<1xf16>
vector.transfer_write %14, %6[%workgroup_id_y, %workgroup_id_x, %thread_id_y, %thread_id_x] {in_bounds = [true]} : vector<1xf16>, memref<24x4608x64x2xf16, strided<[589824, 128, 2, 1], offset: ?>, #gpu.address_space<global>>
return
}
My best guess is that, in the failing case, the element type conversion on this buffer:
memref.assume_alignment %5, 1 : memref<1x24x4608x64x1x2xf16, strided<[14155776, 589824, 128, 2, 2, 1], offset: ?>, #gpu.address_space<global>>
is being awkwardly fused with the elementwise broadcast (?) which is causing it to take a bad path during vectorization. In the passing case, the type conversion does not happen in the same dispatch. There's some weird reordering happening to the buffers as well, which was observed earlier in the issue thread.
Hi @monorimet just want to confirm there is only one known issue at the moment regarding shared memory exhaustion here, if there are multiple issues we should file separate issues so that we can have a developer assigned to each one.
To narrow the failing IR in your last message better it would help to see a full IR dump after all for main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16
dispatch. To do this you can compile the model with the flags, --iree-hal-dump-executable-sources-to=$SOME_PATH
and then at $SOME_PATH you will have main$async_dispatch_21_elementwise_broadcast_24x4608x64x2_f32xf32xf32xf32xf16.mlir
which you can again compile with the following flag and pipe that output to a file something like --mlir-print-ir-after-all &> output_dump.txt
and then upload that dump somewhere and we can take a look at it.
@nirvedhmeshram there is only one issue tracked here. I am going through all this because we have a success case and a failure case -- fail case is shared memory error, success case being full model compilation succeeds.
Success case is: use --iree-opt-outer-dim-concat=true
and inject trace tensors for the attention op's query
and key
inputs.
We cannot use the succeeding case because trace tensors aren't something you want to use outside of debugging.
So, I've been trying to find out why we only compile successfully with the above criteria -- since the dispatch_21
is different based on whether --iree-opt-outer-dim-concat=true
is used, I have been comparing two IR dumps for the model excerpt. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump_fail.txt (has trace tensors, no outer dim concat flag used. This is only interesting apart from the following IR because it shows that the outer dim concatenation toggles the success case with trace tensors.)
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump_fail_2.txt (no trace tensors, with outer dim concat flag. This is the configuration that we want to fix.)
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/doublestream_dump.txt (success, using trace tensors and outer dim concat flag.)
. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):
You can still isolate to the failing dispatch from the failing case and produce a smaller dump with what I described in the previous message.
But the dumps you shared are good to start with too. Will take a look tomorrow.
. Looking at the dump for the individual dispatch doesn't seem to help since dispatch formation plays a part in whether we run into the shared memory issue. Please take a look at the following IR dumps (they are large):
You can still isolate to the failing dispatch from the failing case and produce a smaller dump with what I described in the previous message.
But the dumps you shared are good to start with too. Will take a look tomorrow.
Gotcha. Without trace tensors, the failing dispatch index moves to 18 -- so here's that IR: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/compiled_attn_main_async_dispatch_18.mlir
My compile command:
iree-compile --iree-input-type=torch --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=rocm --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-global-opt-propagate-transposes=true --iree-opt-const-eval=false --iree-llvmgpu-enable-prefetch=true --iree-execution-model=async-external --iree-preprocessing-pass-pipeline='builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)' --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir --iree-codegen-llvmgpu-enable-transform-dialect-jit=false --iree-opt-outer-dim-concat=true compiled_attn_main_async_dispatch_18.mlir --mlir-print-ir-after-all &> dispatch_18_ir_dump_fail.txt
and here's the IR dump:
https://sharkpublic.blob.core.windows.net/sharkpublic/ean/flux-dev/dispatch_18_ir_dump_fail.txt
Given that (basically) the same IR is generated in the two failing configurations (no trace tensors + outer dim concat & traced + no outer dim concat) I figure that somehow the outer dim concatenation is not working correctly without trace tensors for this IR -- the IR for the pass and fail cases start to diverge in the IR dumps at iree-global-opt-decompose-concat
This may not be the most methodical way to find root cause, or suggest any solutions -- perhaps the failing dispatch IR is OK and vectorization needs to be fixed, but it seems to me that the failing dispatch isn't ideal code to begin with, and perhaps it would be best to fix whatever is causing the "bad" dispatch to be formed. The odd behavior where adding trace tensors causes --iree-opt-outer-dim-concat
to actually do something seems symptomatic of a pre-dispatch codegen issue, but I don't know why throwing in trace tensors would make a difference besides naively suspecting that the trace tensors prevent some bad fusion.
We were talking about this during codegen sync and I think Mahesh will add more info. But this seems related to what https://github.com/llvm/torch-mlir/pull/3483 is trying to address.
@monorimet adding tracing changes the way fusion works (at least today, maybe we can make that not be the case). Sorry I wasnt online over the weekend and looks like you tried a bunch of different things. Reading through, the most important point is
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
^bb0(%in: f16, %out: f32):
%14 = arith.extf %in : f16 to f32
linalg.yield %14 : f32
} -> tensor<1x24x4608x64x1x2xf32>
%extracted_slice = tensor.extract_slice %12[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%extracted_slice_0 = tensor.extract_slice %12[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%13 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%11, %extracted_slice, %10, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
^bb0(%in: f32, %in_1: f32, %in_2: f32, %in_3: f32, %out: f16):
%14 = arith.mulf %in_2, %in_3 : f32
%15 = arith.mulf %in, %in_1 : f32
%16 = arith.addf %15, %14 : f32
%17 = arith.truncf %16 : f32 to f16
linalg.yield %17 : f16
} -> tensor<24x4608x64x2xf16>
kind of IR (not even thinking of dispatches, but just lowering form IR) for the backend is complicated to generate code for. (Side note: We should fix dispatch region formation to not create such dispatches). but instead of the above we should lower to
%12 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%7 : tensor<4608x1x24x64x1x2xf16>) outs(%9 : tensor<1x24x4608x64x1x2xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 64, 0]]>} {
^bb0(%in: f16, %out: f32):
%14 = arith.extf %in : f16 to f32
linalg.yield %14 : f32
} -> tensor<24x4608x64x2xf32>
%13 = linalg.generic
ins(%11, %12, %10: tensor<4608x64x2xf32>, tensor<1x24x4608x64x1x2xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%8 : tensor<24x4608x64x2xf16>) {...}
The slicing of the inner most dimension that are made into separate tensors to feed into another operation is a bad lowering sequence. The iree-opt-outer-dim-concat
is bit of a red-herring cause that is meant for cases where concatenation is happening on the inner most dim, but this is not what is happening here (it has some similar artifacts, but its not a good match).
Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.
Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.
That clarifies it quite a bit, thanks. My efforts were just me flapping around trying to find something useful ahead of someone else picking this up -- call it an educational exercise. I will take a shot at hand-editing the linalg IR while the torchtolinalg fix lands.
Sorry you tried a whole bunch of different things and ran into various issues, but the easiest fix for me is to fix the lowering to Linalg. I think recovering from the lowering in the compiler is much harder.
That clarifies it quite a bit, thanks. My efforts were just me flapping around trying to find something useful ahead of someone else picking this up -- call it an educational exercise. I will take a shot at hand-editing the linalg IR while the torchtolinalg fix lands.
If you want I can write what the lowering should look like (or if you take a stab and post it here, I can verify)
If you want I can write what the lowering should look like (or if you take a stab and post it here, I can verify)
No, that's okay -- I'll give it a shot
@MaheshRavishankar I'm a little confused about how the linalg lowering leads to this bad dispatch -- PTAL:
This is the problematic bit of flow IR that is generated (you've seen this already, I just attached debug info to see corresponding source lines):
%5 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d1, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<4608x1x24x64x1x2xf16>) outs(%2 : tensor<1x24x4608x64x1x2xf32>) {
^bb0(%in: f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":395:12), %out: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)):
%7 = arith.extf %in : f16 to f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
linalg.yield %7 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
} -> tensor<1x24x4608x64x1x2xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":403:12)
%extracted_slice = tensor.extract_slice %5[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":427:12)
%extracted_slice_0 = tensor.extract_slice %5[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":434:12)
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%4, %extracted_slice, %3, %extracted_slice_0 : tensor<4608x64x2xf32>, tensor<24x4608x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%1 : tensor<24x4608x64x2xf16>) {
^bb0(%in: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":424:12), %in_1: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":427:12), %in_2: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":431:12), %in_3: f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":434:12), %out: f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)):
%7 = arith.mulf %in_2, %in_3 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":435:12)
%8 = arith.mulf %in, %in_1 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":428:12)
%9 = arith.addf %8, %7 : f32 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":437:12)
%10 = arith.truncf %9 : f32 to f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)
linalg.yield %10 : f16 loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)
} -> tensor<24x4608x64x2xf16> loc("/home/eagarvey/SHARK-Turbine/models/turbine_models/custom_models/flux/flux_sampler_attn_repro_fp16.mlir":461:12)
In the flow IR, the debuginfo points to these lines in the source IR:
%156 = torch.prims.convert_element_type %151, %int6_185 : !torch.vtensor<[1,24,4608,128],f16>, !torch.int -> !torch.vtensor<[1,24,4608,128],f32>
<...>
%163 = torch.aten.select.int %158, %int5_200, %int0_201 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32>
<...>
%166 = torch.aten.select.int %158, %int5_204, %int1_205 : !torch.vtensor<[1,24,4608,64,1,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,24,4608,64,1],f32>
<...>
%178 = torch.prims.convert_element_type %177, %int5_220 : !torch.vtensor<[1,24,4608,128],f32>, !torch.int -> !torch.vtensor<[1,24,4608,128],f16>
For reference, here is an excerpt of the linalg IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/590bff3dffe648b70541b20b9ee8d76843324e0b/pattern_linalg.mlir Full linalg IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/9b32448dd959b30b24443e4e86f3d8b50eddcfdd/flux_sampler_attn_repro_fp16_linalg.mlir And the source IR: https://gist.githubusercontent.com/monorimet/e7052186f2be9c947f1022627e455d5e/raw/590bff3dffe648b70541b20b9ee8d76843324e0b/flux_sampler_attn_repro_fp16.mlir
so I'm wondering -- is it one torch to linalg lowering (i.e. the element type conversion) that is responsible, or a pattern of torch.aten.select.int
s followed by the truncf?
In that case, is it really a torchtolinalg issue or is this just an ugly pattern to begin with (in torch dialect)?
Also, since we want to do a batch size sweep for this model, hand-editing the IR is not sufficient. Goal would be, as suggested, to fix the torchtolinalg lowering, but I am missing some context here and still haven't really figured out what needs to be changed.
Perhaps it would help after all to know exactly what we should be lowering this to... less room for me to get confused that way.
@IanWood1 here is what I believe is responsible for this: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py#L6
As suspected, I was able to minimize further by restricting to the apply_rope
function in the linked flux/math.py
file:
https://gist.github.com/monorimet/d39a0cfa7d9307e91a61da209452e2f6/raw/dc1b12ece62b75cd01db7b99615a95cc2019b6a0/flux_apply_rope_torch.mlir
@IanWood1 Not sure if fix was ready to try, but I have an updated IR dump and reproducer here: https://gist.github.com/monorimet/3a0a4310c1ed09265353ce747599d502
On version;
commit 02c4a74931e403a414db5c0c8ac7085d0a91189e (HEAD, ianwood/extract_slice_prop)
tl'dr it seems the BubbleUpExtractThroughDequantize did not have an effect on the problematic code; the before and after are identical.
How can I help here? I have cycles to offer.
In the meantime, I have found a workaround in the torch model code -- simply removing the cast to float in these lines: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py#L26-27 gives identical results in pytorch and prevents us having to deal with this troublesome pattern. That at least makes this less blocking, but naturally supporting cases like this would be better in the long run.
https://github.com/iree-org/iree/pull/18332 should fix this. I tried the small repro from here and I get
#map = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module { util.func public @repro(%arg0: tensor<4608x1x24x64x1x2xf16>, %arg1: tensor<4608x64x2xf32>, %arg2: tensor<4608x64x2xf32>) -> tensor<24x4608x64x2xf16> { %0 = tensor.empty() : tensor<1x24x4608x64x1x2xf32>
%extracted_slice = tensor.extract_slice %arg0[0, 0, 0, 0, 0, 0] [4608, 1, 24, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<4608x1x24x64x1x2xf16> to tensor<4608x24x64xf16>
%extracted_slice_0 = tensor.extract_slice %0[0, 0, 0, 0, 0, 0] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<4608x24x64xf16>) outs(%extracted_slice_0 : tensor<24x4608x64xf32>) \{
^bb0(%in: f16, %out: f32):
%5 = arith.extf %in : f16 to f32
linalg.yield %5 : f32
} -> tensor<24x4608x64xf32>
%extracted_slice_1 = tensor.extract_slice %arg0[0, 0, 0, 0, 0, 1] [4608, 1, 24, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<4608x1x24x64x1x2xf16> to tensor<4608x24x64xf16>
%extracted_slice_2 = tensor.extract_slice %0[0, 0, 0, 0, 0, 1] [1, 24, 4608, 64, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x24x4608x64x1x2xf32> to tensor<24x4608x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1 : tensor<4608x24x64xf16>) outs(%extracted_slice_2 : tensor<24x4608x64xf32>\) {
^bb0(%in: f16, %out: f32):
%5 = arith.extf %in : f16 to f32
linalg.yield %5 : f32
} -> tensor<24x4608x64xf32>
%3 = tensor.empty() : tensor<24x4608x64x2xf16>
%4 = linalg.generic {indexing_maps = [#map2, #map3, #map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1, %1, %arg2, %2 : tensor<4608x64x2xf32>, tensor<24x460\8x64xf32>, tensor<4608x64x2xf32>, tensor<24x4608x64xf32>) outs(%3 : tensor<24x4608x64x2xf16>) {
^bb0(%in: f32, %in_3: f32, %in_4: f32, %in_5: f32, %out: f16):
%5 = arith.mulf %in_4, %in_5 : f32
%6 = arith.mulf %in, %in_3 : f32
%7 = arith.addf %6, %5 : f32
%8 = arith.truncf %7 : f32 to f16
linalg.yield %8 : f16
} -> tensor<24x4608x64x2xf16>
util.return %4 : tensor<24x4608x64x2xf16>
}
}
which seems right
Tried compiling the original repro with #18332 . It now hits the same error as #18325
What happened?
Compiling a new SD3-like model for gfx942 is introducing a few new things we may need to massage through; the shapes don't seem too out of the ordinary but the graph is in bf16. Given bf16, I'm not sure what to expect w.r.t support in codegen -- I am running into a few compilation failures, which I have tracked down to a few dispatches:
On an elementwise broadcast (
24x4608x64x2xbf16
) dispatch we run into a shared memory issue:And the soon following attention dispatch fails without useful diagnostics:
Steps to reproduce your issue
Download files: attention_and_matmul_spec_mfma.mlir (shared/sdxl_quantized) attention_and_matmul_spec.mlir (main) configured_compiled_flux_sampler_run_forward_async_dispatch_36.mlir configured_compiled_flux_sampler_run_forward_async_dispatch_38.mlir full model if interested -- flux_dev_bs1_512_1024x1024_bf16_sampler.mlir
compile command (shared/sdxl_quantized):
compile command (main):
What component(s) does this issue relate to?
Compiler
Version information
Tested on the following commits: tip of shared/sdxl_quantized tip of main
Additional context
shared memory issue seems unrelated to any changes on shared/sdxl_quantized or TD spec compatibility, as both configurations (iree/main vs. iree/shared/sdxl_quantized and their respective td specs) give the same error.