tt.addptr produces scalar pointers, which are kept intact by the pointer analysis pass. In the latest fused-attention model, block pointers take the result of a tt.addptr as the base pointer. This behaviour is contrary to the current assumption that block pointers always take in the pointers directly from the kernel arguments; it exposes a bug where we ignore the offset of the buffer (because pointers from kernel arguments are assumed to have offset 0).
Additionally, even though we already lower tt.addptr with scalars to a pair of memref.reinterpret_cast and memref.extract_strided_metadata, lowering memref.extract_strided_metadata to work with Microsoft Maia is rather tricky because we don't have primitives to keep track of the current offset of a buffer.
To fix both issues, I have reworked the StructuredToMemref pass to no longer use memref.extract_strided_metadata.
Background
Lowering a sequence of tt.addptr to memref.reinterpret_cast is tricky because memref.reinterpret_cast does not "remember" the offset of the input buffer. Currently, we leverage memref.extract_strided_metadata to carry on the offset and simplify the lowering of tt.addptr in loops.
Ideally, for each result produced by tt.addptr needs to be converted to a pair of memref and index values that keep track of both the buffer and the index. Fortunately, implementing this approach is much simpler with the introduction of the 1->N type conversion infrastructure. This PR removes the usage of memref.extract_strided_metadata and computes the offsets directly.
Technical details
We leverage the 1->N conversion infrastructure to convert tt.addptr for
scalar to memref.reinterpret_cast.
A tt.addptr has the following form:
%new_ptr = tt.addptr %ptr %offset
where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.
With this form, there can be a chain of tt.addptr where we keep adding
offsets to an existing pointer:
Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that
the pointers can be used by affine.load and affine.store (lowered from
tt.load and tt.store).
A memref.reinterpret_cast op also takes an offset and returns a memref in a
similar fashion to tt.addptr:
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes:
[1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset:
?>>
However, since the semantic of memref.reinterpret_cast is different,
the following lowering would be incorrect for the sequence of tt.addptr
above:
%cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
%cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset]
%cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset]
The above sequence is equivalent to:
%cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
%cast_2 = memref.reinterpret_cast %arg0 to offset [%offset]
%cast_3 = memref.reinterpret_cast %arg0 to offset [%offset]
In other word, memref.reinterpret_cast ignores the current offset of the
input buffer.
Therefore, we have to manually track the offset for each addptr by lowering
to the following form:
Motivation
tt.addptr
produces scalar pointers, which are kept intact by the pointer analysis pass. In the latest fused-attention model, block pointers take the result of att.addptr
as the base pointer. This behaviour is contrary to the current assumption that block pointers always take in the pointers directly from the kernel arguments; it exposes a bug where we ignore the offset of the buffer (because pointers from kernel arguments are assumed to have offset 0).Additionally, even though we already lower
tt.addptr
with scalars to a pair ofmemref.reinterpret_cast
andmemref.extract_strided_metadata
, loweringmemref.extract_strided_metadata
to work with Microsoft Maia is rather tricky because we don't have primitives to keep track of the current offset of a buffer.To fix both issues, I have reworked the StructuredToMemref pass to no longer use
memref.extract_strided_metadata
.Background
Lowering a sequence of
tt.addptr
tomemref.reinterpret_cast
is tricky becausememref.reinterpret_cast
does not "remember" the offset of the input buffer. Currently, we leveragememref.extract_strided_metadata
to carry on the offset and simplify the lowering oftt.addptr
in loops.Ideally, for each result produced by
tt.addptr
needs to be converted to a pair ofmemref
andindex
values that keep track of both the buffer and the index. Fortunately, implementing this approach is much simpler with the introduction of the 1->N type conversion infrastructure. This PR removes the usage ofmemref.extract_strided_metadata
and computes the offsets directly.Technical details
We leverage the 1->N conversion infrastructure to convert tt.addptr for scalar to memref.reinterpret_cast.
A tt.addptr has the following form:
where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.
With this form, there can be a chain of tt.addptr where we keep adding offsets to an existing pointer:
Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that the pointers can be used by affine.load and affine.store (lowered from tt.load and tt.store).
A memref.reinterpret_cast op also takes an offset and returns a memref in a similar fashion to tt.addptr:
However, since the semantic of memref.reinterpret_cast is different, the following lowering would be incorrect for the sequence of tt.addptr above:
The above sequence is equivalent to:
In other word, memref.reinterpret_cast ignores the current offset of the input buffer.
Therefore, we have to manually track the offset for each addptr by lowering to the following form:
Each tt.addptr is lowered to a pair of arith.addi that accumulates the current offset before using that offset to the reinterpret_cast.