microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
132 stars 26 forks source link

Introducing Arm SME/SVE2 Optimization pass #109

Open danikhan632 opened 4 months ago

danikhan632 commented 4 months ago

Since this is an optimization pass to existing ttshare output, I decided to make it its own binary, currently gets past optimization phase and fails on _ttsharedir_to_llir which is to be expected since it needs different mlir-opt flags. These flags have also just been recently updated.

Also trying to introduce bf16/f16 support as well as make the current optimization passes only apply to hardware that can support it.

There are more plans for optimization than just the tile and outerproduct approach as seen here but the current build does produce valid MLIR. Based this off the example shown here.

As of now only SVE2 can tested on real hardware which I don't have access to. SME will have to be emulated. Not yet anywhere ready in a state to be merged but feedback would be appreciated.

Instructions to build

Same as normal however to see the optimized MLIR, Usage

#dependent on cmake/python/arch details
export TRITON_SME_PATH="$(pwd)/python/build/cmake.linux-aarch64-cpython-3.11/third_party/triton_shared/tools/triton-sme-opt/triton-sme-opt"
 cd ./third_party/triton_shared/python/examples
rm -rf ~/.triton/cache
python3 test_matmul.py

this is should cause the test to not compile and fail but the optimized MLIR should be printed in blue test to turn this off just set

export TRITON_SME_PATH=""

Below is the optimized MLIR produced from test_matmul.py

