Our current index_seq_analysis, does a backward pass on lhs, rhs, and acc, and then does a forward pass on it's consumers. This is working out OK for now, however for more complex cases we may need to modify it to also do detours. Consider this case:
In the case above, if we want read of bias to also have the layouts from mma's acc, we'd need to do a detour during layout setting i.e
mma -> res -(detour)-> bias read.
Our current index_seq_analysis, does a backward pass on lhs, rhs, and acc, and then does a forward pass on it's consumers. This is working out OK for now, however for more complex cases we may need to modify it to also do detours. Consider this case:
In the case above, if we want
read
ofbias
to also have the layouts from mma's acc, we'd need to do a detour during layout setting i.emma -> res -(detour)-> bias read.
This is actually also evident in the case of our attention kernel. Currently, we are manually setting vector_shapes for
M
andN
on our attention kernel https://github.com/iree-org/iree-turbine/blob/2b45c0fdec21f69b9cc088ec9852e98f5219c37c/tests/kernel/wave/wave_attention_test.py#L244 S.T the partial_sum/reduction of sum will have the correct expansion and indexing. In reality, we should be able to handle this by doing the detour.