Closed wujingyue closed 1 week ago
My current plan is to change allocation_order_inference to maintain the same allocation order between inputs and outputs of SdpaFwdOp. Code around this line needs to change. Instead of trying to find one reference TV and copy from that, it'll try to merge multiple reference TVs (in this case, Q, K, and V, assuming same non_trivial_iter_count
) into one allocation order. Below is an illustration:
Is @Priya2698 working on something similar for Linears and Matmuls?
For the record, I didn't take the plan in https://github.com/NVIDIA/Fuser/issues/3386#issuecomment-2465816028. Instead, #3399 special-cased SdpaFwdOp and Bwd to make the stride order identical to ATen's implementation. I understood this is fragile and am open to better solutions. It was just difficult to find a general algorithm based on IdModel to handle special cases like this where attn_out and softmax_lse want to stride-order head and sequence differently.
Therefore, the consumer kernel of that output can read input data in the wrong order and generate wrong results. I suspect it has caused #3194 but I've yet to confirm that.
I've yet to write a simple repro, but the root cause seems pretty clear now:
The output TV of SdpaFwdOp expects the output to be contiguous and major-to-minor: https://github.com/NVIDIA/Fuser/blob/96d64b61612f3248062c7d17c68582f94b92e37c/csrc/ops/composite.cpp#L492.
However, the output tensor of SdpaFwdOp::evaluate isn't guaranteed to be contiguous or to be major-to-minor:
at::_scaled_dot_product_flash_attention
propagates the input's stride order. In forward prop, the input Q/K/V are sliced, reshaped, and transposed from the output of MHA's first linear layer. So the typical memory format is [b,s,h,e], not [b,h,s,e] as specified in the output TV.Related: this is similar to the problem discussed in https://github.com/NVIDIA/Fuser/issues/2425, but they are not quite the same. Hence, a different ticket.