Open mfrancepillois opened 2 weeks ago
The goal of the raising pass is to be able to lower memory accesses to 2D block load/store ops. To do so, we needs:
Our previous attempt has faced two main problems: 1) The brittleness of several analyses. The most problematic being the offset analysis (and especially the detection of the axis to which the offset should be applied). 2) Missing information: typically the parent tensor shape.
Since we designed the pass, other projects have made progress in the detection of memory access structures. The Cambricon project extends the AxisInfoAnalysis to recover tensor strides (see https://github.com/Cambricon/triton-linalg/blob/master/include/triton-linalg/Dialect/Triton/Interfaces/InferAxisInfoInterface.h#L176-L216). Based on DFA, this analysis seems less fragile that the one implemented in our previous proposal.
Regarding offsets, we actually do not need to detect which axis is being incremented, but we could only apply the “global” offset calculated by the kernel for regular pointer and apply it directly to the base pointer. This could help simplify the pass and remove sources of fragility.
The block shape can be easily extracted from the offset vector size (and compared to the length of group detected by the strides analysis to check that the full block can be access with one operation).
So, removing some of the complexity in our pass and using on more advanced analysis could help solve the fragility problem.
As for the missing data, the idea could be that the backend automatically adds to the kernel function signature the tensor shapes of torch tensors that have been passed to the kernel as parameters. These shapes should be added as additional parameters managed by the backend. Therefore, we could retrieve this information from the kernel.
Another point concerning regular memory access is masks. A way of dealing with them could be to keep the logic calculating the mask for the regular access and only “filter” the data returned by the 2D load using this mask.
Be doing that, we might be able to improve our raising pass and provide a reliable pass without needing additional compiler hints.
Thanks @mfrancepillois for the analysis. I think another possible solution to making the pass safe would be to version the kernel (kernel cloning) or the loop containing the loads we want to codegen to 2D block reads.
Given that our GPU performance heavily depends on 2D block reads, I think we should continue to experiment with this approach.
I am moving this work item toward year end, I would be interested in experimenting with the kernel/loop versioning idea.
Raising regular pointer into structured access, and more specifically into block-pointers, remains a topic of interest as this feature could allow optimized memory accesses even if the block-pointer API (or similar) is not used by users.
Our first attempt to design a such transformation pass faced 2 main issues:
Progress in the triton environment might help us mitigating these issues. An in-depth study of the current limitations and possible improvements is required.