NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

SdpaFwdOp::evaluate produces a tensor whose stride order doesn't match the output allocation domain. #3386

Closed wujingyue closed 1 week ago

wujingyue commented 1 week ago

Therefore, the consumer kernel of that output can read input data in the wrong order and generate wrong results. I suspect it has caused #3194 but I've yet to confirm that.

I've yet to write a simple repro, but the root cause seems pretty clear now:

The output TV of SdpaFwdOp expects the output to be contiguous and major-to-minor: https://github.com/NVIDIA/Fuser/blob/96d64b61612f3248062c7d17c68582f94b92e37c/csrc/ops/composite.cpp#L492.

However, the output tensor of SdpaFwdOp::evaluate isn't guaranteed to be contiguous or to be major-to-minor:

  1. Regarding contiguity, https://github.com/NVIDIA/Fuser/blob/8bd9984d0125a09a13c170a507a2deff7ae748cf/csrc/ir/nodes.cpp#L4564, if kicks in, makes the output tensor non-contiguous in the last dimension.
  2. Regarding stride order, at::_scaled_dot_product_flash_attention propagates the input's stride order. In forward prop, the input Q/K/V are sliced, reshaped, and transposed from the output of MHA's first linear layer. So the typical memory format is [b,s,h,e], not [b,h,s,e] as specified in the output TV.

Related: this is similar to the problem discussed in https://github.com/NVIDIA/Fuser/issues/2425, but they are not quite the same. Hence, a different ticket.

wujingyue commented 1 week ago

My current plan is to change allocation_order_inference to maintain the same allocation order between inputs and outputs of SdpaFwdOp. Code around this line needs to change. Instead of trying to find one reference TV and copy from that, it'll try to merge multiple reference TVs (in this case, Q, K, and V, assuming same non_trivial_iter_count) into one allocation order. Below is an illustration: Image

kevinstephano commented 1 week ago

Is @Priya2698 working on something similar for Linears and Matmuls?

wujingyue commented 1 week ago

For the record, I didn't take the plan in https://github.com/NVIDIA/Fuser/issues/3386#issuecomment-2465816028. Instead, #3399 special-cased SdpaFwdOp and Bwd to make the stride order identical to ATen's implementation. I understood this is fragile and am open to better solutions. It was just difficult to find a general algorithm based on IdModel to handle special cases like this where attn_out and softmax_lse want to stride-order head and sequence differently.