Open thomassajot opened 1 year ago
In the example above I called directly FlashAttnFunc.apply
to get to the above error. When using flash_attn_func
I get the following error
Expected a default value of type Tensor (inferred) on parameter "dropout_p".Because "dropout_p" was not annotated with an explicit type it is assumed to be type 'Tensor'.
Which suggest that using type annotation may help.
Idk how torch script or type annotation works. Do you have a short script to reproduce the issue, and suggestions for type annotation?
To replicate the typing exception:
import torch
from flash_attn import flash_attn_func
@torch.jit.script
def foo(x):
return flash_attn_func(x, x, x, 0.1)
which generates the following exception:
RuntimeError:
Expected a default value of type Tensor (inferred) on parameter "dropout_p".Because "dropout_p" was not annotated with an explicit type it is assumed to be type 'Tensor'.:
File "/home/thomassajot/.cache/bazel/_bazel_thomassajot/354fe3065f670718205b055cf58a4dae/execroot/WayveCode/bazel-out/k8-opt/bin/tools/jupyter_ai.runfiles/pip-dl_flash_attn/site-packages/flash_attn/flash_attn_interface.py", line 385
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return_attn_probs=False):
~~~~~~~~~~~~~~~~~~~~~~~~~
"""dropout_p should be set to 0.0 during evaluation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Arguments:
~~~~~~~~~~
q: (batch_size, seqlen, nheads, headdim)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
k: (batch_size, seqlen, nheads_k, headdim)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
v: (batch_size, seqlen, nheads_k, headdim)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dropout_p: float. Dropout probability.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
softmax_scale: float. The scaling of QK^T before applying softmax.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Default to 1 / sqrt(headdim).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
testing only. The returned probabilities are not guaranteed to be correct
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(they might not have the right scaling).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Return:
~~~~~~~
out: (batch_size, seqlen, nheads, headdim).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
normalization factor).
~~~~~~~~~~~~~~~~~~~~~~
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The output of softmax (possibly with different scaling). It also encodes the dropout
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
pattern (negative means that location was dropped, nonnegative means it was kept).
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
~~~
return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
It might help to add the annotations in the function:
def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, softmax_scale: Optional[float]=None, causal: bool = False, return_attn_probs: bool = False):
But I am not sure if this will solve the issue.
to replicate the primary error:
import torch
from flash_attn.flash_attn_interface import FlashAttnFunc
class Module(torch.nn.Module):
def forward(self, x):
return FlashAttnFunc.apply(x, x, x, 0.1, None, False, False)
x = torch.rand((20, 10, 5, 4), dtype=torch.float16, device='cuda:0')
module = torch.jit.script(Module().eval())
module.save('model.pt')
module = torch.jit.load('model.pt')
module(x)
Which produces the error message:
RuntimeError:
Could not export Python function call 'FlashAttnFunc'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
File "/tmp/ipykernel_2517720/1458385368.py", line 7
def forward(self, x):
return FlashAttnFunc.apply(x, x, x, 0.1, None, False, False)
~~~~~~~~~~~~~~~~~~~ <--- HERE
@thomassajot My understanding is torchscript cannot serialize arbitrary python functions. You may want to check this out https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html.
Thank you, I will be looking into this.
It is currently not possible to script the FlashAttentionFunc layer. It is possible to use a traced model within the same session, but saving the traced model to a file and loading it will also fail with the same error as in scripting.
Is there a work around this? Should I create an issue in Pytorch instead?