iree-org / iree-turbine

IREE's PyTorch Frontend, based on Torch Dynamo.
Apache License 2.0
55 stars 25 forks source link

[TKW] Modify Index Seq Analysis to handle "detours" #246

Open raikonenfnu opened 2 weeks ago

raikonenfnu commented 2 weeks ago

Our current index_seq_analysis, does a backward pass on lhs, rhs, and acc, and then does a forward pass on it's consumers. This is working out OK for now, however for more complex cases we may need to modify it to also do detours. Consider this case:

lhs = read
rhs = read
bias = read
mma = mma(lhs, rhs)
res = mma + bias
write(res)

In the case above, if we want read of bias to also have the layouts from mma's acc, we'd need to do a detour during layout setting i.e mma -> res -(detour)-> bias read.

This is actually also evident in the case of our attention kernel. Currently, we are manually setting vector_shapes for M and N on our attention kernel https://github.com/iree-org/iree-turbine/blob/2b45c0fdec21f69b9cc088ec9852e98f5219c37c/tests/kernel/wave/wave_attention_test.py#L244 S.T the partial_sum/reduction of sum will have the correct expansion and indexing. In reality, we should be able to handle this by doing the detour.