#map = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map1 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map2 = affine_map<(d0, d1) -> (d0, 0, d1)>
#map3 = affine_map<(d0, d1) -> (0, d1, d0)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map6 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %cst = arith.constant dense<false> : vector<1x[4]xi1>
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_1 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst_0 : f32) outs(%alloc : memref<32x64xf32>)
    %0 = bufferization.to_tensor %alloc : memref<32x64xf32>
    %1 = arith.addi %arg3, %c31_i32 : i32
    %2 = arith.divsi %1, %c32_i32 : i32
    %3 = arith.addi %arg4, %c63_i32 : i32
    %4 = arith.divsi %3, %c64_i32 : i32
    %5 = arith.muli %4, %c8_i32 : i32
    %6 = arith.divsi %arg12, %5 : i32
    %7 = arith.muli %6, %c8_i32 : i32
    %8 = arith.subi %2, %7 : i32
    %9 = arith.minsi %8, %c8_i32 : i32
    %10 = arith.remsi %arg12, %9 : i32
    %11 = arith.addi %7, %10 : i32
    %12 = arith.remsi %arg12, %5 : i32
    %13 = arith.divsi %12, %9 : i32
    %14 = arith.muli %11, %c32_i32 : i32
    %15 = arith.index_cast %14 : i32 to index
    %16 = arith.muli %13, %c64_i32 : i32
    %17 = arith.index_cast %16 : i32 to index
    %18 = arith.index_cast %arg3 : i32 to index
    %19 = arith.index_cast %arg6 : i32 to index
    %20 = arith.muli %15, %19 : index
    %21 = arith.muli %18, %19 : index
    %22 = arith.index_cast %arg7 : i32 to index
    %23 = arith.index_cast %arg4 : i32 to index
    %24 = arith.addi %arg5, %c15_i32 : i32
    %25 = arith.divsi %24, %c16_i32 : i32
    %26 = arith.muli %arg7, %c16_i32 : i32
    %27 = arith.index_cast %26 : i32 to index
    %28:3 = scf.for %arg15 = %c0_i32 to %25 step %c1_i32 iter_args(%arg16 = %0, %arg17 = %20, %arg18 = %c0) -> (tensor<32x64xf32>, index, index)  : i32 {
      %41 = bufferization.to_memref %arg16 : memref<32x64xf32>
      %42 = arith.addi %arg18, %17 : index
      %43 = arith.remsi %42, %23 : index
      %44 = arith.subi %42, %43 : index
      %45 = arith.addi %43, %c64 : index
      %46 = arith.minsi %45, %23 : index
      %47 = arith.subi %46, %43 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg1 to offset: [%42], sizes: [%c16, %47], strides: [%22, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %48 = arith.subi %c64, %47 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg1 to offset: [%44], sizes: [%c16, %48], strides: [%22, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %49 = arith.remsi %arg17, %19 : index
      %50 = arith.addi %21, %49 : index
      %51 = arith.subi %50, %arg17 : index
      %52 = arith.divsi %51, %19 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%52, %c16], strides: [%19, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %53 = arith.subi %c32, %52 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg0 to offset: [%49], sizes: [%53, %c16], strides: [%19, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %54 = arith.muli %arg15, %c16_i32 : i32
      %55 = arith.subi %arg5, %54 : i32
      %56 = arith.index_cast %55 : i32 to index
      %57 = arith.minsi %56, %c16 : index
      %alloc_8 = memref.alloc() : memref<32x16xf16>
      %58 = arith.cmpi slt, %57, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst_1 : f16) outs(%alloc_8 : memref<32x16xf16>)
      }
      %59 = arith.minsi %52, %c32 : index
      %60 = arith.subi %c32, %59 : index
      %subview_9 = memref.subview %reinterpret_cast_6[0, 0] [%59, %57] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_7[0, 0] [%60, %57] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%59, %57] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %subview_12 = memref.subview %alloc_8[%59, 0] [%60, %57] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_13 = memref.alloc() : memref<16x64xf16>
      %61 = arith.cmpi slt, %57, %c16 : index
      scf.if %61 {
        linalg.fill ins(%cst_1 : f16) outs(%alloc_13 : memref<16x64xf16>)
      }
      %62 = arith.minsi %47, %c64 : index
      %63 = arith.subi %c64, %62 : index
      %subview_14 = memref.subview %reinterpret_cast_4[0, 0] [%57, %62] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_15 = memref.subview %reinterpret_cast_5[0, 0] [%57, %63] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_16 = memref.subview %alloc_13[0, 0] [%57, %62] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %subview_17 = memref.subview %alloc_13[0, %62] [%57, %63] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %subview_14, %subview_16 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %subview_15, %subview_17 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %64 = vector.vscale
      %65 = arith.muli %64, %c4 : index
      %66 = arith.muli %64, %c4 : index
      %67 = scf.for %arg19 = %c0 to %c32 step %65 iter_args(%arg20 = %0) -> (tensor<32x64xf32>) {
        %72 = scf.for %arg21 = %c0 to %c64 step %66 iter_args(%arg22 = %arg20) -> (tensor<32x64xf32>) {
          %73 = scf.for %arg23 = %c0 to %c16 step %c1 iter_args(%arg24 = %arg22) -> (tensor<32x64xf32>) {
            %74 = bufferization.to_memref %arg24 : memref<32x64xf32>
            %75 = bufferization.to_memref %arg24 : memref<32x64xf32>
            %76 = affine.min #map(%arg19, %65)
            %77 = affine.min #map1(%arg21, %66)
            %78 = affine.min #map(%arg19, %65)
            %79 = affine.min #map1(%arg21, %66)
            %subview_19 = memref.subview %alloc_8[%arg19, %arg23] [%76, 1] [1, 1] : memref<32x16xf16> to memref<?x1xf16, strided<[16, 1], offset: ?>>
            %80 = bufferization.to_tensor %subview_19 : memref<?x1xf16, strided<[16, 1], offset: ?>>
            %subview_20 = memref.subview %alloc_13[%arg23, %arg21] [1, %77] [1, 1] : memref<16x64xf16> to memref<1x?xf16, strided<[64, 1], offset: ?>>
            %81 = bufferization.to_tensor %subview_20 : memref<1x?xf16, strided<[64, 1], offset: ?>>
            %subview_21 = memref.subview %75[%arg19, %arg21] [%78, %79] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %82 = bufferization.to_tensor %subview_21 : memref<?x?xf32, strided<[64, 1], offset: ?>>
            %83 = vector.create_mask %76, %c1 : vector<[4]x1xi1>
            %84 = vector.transfer_read %80[%c0, %c0], %cst_1, %83 {in_bounds = [true, true, true], permutation_map = #map2} : tensor<?x1xf16>, vector<[4]x[4]x1xf16>
            %85 = vector.create_mask %77 : vector<[4]xi1>
            %86 = vector.insert %85, %cst [0] : vector<[4]xi1> into vector<1x[4]xi1>
            %87 = vector.transfer_read %81[%c0, %c0], %cst_1, %86 {in_bounds = [true, true, true], permutation_map = #map3} : tensor<1x?xf16>, vector<[4]x[4]x1xf16>
            %88 = vector.create_mask %76, %77 : vector<[4]x[4]xi1>
            %89 = vector.transfer_read %82[%c0, %c0], %cst_0, %88 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x[4]xf32>
            %90 = arith.extf %84 : vector<[4]x[4]x1xf16> to vector<[4]x[4]x1xf32>
            %91 = arith.extf %87 : vector<[4]x[4]x1xf16> to vector<[4]x[4]x1xf32>
            %92 = vector.create_mask %76, %77, %c1 : vector<[4]x[4]x1xi1>
            %93 = vector.mask %92 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %90, %91, %89 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
            %94 = vector.transfer_write %93, %82[%c0, %c0], %88 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, tensor<?x?xf32>
            %95 = bufferization.to_memref %94 : memref<?x?xf32>
            %alloc_22 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
            memref.copy %74, %alloc_22 : memref<32x64xf32> to memref<32x64xf32>
            %subview_23 = memref.subview %alloc_22[%arg19, %arg21] [%78, %79] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %95, %subview_23 : memref<?x?xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %96 = bufferization.to_tensor %alloc_22 : memref<32x64xf32>
            scf.yield %96 : tensor<32x64xf32>
          }
          scf.yield %73 : tensor<32x64xf32>
        }
        scf.yield %72 : tensor<32x64xf32>
      }
      %68 = bufferization.to_memref %67 : memref<32x64xf32>
      %alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      linalg.generic {indexing_maps = [#map6, #map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%68, %41 : memref<32x64xf32>, memref<32x64xf32>) outs(%alloc_18 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_19: f32, %out: f32):
        %72 = arith.addf %in, %in_19 : f32
        linalg.yield %72 : f32
      }
      %69 = bufferization.to_tensor %alloc_18 : memref<32x64xf32>
      %70 = arith.addi %arg17, %c16 : index
      %71 = arith.addi %arg18, %27 : index
      scf.yield %69, %70, %71 : tensor<32x64xf32>, index, index
    }
    %29 = bufferization.to_memref %28#0 : memref<32x64xf32>
    %30 = arith.index_cast %arg8 : i32 to index
    %31 = arith.muli %15, %30 : index
    %32 = arith.addi %31, %17 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%32], sizes: [32, 64], strides: [%30, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    linalg.generic {indexing_maps = [#map6, #map6], iterator_types = ["parallel", "parallel"]} ins(%29 : memref<32x64xf32>) outs(%alloc_2 : memref<32x64xf16>) {
    ^bb0(%in: f32, %out: f16):
      %41 = arith.truncf %in : f32 to f16
      linalg.yield %41 : f16
    }
    %33 = arith.addi %15, %c32 : index
    %34 = arith.minsi %33, %18 : index
    %35 = arith.subi %34, %15 : index
    %36 = arith.addi %17, %c64 : index
    %37 = arith.minsi %36, %23 : index
    %38 = arith.subi %37, %17 : index
    %39 = arith.minsi %35, %c32 : index
    %40 = arith.minsi %38, %c64 : index
    %subview = memref.subview %alloc_2[0, 0] [%39, %40] [1, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %subview_3 = memref.subview %reinterpret_cast[0, 0] [%39, %40] [1, 1] : memref<32x64xf16, strided<[?, 1], offset: ?>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_3 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}
danikhan632 commented 4 months ago

Also I know that we just merged the arm-workflow runner but I might want to get rid of it since I have working changes for float16 and bfloat16. For quite sometime, triton-shared has been swapping out bf16/fp16 for f32 and I am working on optional support if the current system supports avx512_bf16(x86) or sve-bf16 (arm) or fp16 instructions. Wondering if the runner could changed at some point in the future

danikhan632 commented 4 months ago

getting error when trying to pass the IR through mlir-opt, I tried some of the flags used in the SME example but their for the transform interpreter, anybody got any ideas?

 error: failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal
            %114 = vector.mask %113 { vector.transfer_read %extracted_slice_21[%c0, %c0], %cst_0 {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>

I think its got something todo with the flags being passed in

def _ttsharedir_to_llir(ttsharedir: str):
    with tempfile.TemporaryDirectory() as tmpdir:
        ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
        llmlir_path = os.path.join(tmpdir, "ll.mlir")
        llir_path = os.path.join(tmpdir, "ll.ir")
        Path(ttshared_path).write_text(ttsharedir)
        mlir_opt_path = _get_llvm_bin_path("mlir-opt")

        # TritonShared-MLIR to LLVM-MLIR
        subprocess.check_call([
            mlir_opt_path,
            ttshared_path,
            "--convert-linalg-to-affine-loops",
            "--eliminate-empty-tensors",
            "--arm-sve-legalize-vector-storage",
            "--allocate-arm-sme-tiles",
            "--empty-tensor-to-alloc-tensor",
            "--one-shot-bufferize=allow-return-allocs-from-loops=true",
            "--lower-affine",
            "--convert-linalg-to-loops",
            "--convert-arm-sme-to-scf",
            "--convert-scf-to-cf",
            "--convert-cf-to-llvm",
            "--convert-arith-to-llvm",
            "--convert-math-to-llvm",
            "--convert-complex-to-llvm",
            "--convert-vector-to-arm-sme",
            "--convert-arm-sme-to-llvm",
            "--convert-index-to-llvm",
            "--memref-expand",
            "-convert-vector-to-llvm=enable-arm-sve",
            "--expand-strided-metadata",
            "--finalize-memref-to-llvm",
            "--convert-func-to-llvm",
            # Lowering memrefs creates more affine.apply ops.
            # Lowering these affine ops again creates further arith ops,
            # so we have to run these two passes again here.
            "--lower-affine",
            "--convert-arith-to-llvm",
            # Remove all unrealized casts created
            "--canonicalize",
            "-o",
            llmlir_path,
        ])

        # LLVM-MLIR to LLVM-IR
        mlir_translate_path = _get_llvm_bin_path("mlir-translate")
        subprocess.check_call([mlir_translate_path, llmlir_path,
            "--mlir-to-llvmir",
            "-o",
            llir_path])
        return Path(llir_path).read_text()
zhaoshiz commented 4 months ago

You maybe missing lowering masked vector transfers. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L79 The flag is -lower-vector-mask, if that doesn't work, you can call it in C++. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp#L101-L104

danikhan632 commented 3 months ago

You maybe missing lowering masked vector transfers. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L79 The flag is -lower-vector-mask, if that doesn't work, you can call it in C++. https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp#L101-L104

that seemed to change the IR but didn't seem to fix the issue completely

this is the IR now being generated

#map = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map1 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map2 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map3 = affine_map<(d0)[s0] -> (d0 + s0)>
#map4 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d1)>
#map7 = affine_map<(d0, d1) -> (d0, d1)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) {
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_0 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    linalg.fill ins(%cst : f32) outs(%alloc : memref<32x64xf32>)
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_1 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_1, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %39 = arith.addi %arg18, %16 : index
      %40 = arith.remsi %39, %22 : index
      %41 = arith.subi %39, %40 : index
      %42 = arith.addi %40, %c64 : index
      %43 = arith.minsi %42, %22 : index
      %44 = arith.subi %43, %40 : index
      %reinterpret_cast_4 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %45 = arith.subi %c64, %44 : index
      %reinterpret_cast_5 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %46 = arith.remsi %arg17, %18 : index
      %47 = arith.addi %20, %46 : index
      %48 = arith.subi %47, %arg17 : index
      %49 = arith.divsi %48, %18 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %50 = arith.subi %c32, %49 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %51 = arith.muli %arg15, %c16_i32 : i32
      %52 = arith.subi %arg5, %51 : i32
      %53 = arith.index_cast %52 : i32 to index
      %54 = arith.minsi %53, %c16 : index
      %alloc_8 = memref.alloc() : memref<32x16xf16>
      %55 = arith.cmpi slt, %54, %c16 : index
      scf.if %55 {
        linalg.fill ins(%cst_0 : f16) outs(%alloc_8 : memref<32x16xf16>)
      }
      %56 = arith.minsi %49, %c32 : index
      %57 = arith.subi %c32, %56 : index
      %subview_9 = memref.subview %reinterpret_cast_6[0, 0] [%56, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_10 = memref.subview %reinterpret_cast_7[0, 0] [%57, %54] [1, 1] : memref<?x16xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_11 = memref.subview %alloc_8[0, 0] [%56, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %subview_12 = memref.subview %alloc_8[%56, 0] [%57, %54] [1, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %subview_9, %subview_11 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %subview_10, %subview_12 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_13 = memref.alloc() : memref<16x64xf16>
      %58 = arith.cmpi slt, %54, %c16 : index
      scf.if %58 {
        linalg.fill ins(%cst_0 : f16) outs(%alloc_13 : memref<16x64xf16>)
      }
      %59 = arith.minsi %44, %c64 : index
      %60 = arith.subi %c64, %59 : index
      %subview_14 = memref.subview %reinterpret_cast_4[0, 0] [%54, %59] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_15 = memref.subview %reinterpret_cast_5[0, 0] [%54, %60] [1, 1] : memref<16x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %subview_16 = memref.subview %alloc_13[0, 0] [%54, %59] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %subview_17 = memref.subview %alloc_13[0, %59] [%54, %60] [1, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %subview_14, %subview_16 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %subview_15, %subview_17 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %61 = vector.vscale
      %62 = arith.muli %61, %c4 : index
      %63 = arith.muli %61, %c4 : index
      %alloc_18 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_18 : memref<32x64xf32> to memref<32x64xf32>
      %64 = scf.for %arg19 = %c0 to %c32 step %62 iter_args(%arg20 = %alloc_18) -> (memref<32x64xf32>) {
        %67 = scf.for %arg21 = %c0 to %c64 step %63 iter_args(%arg22 = %arg20) -> (memref<32x64xf32>) {
          %68 = scf.for %arg23 = %c0 to %c16 step %c1 iter_args(%arg24 = %arg22) -> (memref<32x64xf32>) {
            %69 = affine.min #map(%arg19, %62)
            %70 = affine.min #map1(%arg21, %63)
            %71 = affine.min #map(%arg19, %62)
            %72 = affine.min #map1(%arg21, %63)
            %subview_19 = memref.subview %alloc_8[%arg19, %arg23] [%69, 1] [1, 1] : memref<32x16xf16> to memref<?x1xf16, strided<[16, 1], offset: ?>>
            %subview_20 = memref.subview %alloc_13[%arg23, %arg21] [1, %70] [1, 1] : memref<16x64xf16> to memref<1x?xf16, strided<[64, 1], offset: ?>>
            %subview_21 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %73 = vector.create_mask %69 : vector<[4]xi1>
            %subview_22 = memref.subview %subview_19[0, 0] [%69, 1] [1, 1] : memref<?x1xf16, strided<[16, 1], offset: ?>> to memref<?xf16, #map2>
            %74 = vector.transfer_read %subview_22[%c0], %cst_0, %73 {in_bounds = [true]} : memref<?xf16, #map2>, vector<[4]xf16>
            %75 = vector.shape_cast %74 : vector<[4]xf16> to vector<[4]x1xf16>
            %76 = vector.create_mask %70 : vector<[4]xi1>
            %subview_23 = memref.subview %subview_20[0, 0] [1, %70] [1, 1] : memref<1x?xf16, strided<[64, 1], offset: ?>> to memref<?xf16, #map3>
            %77 = vector.transfer_read %subview_23[%c0], %cst_0, %76 {in_bounds = [true]} : memref<?xf16, #map3>, vector<[4]xf16>
            %78 = vector.shape_cast %77 : vector<[4]xf16> to vector<1x[4]xf16>
            %79 = vector.create_mask %69, %70 : vector<[4]x[4]xi1>
            %80 = vector.transfer_read %subview_21[%c0, %c0], %cst, %79 {in_bounds = [true, true]} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]x[4]xf32>
            %81 = arith.extf %75 : vector<[4]x1xf16> to vector<[4]x1xf32>
            %82 = arith.extf %78 : vector<1x[4]xf16> to vector<1x[4]xf32>
            %83 = vector.create_mask %69, %70, %c1 : vector<[4]x[4]x1xi1>
            %84 = vector.mask %83 { vector.contract {indexing_maps = [#map4, #map5, #map6], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %81, %82, %80 : vector<[4]x1xf32>, vector<1x[4]xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>
            vector.transfer_write %84, %subview_21[%c0, %c0], %79 {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32, strided<[64, 1], offset: ?>>
            %subview_24 = memref.subview %arg24[%arg19, %arg21] [%71, %72] [1, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %subview_21, %subview_24 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            scf.yield %arg24 : memref<32x64xf32>
          }
          scf.yield %68 : memref<32x64xf32>
        }
        scf.yield %67 : memref<32x64xf32>
      }
      linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%64, %arg16 : memref<32x64xf32>, memref<32x64xf32>) outs(%64 : memref<32x64xf32>) {
      ^bb0(%in: f32, %in_19: f32, %out: f32):
        %67 = arith.addf %in, %in_19 : f32
        linalg.yield %67 : f32
      }
      %65 = arith.addi %arg17, %c16 : index
      %66 = arith.addi %arg18, %26 : index
      scf.yield %64, %65, %66 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    linalg.generic {indexing_maps = [#map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%27#0 : memref<32x64xf32>) outs(%alloc_2 : memref<32x64xf16>) {
    ^bb0(%in: f32, %out: f16):
      %39 = arith.truncf %in : f32 to f16
      linalg.yield %39 : f16
    }
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.subi %32, %14 : index
    %34 = arith.addi %16, %c64 : index
    %35 = arith.minsi %34, %22 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.minsi %33, %c32 : index
    %38 = arith.minsi %36, %c64 : index
    %subview = memref.subview %alloc_2[0, 0] [%37, %38] [1, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %subview_3 = memref.subview %reinterpret_cast[0, 0] [%37, %38] [1, 1] : memref<32x64xf16, strided<[?, 1], offset: ?>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %subview, %subview_3 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}
zhaoshiz commented 3 months ago

There's a masked vector.contract in the IR

            %117 = vector.create_mask %105, %106, %c1 : vector<[4]x[4]x1xi1>
            %118 = vector.mask %117 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %115, %116, %114 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>

Maybe lowering vector mask can help: https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L89

danikhan632 commented 3 months ago

There's a masked vector.contract in the IR

            %117 = vector.create_mask %105, %106, %c1 : vector<[4]x[4]x1xi1>
            %118 = vector.mask %117 { vector.contract {indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %115, %116, %114 : vector<[4]x[4]x1xf32>, vector<[4]x[4]x1xf32> into vector<[4]x[4]xf32> } : vector<[4]x[4]x1xi1> -> vector<[4]x[4]xf32>

Maybe lowering vector mask can help: https://github.com/llvm/llvm-project/blob/839a8fecb4c5dfe1b4484d5fc942a9490867c47a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir#L89

I figured out the issue is that the outerproduct part seems to be having no effect on the MLIR output. I'm trying to figure out why this does nothing

struct OuterProductVectorizationPass
    : public PassWrapper<OuterProductVectorizationPass,
                         OperationPass<func::FuncOp>> {
  void getDependentDialects(DialectRegistry &registry) const override {
    registry.insert<vector::VectorDialect, func::FuncDialect>();
  }

  void runOnOperation() override {
    func::FuncOp funcOp = getOperation();
    MLIRContext *context = funcOp.getContext();
    RewritePatternSet patterns(context);
    ConversionTarget target(*context);

      // Apply patterns for lowering masked transfers
    transform::ApplyLowerMaskedTransfersPatternsOp lowerMaskedTransfersPatterns;
    lowerMaskedTransfersPatterns.populatePatterns(patterns);

    // Apply patterns for transfer permutation
    transform::ApplyTransferPermutationPatternsOp transferPermutationPatterns;
    transferPermutationPatterns.populatePatterns(patterns);

    // Apply patterns for reduction to contract
    transform::ApplyVectorReductionToContractPatternsOp reductionToContractPatterns;
    reductionToContractPatterns.populatePatterns(patterns);

    // Apply patterns for lowering contraction using outer product
    transform::ApplyLowerOuterProductPatternsOp lowerOuterProductPatterns;
    lowerOuterProductPatterns.populatePatterns(patterns);

    // Apply patterns for lowering masks
    transform::ApplyLowerMasksPatternsOp lowerMasksPatterns;
    lowerMasksPatterns.populatePatterns(patterns);

    // Apply patterns for rank-reducing subview
    transform::ApplyRankReducingSubviewPatternsOp rankReducingSubviewPatterns;
    rankReducingSubviewPatterns.populatePatterns(patterns);

    if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
      return signalPassFailure();
    }

  }

};
danikhan632 commented 3 months ago

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

If you look at the snippet, you can see that cf.br is still present. @banach-space could it some complication with --convert-arm-sme-to-scf?

ll.mlir:75:5: error: Dialect `cf' not found for custom op 'cf.br' 
    cf.br ^bb1(%37 : index)
    ^
.../ll.mlir:75:5: note: Registered dialects: acc, amx, arm_neon, arm_sme, arm_sve, builtin, dlti, func, gpu, llvm, nvvm, omp, rocdl, spirv, x86vector ; for more info on dialect registration see https://mlir.llvm.org/getting_started/Faq...

I ran with and without and as you can see in that cf.br is still present when it should have been lowered

image

current output from sme-opt:

#map = affine_map<()[s0] -> (s0 * 16)>
#map1 = affine_map<(d0, d1) -> (-d0 + 32, d1)>
#map2 = affine_map<(d0, d1) -> (-d0 + 64, d1)>
#map3 = affine_map<()[s0, s1] -> (s0 * 16 + s1)>
#map4 = affine_map<()[s0, s1] -> (s0 * 64 + s1)>
#map5 = affine_map<(d0)[s0] -> (d0 * 16 + s0)>
#map6 = affine_map<(d0)[s0] -> (d0 + s0)>
module {
  func.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: memref<*xf16> {tt.divisibility = 16 : i32}, %arg1: memref<*xf16> {tt.divisibility = 16 : i32}, %arg2: memref<*xf16> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %cst = arith.constant dense<0.000000e+00> : vector<[4]xf16>
    %c4 = arith.constant 4 : index
    %c8_i32 = arith.constant 8 : i32
    %c32_i32 = arith.constant 32 : i32
    %c64_i32 = arith.constant 64 : i32
    %c16_i32 = arith.constant 16 : i32
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c31_i32 = arith.constant 31 : i32
    %c63_i32 = arith.constant 63 : i32
    %c15_i32 = arith.constant 15 : i32
    %c0 = arith.constant 0 : index
    %c1 = arith.constant 1 : index
    %c16 = arith.constant 16 : index
    %cst_1 = arith.constant 0.000000e+00 : f16
    %c32 = arith.constant 32 : index
    %c64 = arith.constant 64 : index
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    scf.for %arg15 = %c0 to %c32 step %c1 {
      scf.for %arg16 = %c0 to %c64 step %c1 {
        memref.store %cst_0, %alloc[%arg15, %arg16] : memref<32x64xf32>
      }
    }
    %0 = arith.addi %arg3, %c31_i32 : i32
    %1 = arith.divsi %0, %c32_i32 : i32
    %2 = arith.addi %arg4, %c63_i32 : i32
    %3 = arith.divsi %2, %c64_i32 : i32
    %4 = arith.muli %3, %c8_i32 : i32
    %5 = arith.divsi %arg12, %4 : i32
    %6 = arith.muli %5, %c8_i32 : i32
    %7 = arith.subi %1, %6 : i32
    %8 = arith.minsi %7, %c8_i32 : i32
    %9 = arith.remsi %arg12, %8 : i32
    %10 = arith.addi %6, %9 : i32
    %11 = arith.remsi %arg12, %4 : i32
    %12 = arith.divsi %11, %8 : i32
    %13 = arith.muli %10, %c32_i32 : i32
    %14 = arith.index_cast %13 : i32 to index
    %15 = arith.muli %12, %c64_i32 : i32
    %16 = arith.index_cast %15 : i32 to index
    %17 = arith.index_cast %arg3 : i32 to index
    %18 = arith.index_cast %arg6 : i32 to index
    %19 = arith.muli %14, %18 : index
    %20 = arith.muli %17, %18 : index
    %21 = arith.index_cast %arg7 : i32 to index
    %22 = arith.index_cast %arg4 : i32 to index
    %23 = arith.addi %arg5, %c15_i32 : i32
    %24 = arith.divsi %23, %c16_i32 : i32
    %25 = arith.muli %arg7, %c16_i32 : i32
    %26 = arith.index_cast %25 : i32 to index
    %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
    memref.copy %alloc, %alloc_2 : memref<32x64xf32> to memref<32x64xf32>
    %27:3 = scf.for %arg15 = %c0_i32 to %24 step %c1_i32 iter_args(%arg16 = %alloc_2, %arg17 = %19, %arg18 = %c0) -> (memref<32x64xf32>, index, index)  : i32 {
      %39 = arith.addi %arg18, %16 : index
      %40 = arith.remsi %39, %22 : index
      %41 = arith.subi %39, %40 : index
      %42 = arith.addi %40, %c64 : index
      %43 = arith.minsi %42, %22 : index
      %44 = arith.subi %43, %40 : index
      %reinterpret_cast_6 = memref.reinterpret_cast %arg1 to offset: [%39], sizes: [%c16, %44], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %45 = arith.subi %c64, %44 : index
      %reinterpret_cast_7 = memref.reinterpret_cast %arg1 to offset: [%41], sizes: [%c16, %45], strides: [%21, %c1] : memref<*xf16> to memref<16x?xf16, strided<[?, ?], offset: ?>>
      %46 = arith.remsi %arg17, %18 : index
      %47 = arith.addi %20, %46 : index
      %48 = arith.subi %47, %arg17 : index
      %49 = arith.divsi %48, %18 : index
      %reinterpret_cast_8 = memref.reinterpret_cast %arg0 to offset: [%arg17], sizes: [%49, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %50 = arith.subi %c32, %49 : index
      %reinterpret_cast_9 = memref.reinterpret_cast %arg0 to offset: [%46], sizes: [%50, %c16], strides: [%18, %c1] : memref<*xf16> to memref<?x16xf16, strided<[?, ?], offset: ?>>
      %51 = arith.muli %arg15, %c16_i32 : i32
      %52 = arith.subi %arg5, %51 : i32
      %53 = arith.index_cast %52 : i32 to index
      %54 = arith.minsi %53, %c16 : index
      %alloc_10 = memref.alloc() : memref<32x16xf16>
      %55 = arith.cmpi slt, %54, %c16 : index
      scf.if %55 {
        scf.for %arg19 = %c0 to %c32 step %c1 {
          scf.for %arg20 = %c0 to %c16 step %c1 {
            memref.store %cst_1, %alloc_10[%arg19, %arg20] : memref<32x16xf16>
          }
        }
      }
      %56 = arith.minsi %49, %c32 : index
      %57 = arith.subi %c32, %56 : index
      %base_buffer_11, %offset_12, %sizes_13:2, %strides_14:2 = memref.extract_strided_metadata %reinterpret_cast_8 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_15 = memref.reinterpret_cast %base_buffer_11 to offset: [%offset_12], sizes: [%56, %54], strides: [%strides_14#0, %strides_14#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %base_buffer_16, %offset_17, %sizes_18:2, %strides_19:2 = memref.extract_strided_metadata %reinterpret_cast_9 : memref<?x16xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_20 = memref.reinterpret_cast %base_buffer_16 to offset: [%offset_17], sizes: [%57, %54], strides: [%strides_19#0, %strides_19#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %reinterpret_cast_21 = memref.reinterpret_cast %alloc_10 to offset: [0], sizes: [%56, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1]>>
      %58 = affine.apply #map()[%56]
      %reinterpret_cast_22 = memref.reinterpret_cast %alloc_10 to offset: [%58], sizes: [%57, %54], strides: [16, 1] : memref<32x16xf16> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      memref.copy %reinterpret_cast_15, %reinterpret_cast_21 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1]>>
      memref.copy %reinterpret_cast_20, %reinterpret_cast_22 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[16, 1], offset: ?>>
      %alloc_23 = memref.alloc() : memref<16x64xf16>
      %59 = arith.cmpi slt, %54, %c16 : index
      scf.if %59 {
        scf.for %arg19 = %c0 to %c16 step %c1 {
          scf.for %arg20 = %c0 to %c64 step %c1 {
            memref.store %cst_1, %alloc_23[%arg19, %arg20] : memref<16x64xf16>
          }
        }
      }
      %60 = arith.minsi %44, %c64 : index
      %61 = arith.subi %c64, %60 : index
      %base_buffer_24, %offset_25, %sizes_26:2, %strides_27:2 = memref.extract_strided_metadata %reinterpret_cast_6 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_28 = memref.reinterpret_cast %base_buffer_24 to offset: [%offset_25], sizes: [%54, %60], strides: [%strides_27#0, %strides_27#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %base_buffer_29, %offset_30, %sizes_31:2, %strides_32:2 = memref.extract_strided_metadata %reinterpret_cast_7 : memref<16x?xf16, strided<[?, ?], offset: ?>> -> memref<f16>, index, index, index, index, index
      %reinterpret_cast_33 = memref.reinterpret_cast %base_buffer_29 to offset: [%offset_30], sizes: [%54, %61], strides: [%strides_32#0, %strides_32#1] : memref<f16> to memref<?x?xf16, strided<[?, ?], offset: ?>>
      %reinterpret_cast_34 = memref.reinterpret_cast %alloc_23 to offset: [0], sizes: [%54, %60], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1]>>
      %reinterpret_cast_35 = memref.reinterpret_cast %alloc_23 to offset: [%60], sizes: [%54, %61], strides: [64, 1] : memref<16x64xf16> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      memref.copy %reinterpret_cast_28, %reinterpret_cast_34 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1]>>
      memref.copy %reinterpret_cast_33, %reinterpret_cast_35 : memref<?x?xf16, strided<[?, ?], offset: ?>> to memref<?x?xf16, strided<[64, 1], offset: ?>>
      %62 = vector.vscale
      %63 = arith.muli %62, %c4 : index
      %64 = arith.muli %62, %c4 : index
      %alloc_36 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
      memref.copy %alloc, %alloc_36 : memref<32x64xf32> to memref<32x64xf32>
      scf.for %arg19 = %c0 to %c32 step %63 {
        scf.for %arg20 = %c0 to %c64 step %64 {
          scf.for %arg21 = %c0 to %c16 step %c1 {
            %67 = affine.min #map1(%arg19, %63)
            %68 = affine.min #map2(%arg20, %64)
            %69 = affine.min #map1(%arg19, %63)
            %70 = affine.min #map2(%arg20, %64)
            %71 = affine.apply #map3()[%arg19, %arg21]
            %72 = affine.apply #map4()[%arg21, %arg20]
            %73 = affine.apply #map4()[%arg19, %arg20]
            %reinterpret_cast_37 = memref.reinterpret_cast %alloc_36 to offset: [%73], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            %74 = vector.create_mask %67 : vector<[4]xi1>
            %reinterpret_cast_38 = memref.reinterpret_cast %alloc_10 to offset: [%71], sizes: [%67], strides: [16] : memref<32x16xf16> to memref<?xf16, #map5>
            %75 = vector.vscale
            %76 = arith.muli %75, %c4 : index
            %77 = scf.for %arg22 = %c0 to %76 step %c1 iter_args(%arg23 = %cst) -> (vector<[4]xf16>) {
              %108 = vector.extractelement %74[%arg22 : index] : vector<[4]xi1>
              %109 = scf.if %108 -> (vector<[4]xf16>) {
                %110 = memref.load %reinterpret_cast_38[%arg22] : memref<?xf16, #map5>
                %111 = vector.insertelement %110, %arg23[%arg22 : index] : vector<[4]xf16>
                scf.yield %111 : vector<[4]xf16>
              } else {
                scf.yield %arg23 : vector<[4]xf16>
              }
              scf.yield %109 : vector<[4]xf16>
            }
            %78 = vector.shape_cast %77 : vector<[4]xf16> to vector<[4]x1xf16>
            %79 = vector.create_mask %68 : vector<[4]xi1>
            %reinterpret_cast_39 = memref.reinterpret_cast %alloc_23 to offset: [%72], sizes: [%68], strides: [1] : memref<16x64xf16> to memref<?xf16, #map6>
            %80 = vector.transfer_read %reinterpret_cast_39[%c0], %cst_1, %79 {in_bounds = [true]} : memref<?xf16, #map6>, vector<[4]xf16>
            %81 = vector.shape_cast %80 : vector<[4]xf16> to vector<1x[4]xf16>
            %82 = vector.create_mask %67, %68 : vector<[4]x[4]xi1>
            %83 = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
            %c4_40 = arith.constant 4 : index
            %84 = vector.vscale
            %85 = arith.muli %c4_40, %84 : index
            %86 = arith.index_cast %67 : index to i64
            %87 = arith.index_cast %85 : index to i64
            %88 = arith.minsi %86, %87 : i64
            %89 = arith.index_cast %88 : i64 to index
            %90 = vector.create_mask %68 : vector<[4]xi1>
            %c0_41 = arith.constant 0 : index
            %c1_42 = arith.constant 1 : index
            %91 = scf.for %arg22 = %c0_41 to %89 step %c1_42 iter_args(%arg23 = %83) -> (vector<[4]x[4]xf32>) {
              %108 = arith.addi %c0, %arg22 : index
              %109 = arm_sme.load_tile_slice %reinterpret_cast_37[%108, %c0], %90, %arg23, %arg22 {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
              scf.yield %109 : vector<[4]x[4]xf32>
            }
            %92 = arith.extf %78 : vector<[4]x1xf16> to vector<[4]x1xf32>
            %93 = arith.extf %81 : vector<1x[4]xf16> to vector<1x[4]xf32>
            %94 = vector.transpose %92, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
            %95 = vector.extract %94[0] : vector<[4]xf32> from vector<1x[4]xf32>
            %96 = vector.extract %93[0] : vector<[4]xf32> from vector<1x[4]xf32>
            %97 = vector.create_mask %67 : vector<[4]xi1>
            %98 = vector.create_mask %68 : vector<[4]xi1>
            %99 = arm_sme.outerproduct %95, %96 acc(%91) masks(%97, %98) {tile_id = 0 : i32} : vector<[4]xf32>, vector<[4]xf32>
            %c4_43 = arith.constant 4 : index
            %100 = vector.vscale
            %101 = arith.muli %c4_43, %100 : index
            %102 = arith.index_cast %67 : index to i64
            %103 = arith.index_cast %101 : index to i64
            %104 = arith.minsi %102, %103 : i64
            %105 = arith.index_cast %104 : i64 to index
            %106 = vector.create_mask %68 : vector<[4]xi1>
            %c0_44 = arith.constant 0 : index
            %c1_45 = arith.constant 1 : index
            scf.for %arg22 = %c0_44 to %105 step %c1_45 {
              %108 = arith.addi %c0, %arg22 : index
              arm_sme.store_tile_slice %99, %arg22, %106, %reinterpret_cast_37[%108, %c0] {tile_id = 0 : i32} : memref<?x?xf32, strided<[64, 1], offset: ?>>, vector<[4]xi1>, vector<[4]x[4]xf32>
            }
            %107 = affine.apply #map4()[%arg19, %arg20]
            %reinterpret_cast_46 = memref.reinterpret_cast %alloc_36 to offset: [%107], sizes: [%69, %70], strides: [64, 1] : memref<32x64xf32> to memref<?x?xf32, strided<[64, 1], offset: ?>>
            memref.copy %reinterpret_cast_37, %reinterpret_cast_46 : memref<?x?xf32, strided<[64, 1], offset: ?>> to memref<?x?xf32, strided<[64, 1], offset: ?>>
          }
        }
      }
      scf.for %arg19 = %c0 to %c32 step %c1 {
        scf.for %arg20 = %c0 to %c64 step %c1 {
          %67 = memref.load %alloc_36[%arg19, %arg20] : memref<32x64xf32>
          %68 = memref.load %arg16[%arg19, %arg20] : memref<32x64xf32>
          %69 = arith.addf %67, %68 : f32
          memref.store %69, %alloc_36[%arg19, %arg20] : memref<32x64xf32>
        }
      }
      %65 = arith.addi %arg17, %c16 : index
      %66 = arith.addi %arg18, %26 : index
      scf.yield %alloc_36, %65, %66 : memref<32x64xf32>, index, index
    }
    %28 = arith.index_cast %arg8 : i32 to index
    %29 = arith.muli %14, %28 : index
    %30 = arith.addi %29, %16 : index
    %reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [%30], sizes: [32, 64], strides: [%28, 1] : memref<*xf16> to memref<32x64xf16, strided<[?, 1], offset: ?>>
    %alloc_3 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf16>
    scf.for %arg15 = %c0 to %c32 step %c1 {
      scf.for %arg16 = %c0 to %c64 step %c1 {
        %39 = memref.load %27#0[%arg15, %arg16] : memref<32x64xf32>
        %40 = arith.truncf %39 : f32 to f16
        memref.store %40, %alloc_3[%arg15, %arg16] : memref<32x64xf16>
      }
    }
    %31 = arith.addi %14, %c32 : index
    %32 = arith.minsi %31, %17 : index
    %33 = arith.subi %32, %14 : index
    %34 = arith.addi %16, %c64 : index
    %35 = arith.minsi %34, %22 : index
    %36 = arith.subi %35, %16 : index
    %37 = arith.minsi %33, %c32 : index
    %38 = arith.minsi %36, %c64 : index
    %reinterpret_cast_4 = memref.reinterpret_cast %alloc_3 to offset: [0], sizes: [%37, %38], strides: [64, 1] : memref<32x64xf16> to memref<?x?xf16, strided<[64, 1]>>
    %base_buffer, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %reinterpret_cast : memref<32x64xf16, strided<[?, 1], offset: ?>> -> memref<f16>, index, index, index, index, index
    %reinterpret_cast_5 = memref.reinterpret_cast %base_buffer to offset: [%offset], sizes: [%37, %38], strides: [%strides#0, 1] : memref<f16> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    memref.copy %reinterpret_cast_4, %reinterpret_cast_5 : memref<?x?xf16, strided<[64, 1]>> to memref<?x?xf16, strided<[?, 1], offset: ?>>
    return
  }
}

function that compiles the kernel above:

def _ttsharedir_to_llir(ttsharedir: str):
    with tempfile.TemporaryDirectory() as tmpdir:
        ttshared_path = os.path.join(tmpdir, "ttshared.mlir")
        llmlir_path = os.path.join(tmpdir, "ll.mlir")
        llir_path = os.path.join(tmpdir, "ll.ir")
        Path(ttshared_path).write_text(ttsharedir)
        mlir_opt_path = _get_llvm_bin_path("mlir-opt")
        # TritonShared-MLIR to LLVM-MLIR
        subprocess.check_call([
            mlir_opt_path,
            ttshared_path,
            "--one-shot-bufferize=allow-return-allocs-from-loops=true",
            "--convert-arm-sme-to-llvm", 
            "--convert-vector-to-llvm=enable-arm-sve",
            "--convert-arith-to-llvm",
            "--convert-math-to-llvm",
            "--convert-complex-to-llvm",
            "--convert-func-to-llvm",
            "--convert-index-to-llvm",
            "--finalize-memref-to-llvm",
            "--convert-scf-to-cf",
            "--convert-cf-to-llvm", 
            "-o", llmlir_path
        ])
        # LLVM-MLIR to LLVM-IR
        mlir_translate_path = _get_llvm_bin_path("mlir-translate")
        subprocess.check_call([mlir_translate_path, llmlir_path, "--mlir-to-llvmir", "-o", llir_path])
        return Path(llir_path).read_text()

output kernel before mlir-translate:

sme_matmul_lowered.mlir.txt

MacDue commented 3 months ago

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

danikhan632 commented 3 months ago

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

yea I might have forgotten that we are going from f16 ->f32

understood I guess the tiling logic has to be a bit different since this kernel uses a f32 acculumator. btw here is orginial kernel for refrence btw before any sme/llvm lowerings are applied

edit:

was kind of confused about the outer-product-fusion thing, turns out these are pretty new and not in llvm commit 4017f04e that current triton branch uses @nhat-nguyen can this be bumped to triton hash ea9777d?

danikhan632 commented 3 months ago

the sme-opt seems to work now, producing a valid output for sme.outerproduct however I encounter an odd issue when lowering to llvm, perhaps "--convert-cf-to-llvm" fails?

You appear to be doing a f16 -> f32 matmul, that (for this example) requires slightly different tiling and passes. I think you need tile sizes of [4 x vscale, 4 x vscale, 2] (i.e. the reduction dimension is unrolled by two). Then you need to apply the arm-sme-vector-legalization pass fairly early (before convert-vector/arith-to-arm-sme), and the arm-sme-outerproduct-fusion pass just after convert-vector-to-arm-sme). This should result in arm_sme.fmopa_2way operations (rather than arm_sme.outerproduct ops).

got this working now too, have some concerns about future when I have to change dims from 1 -> 2 for widening but that can be worried about later.

fmopa2_way is being produced right now kernel.mlir.txt

have a minor issue with this:

ll.mlir:8:10: error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>
         ^

want to know if you have any idea how to remedy this and is the caused by something SME related or is it just an triton-shared/mlir thing?

here my first mlir conversion before lowerings to llvm, think the order is right

        subprocess.check_call([mlir_opt_path, sme_first_pass,
        "--canonicalize", 
            "--eliminate-empty-tensors",
            "--convert-linalg-to-loops",
            "--empty-tensor-to-alloc-tensor",
            "--expand-strided-metadata",
            "--arm-sme-vector-legalization",
            "--convert-vector-to-arm-sme",
            "--arm-sme-outer-product-fusion",
            "--arm-sve-legalize-vector-storage",
            "--convert-arith-to-arm-sme",
            "--allocate-arm-sme-tiles",
            "--convert-arm-sme-to-scf",
            "--convert-vector-to-scf",
            "-o",
            mlir_sme_pass])
nhat-nguyen commented 3 months ago

@danikhan632 unrealized_conversion_cast ops are inserted automatically by TypeConverters during the dialect conversion when resulting types are incompatible. This is unrelated to triton-shared. One way to debug this is to first find out at which pass these unrealized_conversion_cast ops start appearing.

danikhan632 commented 3 months ago

@danikhan632 unrealized_conversion_cast ops are inserted automatically by TypeConverters during the dialect conversion when resulting types are incompatible. This is unrelated to triton-shared. One way to debug this is to first find out at which pass these unrealized_conversion_cast ops start appearing.

I figured that much, think its got something todo with the way inputs are passed

  llvm.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: i64, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: i64, %arg5: !llvm.ptr, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %1 = llvm.insertvalue %arg4, %0[0] : !llvm.struct<(i64, ptr)> 
    %2 = llvm.insertvalue %arg5, %1[1] : !llvm.struct<(i64, ptr)> 
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>

like I think the kernel is expecting an i64 value and a pointer to the inputs, but it gets a memref

zhaoshiz commented 3 months ago
  llvm.func @matmul_kernel_0d1d2d34567c89c1011c(%arg0: i64, %arg1: !llvm.ptr, %arg2: i64, %arg3: !llvm.ptr, %arg4: i64, %arg5: !llvm.ptr, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32, %arg12: i32, %arg13: i32, %arg14: i32, %arg15: i32, %arg16: i32, %arg17: i32) attributes {arm_sme.tiles_in_use = 34952 : i32} {
    %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)>
    %1 = llvm.insertvalue %arg4, %0[0] : !llvm.struct<(i64, ptr)> 
    %2 = llvm.insertvalue %arg5, %1[1] : !llvm.struct<(i64, ptr)> 
    %3 = builtin.unrealized_conversion_cast %2 : !llvm.struct<(i64, ptr)> to memref<*xf16>

I see unrealized_conversion_cast errors when lowering to the llvm dialect. In my case, it's caused by that the user of this cast (%3 above) is not lowered to llvm dialect. I would check the dialect/op of the user and try find the pass(es) to lower it.

danikhan632 commented 3 months ago

zhaoshiz I think the is issue is with the memref allocation of scalable vectors here, any ideas on how to fix?


...
%62 = arith.muli %vscale, %c4 : index
%63 = arith.muli %vscale, %c4 : index
%alloc_37 = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32>
memref.copy %alloc, %alloc_37 : memref<32x64xf32> to memref<32x64xf32>
scf.for %arg19 = %c0 to %c32 step %62 {
  scf.for %arg20 = %c0 to %c64 step %63 {
    scf.for %arg21 = %c0 to %c16 step %c2 {
      %alloca = memref.alloca() : memref<vector<2x[4]xf16>>
      %alloca_38 = memref.alloca() : memref<vector<2x[4]xi1>>

...

AFTER MLIR PASSES

  %33 = "arith.constant"() <{value = 2 : index}> : () -> index
  %34 = "builtin.unrealized_conversion_cast"(%33) : (index) -> i64
  %35 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
  %36 = "arith.constant"() <{value = -1 : index}> : () -> index
  %37 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
  %38 = "builtin.unrealized_conversion_cast"(%37) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
  %39 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
  %40 = "builtin.unrealized_conversion_cast"(%21) : (index) -> i64
  %41 = "llvm.mlir.constant"() <{value = 32 : index}> : () -> i64

...
AND MORE MLIR PASSES  
  %33 = "arith.constant"() <{value = 64 : i32}> : () -> i32
  %34 = "arith.constant"() <{value = 32 : i32}> : () -> i32
  %35 = "arith.constant"() <{value = 8 : i32}> : () -> i32
  %36 = "arith.constant"() <{value = 4 : index}> : () -> index
  %37 = "arith.constant"() <{value = 2 : index}> : () -> index
  %38 = "builtin.unrealized_conversion_cast"(%37) : (index) -> i64
  %39 = "arith.constant"() <{value = dense<0.000000e+00> : vector<[4]xf16>}> : () -> vector<[4]xf16>
  %40 = "arith.constant"() <{value = -1 : index}> : () -> index
  %41 = "arith.constant"() <{value = dense<false> : vector<2x[4]xi1>}> : () -> vector<2x[4]xi1>
  %42 = "builtin.unrealized_conversion_cast"(%41) : (vector<2x[4]xi1>) -> !llvm.array<2 x vector<[4]xi1>>
  %43 = "builtin.unrealized_conversion_cast"(%24) : (index) -> i64

  <unknown>:0: error: failed to legalize operation 'builtin.unrealized_conversion_cast' that was explicitly marked illegal
<unknown>:0: note: see current operation: %38 = "builtin.unrealized_conversion_cast"(%37) : (i64) -> index
MacDue commented 3 months ago

I think it's more helpful to look at the users of an unrealized_conversion_cast rather than the cast (especially when posting snippets like the above). The users will be the thing that's keeping the casts around (likely because they've not been lowered correctly).

It looks to me like the arith dialect has not been lowered, but also stuff like making allocas for predicates (i.e. memref<vector<2x[4]xi1>>), is something you generally want to avoid. But if you have to keep them around, they need to be legalised by -arm-sve-legalize-vector-storage (but I don't think that's the cause of the issue here).

danikhan632 commented 3 months ago

I think it's more helpful to look at the users of an unrealized_conversion_cast rather than the cast (especially when posting snippets like the above). The users will be the thing that's keeping the casts around (likely because they've has not been lowered correctly).

It looks to me like the arith dialect has not been lowered, but also stuff like making allocas for predicates (i.e. memref<vector<2x[4]xi1>>), is something you generally want to avoid. But if you have to keep them around, they need to be legalised by -arm-sve-legalize-vector-storage (but I don't think that's the cause of the issue here).

yeah I don't think '--convert-arith-to-arm-sme' is really doing anything here, I wanted to vet btw that the kernel that I generated is legitimate and that the only thing that I should have to do is run it through mlir-opt and then through mlir-translate and it should be fine.

I also figured that memref<vector<2x[4]xi1>> is not great since these vectors sizes aren't known till run time.

matmul.mlir.txt

MacDue commented 3 months ago

I think the allocas for the scalable vectors come from using default lowering of --convert-vector-to-scf. If you do --convert-vector-to-scf=full-unroll, those should be avoided.

danikhan632 commented 3 months ago

I think the allocas for the scalable vectors come from using default lowering of --convert-vector-to-scf. If you do --convert-vector-to-scf=full-unroll, those should be avoided. it lowers to llir successfully now, this issue it when compiling using llc any ideas why? sme_matmul_lowered.llir.txt

LLVM ERROR: Cannot select: t155: i64 = vscale Constant:i64<1024>
  t154: i64 = Constant<1024>
In function: matmul_kernel_0d1d2d34567c89c1011c
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: /home/green/.triton/llvm/llvm-6f44bb77-ubuntu-x64/bin/llc /tmp/tmp7lnmobsz/kernel.ll -o /tmp/tmp7lnmobsz/kernel.o
1.      Running pass 'Function Pass Manager' on module '/tmp/tmp7lnmobsz/kernel.ll'.
2.      Running pass 'X86 DAG->DAG Instruction Selection' on function '@matmul_kernel_0d1d2d34567c89c1011c'
MacDue commented 3 months ago

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

danikhan632 commented 3 months ago

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

ah I see, I think this is where the sme userspace emulator is needed,

I think this is correct, going to switch over to my arm system to test it

def _llir_to_bin(llir: str, metadata):
    pattern = r"define void @(\w+)\(.+"
    matches = re.findall(pattern, llir)
    assert len(matches) == 1
    metadata["name"] = matches[0]
    with tempfile.TemporaryDirectory() as tmpdir:
        src_path = os.path.join(tmpdir, "kernel.ll")
        dst_path = os.path.join(tmpdir, "kernel.o")
        Path(src_path).write_text(llir)
        llc_path = _get_llvm_bin_path("llc")
        subprocess.check_call(["/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])
        # Actually it's text-format assembly.  Use read_text().
        return Path(dst_path).read_text()
danikhan632 commented 3 months ago

What flags are you using? To compile with llc (for example) you'd need to pass -mattr=+sve,+sme when using SVE and SME.

I did that and get this error,

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

output:

/tmp/tmp1j2jbysw/kernel.s: Assembler messages:
/tmp/tmp1j2jbysw/kernel.s:126: Error: selected processor does not support `rdvl x8,#1'
/tmp/tmp1j2jbysw/kernel.s:130: Error: selected processor does not support `cntw x24'
/tmp/tmp1j2jbysw/kernel.s:409: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:410: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:417: Error: selected processor does not support `incw x22'
/tmp/tmp1j2jbysw/kernel.s:420: Error: selected processor does not support `addvl x8,x8,#8'
/tmp/tmp1j2jbysw/kernel.s:438: Error: selected processor does not support `incw x25'
/tmp/tmp1j2jbysw/kernel.s:439: Error: selected processor does not support `addvl x20,x20,#1'
/tmp/tmp1j2jbysw/kernel.s:486: Error: selected processor does not support `index z6.s,#0,#1'
/tmp/tmp1j2jbysw/kernel.s:490: Error: selected processor does not support `ptrue p2.s'
/tmp/tmp1j2jbysw/kernel.s:513: Error: selected processor does not support `mov z0.s,w9'
/tmp/tmp1j2jbysw/kernel.s:514: Error: selected processor does not support `cmpgt p0.s,p2/z,z0.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:516: Error: selected processor does not support `ld1h {z0.s},p0/z,[x13,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:519: Error: selected processor does not support `ld1h {z1.s},p0/z,[x11,x10,lsl#1]'
/tmp/tmp1j2jbysw/kernel.s:522: Error: unknown mnemonic `zero' -- `zero {za0.s}'
/tmp/tmp1j2jbysw/kernel.s:530: Error: operand 1 must be a list of SVE vector registers -- `ld1w {za0h.s[w12,0]},p0/z,[x13]'
/tmp/tmp1j2jbysw/kernel.s:536: Error: selected processor does not support `mov z2.s,w8'
/tmp/tmp1j2jbysw/kernel.s:539: Error: selected processor does not support `cmpgt p0.s,p2/z,z2.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:540: Error: selected processor does not support `mov z2.h,#0'
/tmp/tmp1j2jbysw/kernel.s:541: Error: selected processor does not support `mov z3.s,p0/z,#1'
/tmp/tmp1j2jbysw/kernel.s:554: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:555: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:558: Error: selected processor does not support `mov z4.s,w11'
/tmp/tmp1j2jbysw/kernel.s:559: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:561: Error: selected processor does not support `mov z2.h,p0/m,h4'
/tmp/tmp1j2jbysw/kernel.s:564: Error: selected processor does not support `mov z4.h,#0'
/tmp/tmp1j2jbysw/kernel.s:579: Error: selected processor does not support `whilels p0.s,xzr,x11'
/tmp/tmp1j2jbysw/kernel.s:580: Error: selected processor does not support `lastb w13,p0,z3.s'
/tmp/tmp1j2jbysw/kernel.s:583: Error: selected processor does not support `mov z5.s,w11'
/tmp/tmp1j2jbysw/kernel.s:584: Error: selected processor does not support `cmpeq p0.s,p2/z,z6.s,z5.s'
/tmp/tmp1j2jbysw/kernel.s:586: Error: selected processor does not support `mov z4.h,p0/m,h5'
/tmp/tmp1j2jbysw/kernel.s:589: Error: selected processor does not support `mov z3.s,w8'
/tmp/tmp1j2jbysw/kernel.s:590: Error: selected processor does not support `mov z5.s,w9'
/tmp/tmp1j2jbysw/kernel.s:593: Error: selected processor does not support `cmpgt p1.s,p2/z,z3.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:594: Error: selected processor does not support `cmpgt p0.s,p2/z,z5.s,z6.s'
/tmp/tmp1j2jbysw/kernel.s:595: Error: selected processor does not support `zip2 z3.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:596: Error: selected processor does not support `zip1 z2.s,z2.s,z4.s'
/tmp/tmp1j2jbysw/kernel.s:597: Error: selected processor does not support `zip2 z4.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:598: Error: selected processor does not support `zip1 z0.s,z0.s,z1.s'
/tmp/tmp1j2jbysw/kernel.s:600: Error: selected processor does not support `zip2 p2.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:602: Error: selected processor does not support `zip1 p1.s,p1.s,p1.s'
/tmp/tmp1j2jbysw/kernel.s:603: Error: selected processor does not support `zip2 p3.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:604: Error: selected processor does not support `uzp1 z1.h,z2.h,z3.h'
/tmp/tmp1j2jbysw/kernel.s:605: Error: selected processor does not support `uzp1 z0.h,z0.h,z4.h'
/tmp/tmp1j2jbysw/kernel.s:606: Error: selected processor does not support `zip1 p4.s,p0.s,p0.s'
/tmp/tmp1j2jbysw/kernel.s:607: Error: selected processor does not support `uzp1 p1.h,p1.h,p2.h'
/tmp/tmp1j2jbysw/kernel.s:608: Error: selected processor does not support `uzp1 p2.h,p4.h,p3.h'
/tmp/tmp1j2jbysw/kernel.s:609: Error: unknown mnemonic `fmopa' -- `fmopa za0.s,p1/m,p2/m,z1.h,z0.h'
/tmp/tmp1j2jbysw/kernel.s:617: Error: operand 1 must be a list of SVE vector registers -- `st1w {za0h.s[w12,0]},p0,[x13]'

I took a look at the sme_matmul.mlir and trying to figure out where these shared libs are

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s
banach-space commented 2 months ago

Generic Advice

It's best to run the tests as part of the build process of MLIR (or afterwards) and then to copy the build commands from tests. CMake flags to run the SME integration tests are documented here:

  -DMLIR_INCLUDE_INTEGRATION_TESTS=On
  -DMLIR_RUN_ARM_SME_TESTS=On
  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 

Then, during/after the build, you can either run all the tests:

ninja check-mlir

or just selected integration tests:

cd <llvm-build-dir>
# Please adjust paths to match your system
bin/llvm-lit -va ../../mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

Note that I am using -va - this will make LIT print the RUN commands. You can extract what's needed from those RUN lines. I would use these as your reference commands.

I would make sure that these tests work for you before trying to run things manually.

Specific advice

As you have noted, the tests will contain sth like this:

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s

%mcr_aarch64_cmd is a convenience wrapper for mlir-cpu-runner:

This is important - it means that ^^^ defines flags to be passed to mlir-cpu-runner. However, you are passing these flags to qemu-aarch64-static:

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

That's incorrect and won't work :)

Now, it also looks like you are passing llc to qemu-aarch64-static (guessing based on llc_path above). That's not required (*) - llc is a driver for the LLVM backend that lowers LLVM IR to Machine Code.

Also, we don't really use llc for the integration tests. Instead, we rely on mlir-cpu-runner to drive that part of the compilation (the name is a bit confusing).

As for -march=aarch64 -mattr="+sve,+sme", those flags are passed to mlir-cpu-runner (i.e. %mcr_aarch64_cmd) - that's to inform the compilation pipeline (driven by mlir-cpu-runner) to target SVE.

Suggestion

  1. Try running SME integration tests in MLIR. This will give you a working reference.
  2. Share your build step and try to run your binary from command line rather than via Python. In particular, what is it that you are trying to run? An MLIR file? An LLVM IR file? A binary? What do you get at the end of your compilation?

HTH :) -Andrzej

(*) Unless you've cross-compiled it, but I highly doubt it.

danikhan632 commented 2 months ago

Generic Advice

It's best to run the tests as part of the build process of MLIR (or afterwards) and then to copy the build commands from tests. CMake flags to run the SME integration tests are documented here:

  -DMLIR_INCLUDE_INTEGRATION_TESTS=On
  -DMLIR_RUN_ARM_SME_TESTS=On
  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 

Then, during/after the build, you can either run all the tests:

ninja check-mlir

or just selected integration tests:

cd <llvm-build-dir>
# Please adjust paths to match your system
bin/llvm-lit -va ../../mlir/test/Integration/Dialect/Linalg/CPU/ArmSVE/matmul.mlir

Note that I am using -va - this will make LIT print the RUN commands. You can extract what's needed from those RUN lines. I would use these as your reference commands.

I would make sure that these tests work for you before trying to run things manually.

Specific advice

As you have noted, the tests will contain sth like this:

// %mcr_aarch64_cmd \
//-e=main -entry-point-result=void \
//-march=aarch64 -mattr="+sve,+sme" \
//-shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib | \
// RUN: FileCheck %s

%mcr_aarch64_cmd is a convenience wrapper for mlir-cpu-runner:

This is important - it means that ^^^ defines flags to be passed to mlir-cpu-runner. However, you are passing these flags to qemu-aarch64-static:

        subprocess.check_call([ "/usr/bin/qemu-aarch64-static",llc_path, src_path, "--mattr=+sve,+sme","-o", dst_path])

That's incorrect and won't work :)

Now, it also looks like you are passing llc to qemu-aarch64-static (guessing based on llc_path above). That's not required (*) - llc is a driver for the LLVM backend that lowers LLVM IR to Machine Code.

Also, we don't really use llc for the integration tests. Instead, we rely on mlir-cpu-runner to drive that part of the compilation (the name is a bit confusing).

As for -march=aarch64 -mattr="+sve,+sme", those flags are passed to mlir-cpu-runner (i.e. %mcr_aarch64_cmd) - that's to inform the compilation pipeline (driven by mlir-cpu-runner) to target SVE.

Suggestion

  1. Try running SME integration tests in MLIR. This will give you a working reference.
  2. Share your build step and try to run your binary from command line rather than via Python. In particular, what is it that you are trying to run? An MLIR file? An LLVM IR file? A binary? What do you get at the end of your compilation?

HTH :) -Andrzej

(*) Unless you've cross-compiled it, but I highly doubt it.

got it, I'm trying to pass llir through llc to compile to binary. I think mlir-cpu-runner will be good for IR tests but looking to run this E2E.

Not sure if the MLIR CPU runner can do this


  -DARM_EMULATOR_EXECUTABLE=<path-to-emulator> 

also is this the instruction emulator?

https://developer.arm.com/Tools%20and%20Software/Arm%20Instruction%20Emulator

banach-space commented 2 months ago

also is this the instruction emulator?

ArmIE is one emulator, but based on the website it only support SVE and SVE2 (so no SME):

Arm Instruction Emulator (ArmIE) emulates Scalable Vector Extension (SVE) and SVE2 instructions on AArch64 platforms.

QEMU does support SME: https://qemu-project.gitlab.io/qemu/system/arm/cpu-features.html

Btw, I forgot to answer your other question:

I took a look at the sme_matmul.mlir and trying to figure out where these shared libs are

These are MLIR runtime libs - you will find them in the LLVM build directory under the lib directory. Also, note that:

:)

steplong commented 1 week ago

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

danikhan632 commented 1 week ago

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats. Only on ubuntu 22.04 not 20.04 Qemu must be built from source some more changes I need to push but yes I have confirmed this to work

steplong commented 1 week ago

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats. Only on ubuntu 22.04 not 20.04 Qemu must be built from source some more changes I need to push but yes I have confirmed this to work

Could you share those changes? Right now, I modified the generated launcher.cpp to build as an executable and then linking with the generated kernel.o, passing two tensors to matmul_kernel, and then comparing the output to an expected result.

danikhan632 commented 1 week ago

Hi, @danikhan632 I'm trying this PR out and I'm seeing a correctness issue where the result of the matmul doesn't match the expected output. I'm looking into it right now, but I was wondering if you were able to run the compiled test_matmul with QEMU.

Yeah I've been able to get it to work with some caveats. Only on ubuntu 22.04 not 20.04 Qemu must be built from source some more changes I need to push but yes I have confirmed this to work

Could you share those changes? Right now, I modified the generated launcher.cpp to build as an executable and then linking with the generated kernel.o, passing two tensors to matmul_kernel, and then comparing the output to an expected result.

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

steplong commented 1 week ago

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

danikhan632 commented 1 week ago

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

steplong commented 5 days ago

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

diff --git a/backend/compiler.py b/backend/compiler.py
index a0965e9..0535430 100644
--- a/backend/compiler.py
+++ b/backend/compiler.py
@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):
     assert len(matches) == 1
     metadata["name"] = matches[0]
     with tempfile.TemporaryDirectory() as tmpdir:
-        src_path = os.path.join(tmpdir, "kernel.ll")
-        dst_path = os.path.join(tmpdir, "kernel.o")
+        src_path = os.path.join(os.getcwd(), "kernel.ll")
+        dst_path = os.path.join(os.getcwd(), "kernel.o")
         Path(src_path).write_text(llir)
         llc_path = _get_llvm_bin_path("llc")
         if  _get_triton_SME_path() == "":
diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py
index caa072c..7ef952d 100644
--- a/python/examples/test_matmul.py
+++ b/python/examples/test_matmul.py
@@ -85,14 +85,14 @@ def matmul_kernel(
     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
     # of fp32 values for higher accuracy.
     # `accumulator` will be converted back to fp16 after the loop.
-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
         # Load the next block of A and B, generate a mask by checking the K dimension.
         # If it is out of bounds, set it to 0.
         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        accumulator += tl.dot(a, b).to(tl.float16)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -100,7 +100,7 @@ def matmul_kernel(
     # while the accumulator is still in FP32!
     if ACTIVATION == "leaky_relu":
         accumulator = leaky_relu(accumulator)
-    c = accumulator.to(tl.float32)
+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

danikhan632 commented 3 days ago

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16


diff --git a/backend/compiler.py b/backend/compiler.py

index a0965e9..0535430 100644

--- a/backend/compiler.py

+++ b/backend/compiler.py

@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):

     assert len(matches) == 1

     metadata["name"] = matches[0]

     with tempfile.TemporaryDirectory() as tmpdir:

-        src_path = os.path.join(tmpdir, "kernel.ll")

-        dst_path = os.path.join(tmpdir, "kernel.o")

+        src_path = os.path.join(os.getcwd(), "kernel.ll")

+        dst_path = os.path.join(os.getcwd(), "kernel.o")

         Path(src_path).write_text(llir)

         llc_path = _get_llvm_bin_path("llc")

         if  _get_triton_SME_path() == "":

diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py

index caa072c..7ef952d 100644

--- a/python/examples/test_matmul.py

+++ b/python/examples/test_matmul.py

@@ -85,14 +85,14 @@ def matmul_kernel(

     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block

     # of fp32 values for higher accuracy.

     # `accumulator` will be converted back to fp16 after the loop.

-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)

     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

         # Load the next block of A and B, generate a mask by checking the K dimension.

         # If it is out of bounds, set it to 0.

         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)

         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

         # We accumulate along the K dimension.

-        accumulator += tl.dot(a, b)

+        accumulator += tl.dot(a, b).to(tl.float16)

         # Advance the ptrs to the next K block.

         a_ptrs += BLOCK_SIZE_K * stride_ak

         b_ptrs += BLOCK_SIZE_K * stride_bk

@@ -100,7 +100,7 @@ def matmul_kernel(

     # while the accumulator is still in FP32!

     if ACTIVATION == "leaky_relu":

         accumulator = leaky_relu(accumulator)

-    c = accumulator.to(tl.float32)

+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

Ok let me try and fix that, broke my env so taking me longer than it should

danikhan632 commented 2 days ago

Sure , give me a sec to push those changes, wanted to know what your output looks like btw, like is it off by a small amount or larger?

Hmm, a lot of the elements, that differ from the expected output, are 0, but some aren't.

I don't the widening works as of now so fp16 -> fp32 doesn't work. Change accumulatior to fp16

diff --git a/backend/compiler.py b/backend/compiler.py
index a0965e9..0535430 100644
--- a/backend/compiler.py
+++ b/backend/compiler.py
@@ -184,8 +184,8 @@ def _llir_to_bin(llir: str, metadata):
     assert len(matches) == 1
     metadata["name"] = matches[0]
     with tempfile.TemporaryDirectory() as tmpdir:
-        src_path = os.path.join(tmpdir, "kernel.ll")
-        dst_path = os.path.join(tmpdir, "kernel.o")
+        src_path = os.path.join(os.getcwd(), "kernel.ll")
+        dst_path = os.path.join(os.getcwd(), "kernel.o")
         Path(src_path).write_text(llir)
         llc_path = _get_llvm_bin_path("llc")
         if  _get_triton_SME_path() == "":
diff --git a/python/examples/test_matmul.py b/python/examples/test_matmul.py
index caa072c..7ef952d 100644
--- a/python/examples/test_matmul.py
+++ b/python/examples/test_matmul.py
@@ -85,14 +85,14 @@ def matmul_kernel(
     # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
     # of fp32 values for higher accuracy.
     # `accumulator` will be converted back to fp16 after the loop.
-    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float16)
     for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
         # Load the next block of A and B, generate a mask by checking the K dimension.
         # If it is out of bounds, set it to 0.
         a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
         b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
         # We accumulate along the K dimension.
-        accumulator += tl.dot(a, b)
+        accumulator += tl.dot(a, b).to(tl.float16)
         # Advance the ptrs to the next K block.
         a_ptrs += BLOCK_SIZE_K * stride_ak
         b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -100,7 +100,7 @@ def matmul_kernel(
     # while the accumulator is still in FP32!
     if ACTIVATION == "leaky_relu":
         accumulator = leaky_relu(accumulator)
-    c = accumulator.to(tl.float32)
+    c = accumulator.to(tl.float16)

This is the change I'm trying and I'm not seeing any changes in the output.

I've had issues recreating this behavior, could you reach out to danikhan632@gmail.com with more details?