microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
132 stars 26 forks source link

Introduce `structured-to-memref` pass #102

Closed nhat-nguyen closed 4 months ago

nhat-nguyen commented 4 months ago

This PR adds the final pass structured-to-memref as outlined in the discussion on making the triton-to-linalg pass more modular: https://github.com/microsoft/triton-shared/discussions/81

Details

The pass expects to run after triton-to-structured and triton-arith-to-linalg and will lower all ops in the TritonStructured dialect to memref:

Differences

Overall, the final resulting IRs will be cleaner compared to the original monolith triton-to-linalg pass with very little changes except for handling of scalar loads and stores.

Since triton-to-structured leaves out tt.addptr, tt.load, and tt.store for unstructured pointers and scalars, the pass also handles those that deal with scalars through a combination of memref.reinterpret_cast and memref.extract_strided_metadata as seen below:

  func.func @kernel(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
    %c8_i32 = arith.constant 8 : i32
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %0 = scf.for %arg7 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg8 = %arg0) -> (!tt.ptr<f32, 1>)  : i32 {
      %1 = arith.sitofp %arg7 : i32 to f32
      tt.store %arg8, %1 {cache = 1 : i32, evict = 1 : i32} : f32
      %2 = tt.addptr %arg8, %c1_i32 : !tt.ptr<f32, 1>, i32
      scf.yield %2 : !tt.ptr<f32, 1>
    }
    return
  }

will become

  func.func @kernel(%arg0: memref<*xf32> {tt.divisibility = 16 : i32}, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
    %c1 = arith.constant 1 : index
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %c8_i32 = arith.constant 8 : i32
    %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>>
    %0 = scf.for %arg7 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg8 = %reinterpret_cast) -> (memref<1xf32, strided<[1], offset: ?>>)  : i32 {
      %1 = arith.sitofp %arg7 : i32 to f32
      affine.store %1, %arg8[0] : memref<1xf32, strided<[1], offset: ?>>
      %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %arg8 : memref<1xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
      %2 = arith.addi %offset, %c1 : index
      %reinterpret_cast_0 = memref.reinterpret_cast %base_buffer to offset: [%2], sizes: [1], strides: [1] : memref<f32> to memref<1xf32, strided<[1], offset: ?>>
      scf.yield %reinterpret_cast_0 : memref<1xf32, strided<[1], offset: ?>>
    }
    return
  }

How to use

With this pass, the whole pipeline for lowering triton IR to linalg is:

--triton-to-structured --canonicalize --triton-arith-to-linalg --structured-to-memref

Testing

lit tests

I have added all the existing lit tests from triton-to-linalg and manually verified that the outputs make sense.

CPU backend testing

I have updated our CPU backend to use the new pipeline and verified that they all ran correctly.

nhat-nguyen commented 4 months ago

@haishanzzzz I still can't add you as a reviewer so tagging you here instead.

haishanzzzz commented 4 months ago

@haishanzzzz I still can't add you as a reviewer so tagging you here instead.

Thank you @nhat-nguyen. For some reason Github told me the invite didn't take effect.. Will review early next week!