microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
163 stars 34 forks source link

Rework StructuredToMemref pass to no longer use memref.extract_strided_metadata #140

Closed nhat-nguyen closed 2 months ago

nhat-nguyen commented 3 months ago

Motivation

tt.addptr produces scalar pointers, which are kept intact by the pointer analysis pass. In the latest fused-attention model, block pointers take the result of a tt.addptr as the base pointer. This behaviour is contrary to the current assumption that block pointers always take in the pointers directly from the kernel arguments; it exposes a bug where we ignore the offset of the buffer (because pointers from kernel arguments are assumed to have offset 0).

Additionally, even though we already lower tt.addptr with scalars to a pair of memref.reinterpret_cast and memref.extract_strided_metadata, lowering memref.extract_strided_metadata to work with Microsoft Maia is rather tricky because we don't have primitives to keep track of the current offset of a buffer.

To fix both issues, I have reworked the StructuredToMemref pass to no longer use memref.extract_strided_metadata.

Background

Lowering a sequence of tt.addptr to memref.reinterpret_cast is tricky because memref.reinterpret_cast does not "remember" the offset of the input buffer. Currently, we leverage memref.extract_strided_metadata to carry on the offset and simplify the lowering of tt.addptr in loops.

Ideally, for each result produced by tt.addptr needs to be converted to a pair of memref and index values that keep track of both the buffer and the index. Fortunately, implementing this approach is much simpler with the introduction of the 1->N type conversion infrastructure. This PR removes the usage of memref.extract_strided_metadata and computes the offsets directly.

Technical details

We leverage the 1->N conversion infrastructure to convert tt.addptr for scalar to memref.reinterpret_cast.

A tt.addptr has the following form:

 %new_ptr = tt.addptr %ptr %offset

where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.

With this form, there can be a chain of tt.addptr where we keep adding offsets to an existing pointer:

 %ptr_1 = tt.addptr %ptr_0 %offset
 %ptr_2 = tt.addptr %ptr_1 %offset
 %ptr_3 = tt.addptr %ptr_2 %offset

Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that the pointers can be used by affine.load and affine.store (lowered from tt.load and tt.store).

A memref.reinterpret_cast op also takes an offset and returns a memref in a similar fashion to tt.addptr:

  %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes:
  [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset:
  ?>>

However, since the semantic of memref.reinterpret_cast is different, the following lowering would be incorrect for the sequence of tt.addptr above:

  %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset]
  %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset]

The above sequence is equivalent to:

  %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset]
  %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset]

In other word, memref.reinterpret_cast ignores the current offset of the input buffer.

Therefore, we have to manually track the offset for each addptr by lowering to the following form:

 %offset_1 = arith.addi %cst_0 %offset
 %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1]

 %offset_2 = arith.addi %offset_1 %offset
 %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2]

 %offset_3 = arith.addi %offset_2 %offset
 %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3]

Each tt.addptr is lowered to a pair of arith.addi that accumulates the current offset before using that offset to the reinterpret_cast.