Open Priya2698 opened 1 month ago
Summarizing discussion from today's meeting and an offline discussion with @jjsjann123:
For training, we need two SdpaOpFwd
and SdpaOpBwd
nodes. PR #2294 currently uses at::scaled_dot_product_attention
that does not return any intermediate values to be stored for backward and is an inference-only node. We can merge this with SdpaOpFwd
and potentially have a different API if we don't want to return all the outputs.
There are different variants of SDPA in use (flash attention, memory efficient) with slightly different function signatures, we will initially start with one (possibly, flash attention, after verifying that it is indeed being used in models like nanogpt).
CC: @IvanYashchuk
References:
which signature is the right one to target here? PyTorch itself has so many different variants and each has a different signature: https://github.com/pytorch/pytorch/blob/a6b994ed5467d4df8320cbae51cba6a98ffb139c/aten/src/ATen/native/transformers/attention.cpp#L665-L706 https://github.com/pytorch/pytorch/blob/a6b994ed5467d4df8320cbae51cba6a98ffb139c/tools/autograd/derivatives.yaml#L2806-L2829
There's no right or wrong signature, it depends on how you want to do backward computation and that would dictate the output signature for the forward function. You need to make the decision yourself what needs to be stashed for backward or could be recomputed. If you want to have fallbacks to ATen then there's no other choice than directly mimicking ATen's function signatures. It's important to remember that Flash Attention doesn't work for all input cases, the "memory efficient" one also doesn't work for all input cases.
Is flash attention kernel representable in nvFuser primitives?
There's no right or wrong signature, it depends on how you want to do backward computation and that would dictate the output signature for the forward function. You need to make the decision yourself what needs to be stashed for backward or could be recomputed.
Yes. The question here is mostly for @cowanmeg, i.e. which implementation we are targeting in codegen would determine what signature we would want to have.
[Update from Jun4 meeting] At the moment we only plan on supporting Flash Attention to support multi-GPU development. Once we support Flash Attention, we can revisit, if we need to add Memory Efficient Attention as well. There could be a few ways:
Note to self: Next few PRs
max_q/k
as scalar CPU tensors to support adding them as fusion outputs: PR #2531 SdpaFwdOp
to enforce E=Ev
and modify the corresponding pairwise mapping.SdpaBwdOp
and scheduling support
Add a new IR node for SDPA that is currently not supported within nvFuser. CC: @cowanmeg @kevinstephano