Open KumoLiu opened 2 months ago
I would think that https://github.com/Dao-AILab/flash-attention/pull/617 needs to be completed for FAv2 support for arbitrary attention bias. And then depending on actual needed relative encoding formula, maybe https://github.com/Dao-AILab/flash-attention/pull/956 could be pushed
Another way forward is trying PyTorch's flex_attention which can fuse modification of attention matrix
From reading this thread: https://github.com/pytorch/pytorch/issues/96099#issuecomment-1480430583 It seems to me that the relative positional embedding can be integrated with
scaled_dot_product_attention
'sattn_mask
argument. However, it can be slow as it's not taking the "fast path".Do you think we can keep this option open for users who wants to use flash_attention and rel_pos_embedding?
_Originally posted by @mingxin-zheng in https://github.com/Project-MONAI/MONAI/pull/7977#discussion_r1701825032_