Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.81k stars 1.07k forks source link

Enable relative positional embedding in flash attention #7997

Open KumoLiu opened 2 months ago

KumoLiu commented 2 months ago

From reading this thread: https://github.com/pytorch/pytorch/issues/96099#issuecomment-1480430583 It seems to me that the relative positional embedding can be integrated with scaled_dot_product_attention 's attn_mask argument. However, it can be slow as it's not taking the "fast path".

Do you think we can keep this option open for users who wants to use flash_attention and rel_pos_embedding?

_Originally posted by @mingxin-zheng in https://github.com/Project-MONAI/MONAI/pull/7977#discussion_r1701825032_

vadimkantorov commented 2 months ago

I would think that https://github.com/Dao-AILab/flash-attention/pull/617 needs to be completed for FAv2 support for arbitrary attention bias. And then depending on actual needed relative encoding formula, maybe https://github.com/Dao-AILab/flash-attention/pull/956 could be pushed

Another way forward is trying PyTorch's flex_attention which can fuse modification of attention matrix