We currently lower a sequence of tt.addptr for scalars using memref.extract_strided_metadata and memref.reinterpret_cast. To improve codegen in certain cases, this PR changes the lowering to just using memref.reinterpret_cast for tt.addptr outside of loops. For tt.addptr inside loops, memref.extract_strided_metadata is still used to avoid having to carry the pointer offset in the loops' iter args.
Background
In triton-to-linalg-experimental, when lowering %new_ptr = tt.addptr %base_ptr %add_offset for scalars, %base_ptr is lowered to a memref using a memref.reinterpret_cast op with a certain offset %curr_offset. If %ptr comes from the function arguments, we assume that the %curr_offset is 0. Then %new_ptr is lowered to another memref using memref.reinterpret_cast with offset %curr_offset + %add_offset.
With this approach, each memref.reinterpret_cast that's lowered from a tt.addptr has to know the current offset of %base_ptr and perform a pointer addition to get the new offset.
memref.extract_strided_metadata can be used to extract the offset from a memref. Then the pointer addition can be done by a simple arith.addi.
Doing so is particularly helpful when lowering tt.addptr in loops, because we don't have to carry the buffer offsets as induction variables -- the offsets can be retrieved just by calling memref.extract_strided_metadata.
Improvements
With the current approach, the generated code might not be optimal and prevents certain optimizations through canonicalization from kicking in:
%1 = memref.reinterpret_cast %base with offset %add_offset
When lowering tt.addptr not in any loops, to get the current offset the buffer, we can look at the defining op of the buffer which must be a memref.reinterpret_cast. So, in the example, we will end up with the following code:
%0 = memref.reinterpret_cast %arg0 with offset 0
%new_offset = arith.addi %offset %add_offset
%1 = memref.reinterpret_cast %base with offset %new_offset
After canonicalization, we will get:
%1 = memref.reinterpret_cast %base with offset %add_offset
Intro
We currently lower a sequence of
tt.addptr
for scalars usingmemref.extract_strided_metadata
andmemref.reinterpret_cast
. To improve codegen in certain cases, this PR changes the lowering to just usingmemref.reinterpret_cast
fortt.addptr
outside of loops. Fortt.addptr
inside loops,memref.extract_strided_metadata
is still used to avoid having to carry the pointer offset in the loops' iter args.Background
In
triton-to-linalg-experimental
, when lowering%new_ptr = tt.addptr %base_ptr %add_offset
for scalars,%base_ptr
is lowered to amemref
using amemref.reinterpret_cast
op with a certain offset%curr_offset
. If%ptr
comes from the function arguments, we assume that the%curr_offset
is0
. Then%new_ptr
is lowered to anothermemref
usingmemref.reinterpret_cast
with offset%curr_offset
+%add_offset
.With this approach, each
memref.reinterpret_cast
that's lowered from att.addptr
has to know the current offset of%base_ptr
and perform a pointer addition to get the new offset.memref.extract_strided_metadata
can be used to extract the offset from a memref. Then the pointer addition can be done by a simplearith.addi
.Doing so is particularly helpful when lowering
tt.addptr
in loops, because we don't have to carry the buffer offsets as induction variables -- the offsets can be retrieved just by callingmemref.extract_strided_metadata
.Improvements
With the current approach, the generated code might not be optimal and prevents certain optimizations through canonicalization from kicking in:
this can actually be simplified to:
When lowering
tt.addptr
not in any loops, to get the current offset the buffer, we can look at the defining op of the buffer which must be amemref.reinterpret_cast
. So, in the example, we will end up with the following code:After canonicalization, we will get: