cvg / LightGlue

LightGlue: Local Feature Matching at Light Speed (ICCV 2023)
Apache License 2.0
3.41k stars 336 forks source link

FlashAttention actually does not support attention mask #116

Open HJoonKwon opened 9 months ago

HJoonKwon commented 9 months ago

Thanks for your great work!

I'm just curious whether your code here is using flash or not when mask is not None. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used, half() would not have been needed when mask is not None. Thank you!

++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not None. If we force PyTorch to use flash, it spits out an error like below.

class Attention(nn.Module):
    def __init__(self, attn_dropout=0.0):
        super().__init__()
        self.attn_dropout = attn_dropout

    def forward(self, q, k, v, q_mask=None, kv_mask=None):
        if kv_mask is not None:
            attn_mask = q_mask[:, None, :, None] * kv_mask[:, None, None, :]
        else:
            attn_mask = None
        with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
            )

        return y if attn_mask is None else y.nan_to_num()

device = 'cuda'
attn = Attention().to(device)
B = 4
L = 32 * 32
S = 24 * 24
n_embd = 32
n_heads = 4
q = torch.randn(B, n_heads, L, n_embd // n_heads).to(device)
k = torch.randn(B, n_heads, S, n_embd // n_heads).to(device)
v = torch.randn(B, n_heads, S, n_embd // n_heads).to(device)
q_mask = (torch.rand(B, L) > 0.1).to(device)
kv_mask = (torch.rand(B, S) > 0.1).to(device)
x = [x.half() for x in [q, k, v]]
y = attn(*x, q_mask, kv_mask)
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory efficient kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Flash attention kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.)
  y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.)
  y = torch.nn.functional.scaled_dot_product_attention(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[34], line 12
     10 kv_mask = (torch.rand(B, S) > 0.1).to(device)
     11 x = [x.half() for x in [q, k, v]]
---> 12 y = attn(*x, q_mask, kv_mask)

File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

Cell In[32], line 12, in TorchNativeAttention.forward(self, q, k, v, q_mask, kv_mask)
     10     attn_mask = None
     11 with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
---> 12     y = torch.nn.functional.scaled_dot_product_attention(
     13         q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
     14     )
     16 return y if attn_mask is None else y.nan_to_num()

RuntimeError: No available kernel.  Aborting execution.

while memory efficient kernel does not

with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
            )
Phil26AT commented 9 months ago

Hey @HJoonKwon! Damn, very good find, thank you! I guess this does matter in compiled forward, where we are padding inputs to static dimensions. We'd need to run the benchmarks, but maybe avoiding the call to half() could improve throughput then.

HJoonKwon commented 8 months ago

@Phil26AT Great! Thank you again for your great work. I got inspired a lot.

LudvigDillen commented 8 months ago

On the topic of FlashAttention, you link to FlashAttention and not FlashAttention2 here image Isn't the second version used? If not, why? Seems quite much faster image

FlashAttention: https://arxiv.org/abs/2205.14135 FlashAttention2: https://arxiv.org/pdf/2307.08691.pdf?trk=public_post_comment-text