microsoft / triton-shared

Shared Middle-Layer for Triton Compilation
MIT License
132 stars 26 forks source link

[StructuredToMemref] Only use `memref.extract_strided_metadata` in loops #123

Closed nhat-nguyen closed 3 months ago

nhat-nguyen commented 3 months ago

Intro

We currently lower a sequence of tt.addptr for scalars using memref.extract_strided_metadata and memref.reinterpret_cast. To improve codegen in certain cases, this PR changes the lowering to just using memref.reinterpret_cast for tt.addptr outside of loops. For tt.addptr inside loops, memref.extract_strided_metadata is still used to avoid having to carry the pointer offset in the loops' iter args.  

Background

In triton-to-linalg-experimental, when lowering %new_ptr = tt.addptr %base_ptr %add_offset for scalars, %base_ptr is lowered to a memref using a memref.reinterpret_cast op with a certain offset %curr_offset. If %ptr comes from the function arguments, we assume that the %curr_offset is 0. Then %new_ptr is lowered to another memref using memref.reinterpret_cast with offset %curr_offset + %add_offset.

With this approach, each memref.reinterpret_cast that's lowered from a tt.addptr has to know the current offset of %base_ptr and perform a pointer addition to get the new offset.

memref.extract_strided_metadata can be used to extract the offset from a memref. Then the pointer addition can be done by a simple arith.addi.

Doing so is particularly helpful when lowering tt.addptr in loops, because we don't have to carry the buffer offsets as induction variables -- the offsets can be retrieved just by calling memref.extract_strided_metadata.

Improvements

With the current approach, the generated code might not be optimal and prevents certain optimizations through canonicalization from kicking in:

%0 = memref.reinterpret_cast %arg0 with offset 0

%base, %offset, %size, %stride = memref.extract_strided_metadata %0 : memref<1xf32>, index, index, index

%new_offset = arith.addi %offset %add_offset
%1 = memref.reinterpret_cast %base with offset %new_offset

this can actually be simplified to:

%1 = memref.reinterpret_cast %base with offset %add_offset

When lowering tt.addptr not in any loops, to get the current offset the buffer, we can look at the defining op of the buffer which must be a memref.reinterpret_cast. So, in the example, we will end up with the following code:

%0 = memref.reinterpret_cast %arg0 with offset 0
%new_offset = arith.addi %offset %add_offset
%1 = memref.reinterpret_cast %base with offset %new_offset

After canonicalization, we will get:

%1 = memref.reinterpret_cast %base with offset %add_offset