pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
459 stars 22 forks source link

How to implement Bidirectional Alibi with padding using flex attention? #74

Open sphmel opened 10 hours ago

sphmel commented 10 hours ago

Hi, I want to use FlexAttention for alibi with padding(no bias)

If seq_len is 5 I want to make alibi tensor like below, which is alibi tensor with seq_len, and last item is not penalized

0 -1 -2 -3 0
-1 0 -1 -2 0
-2 -1 0 -1 0
-3 -2 -1 0 0
0 0 0 0 0

How can I implement score mod like this? seq_len can be different every forward. Such alibi is used in Voicebox paper. I'm new to BatchedTensor or maybe vmap API? I do not know how to implement it at all. Can you help me?

sphmel commented 10 hours ago

q_idx - kv_idx wiil make tensor below, but i want to last row and column is not biased

0 -1 -2 -3 -4
-1 0 -1 -2 -3
-2 -1 0 -1 -2
-3 -2 -1 0 -1
-4 -3 -2 -1 0