Closed denisbeslic closed 2 months ago
Tested with pytorch implementation
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot-Product Attention"""
def __init__(self, temperature):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
torch.backends.cuda.enable_flash_sdp(True)
def forward(self, q, k, v, mask=None):
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v
)
return attn_output, None
No large difference in performance, maybe due to very short sequence length (https://github.com/Dao-AILab/flash-attention/issues/403#issuecomment-1658844143)?
Leave this for now.
Should improve runtime
See here https://benjaminwarner.dev/2023/08/16/flash-attention-compile