pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
470 stars 23 forks source link

Rope2d #75

Open bhack opened 1 week ago

bhack commented 1 week ago

Can you add an example about Rope2d as in META Sam2 https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam/transformer.py#L289

drisspg commented 3 days ago

Just to confirm, do you mean an example where rope is fused into FlashAttention as opposed to hows it done in SAM2 where q,k are done prior and then ran with Flash?

bhack commented 3 days ago

Yes, Does this fit in the flex API or not?

drisspg commented 1 day ago

This currently does not fit within the Flex API since this is typically implemented by pre-mutating Q and K where we don't provide any ways to mutate QK before the dot product operation.

bhack commented 1 day ago

Is it in the roadmap?

drisspg commented 1 day ago

Not currently, from what I know fusion ends up not being beneficial in training can be beneficial for memory bound cases in decoding

I will leave this open though I think we have a few other things that are high priority, like learnable biases that I am working on but will think about how this can be supported

bhack commented 1 day ago

Do you have some alternative SOTA 2d learnable bias in the roadmap?