ZKI-PH-ImageAnalysis / seq2squiggle

End-to-end simulation of nanopore sequencing signals with feed-forward transformers
MIT License
4 stars 0 forks source link

Use FlashAttention #2

Closed denisbeslic closed 2 months ago

denisbeslic commented 3 months ago

Should improve runtime

See here https://benjaminwarner.dev/2023/08/16/flash-attention-compile

denisbeslic commented 2 months ago

Tested with pytorch implementation

class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention"""

    def __init__(self, temperature):
        super().__init__()
        self.temperature = temperature
        self.softmax = nn.Softmax(dim=2)
        torch.backends.cuda.enable_flash_sdp(True)

    def forward(self, q, k, v, mask=None):
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            q, k, v
        )
        return attn_output, None

No large difference in performance, maybe due to very short sequence length (https://github.com/Dao-AILab/flash-attention/issues/403#issuecomment-1658844143)?

Leave this for now.