Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.12k stars 1.32k forks source link

Cannot trace or Script flash attention v2 #431

Open thomassajot opened 1 year ago

thomassajot commented 1 year ago

It is currently not possible to script the FlashAttentionFunc layer. It is possible to use a traced model within the same session, but saving the traced model to a file and loading it will also fail with the same error as in scripting.

Could not export Python function call 'FlashAttnFunc'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:

Is there a work around this? Should I create an issue in Pytorch instead?

thomassajot commented 1 year ago

In the example above I called directly FlashAttnFunc.apply to get to the above error. When using flash_attn_func I get the following error

Expected a default value of type Tensor (inferred) on parameter "dropout_p".Because "dropout_p" was not annotated with an explicit type it is assumed to be type 'Tensor'.

Which suggest that using type annotation may help.

tridao commented 1 year ago

Idk how torch script or type annotation works. Do you have a short script to reproduce the issue, and suggestions for type annotation?

thomassajot commented 1 year ago

To replicate the typing exception:

import torch
from flash_attn import flash_attn_func

@torch.jit.script
def foo(x):
    return flash_attn_func(x, x, x, 0.1)

which generates the following exception:

RuntimeError: 
Expected a default value of type Tensor (inferred) on parameter "dropout_p".Because "dropout_p" was not annotated with an explicit type it is assumed to be type 'Tensor'.:
  File "/home/thomassajot/.cache/bazel/_bazel_thomassajot/354fe3065f670718205b055cf58a4dae/execroot/WayveCode/bazel-out/k8-opt/bin/tools/jupyter_ai.runfiles/pip-dl_flash_attn/site-packages/flash_attn/flash_attn_interface.py", line 385
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                    return_attn_probs=False):
                    ~~~~~~~~~~~~~~~~~~~~~~~~~
    """dropout_p should be set to 0.0 during evaluation
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

    Arguments:
    ~~~~~~~~~~
        q: (batch_size, seqlen, nheads, headdim)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        k: (batch_size, seqlen, nheads_k, headdim)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        v: (batch_size, seqlen, nheads_k, headdim)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        dropout_p: float. Dropout probability.
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        softmax_scale: float. The scaling of QK^T before applying softmax.
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Default to 1 / sqrt(headdim).
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        return_attn_probs: bool. Whether to return the attention probabilities. This option is for
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           testing only. The returned probabilities are not guaranteed to be correct
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
           (they might not have the right scaling).
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    Return:
    ~~~~~~~
        out: (batch_size, seqlen, nheads, headdim).
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            normalization factor).
            ~~~~~~~~~~~~~~~~~~~~~~
        S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            The output of softmax (possibly with different scaling). It also encodes the dropout
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            pattern (negative means that location was dropped, nonnegative means it was kept).
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    """
    ~~~
    return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, return_attn_probs)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

It might help to add the annotations in the function:

def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout_p: float = 0.0, softmax_scale: Optional[float]=None, causal: bool = False, return_attn_probs: bool = False):

But I am not sure if this will solve the issue.

thomassajot commented 1 year ago

to replicate the primary error:

import torch
from flash_attn.flash_attn_interface import FlashAttnFunc

class Module(torch.nn.Module):
    def forward(self, x):
        return FlashAttnFunc.apply(x, x, x, 0.1, None, False, False)

x = torch.rand((20, 10, 5, 4), dtype=torch.float16, device='cuda:0')
module = torch.jit.script(Module().eval())
module.save('model.pt')
module = torch.jit.load('model.pt')
module(x)

Which produces the error message:

RuntimeError: 
Could not export Python function call 'FlashAttnFunc'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to __constants__:
  File "/tmp/ipykernel_2517720/1458385368.py", line 7
    def forward(self, x):
        return FlashAttnFunc.apply(x, x, x, 0.1, None, False, False)
               ~~~~~~~~~~~~~~~~~~~ <--- HERE
ymwangg commented 1 year ago

@thomassajot My understanding is torchscript cannot serialize arbitrary python functions. You may want to check this out https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html.

thomassajot commented 1 year ago

Thank you, I will be looking into this.