Open bhack opened 1 week ago
Just to confirm, do you mean an example where rope is fused into FlashAttention as opposed to hows it done in SAM2 where q,k are done prior and then ran with Flash?
Yes, Does this fit in the flex API or not?
This currently does not fit within the Flex API since this is typically implemented by pre-mutating Q and K where we don't provide any ways to mutate QK before the dot product operation.
Is it in the roadmap?
Not currently, from what I know fusion ends up not being beneficial in training can be beneficial for memory bound cases in decoding
I will leave this open though I think we have a few other things that are high priority, like learnable biases that I am working on but will think about how this can be supported
Do you have some alternative SOTA 2d learnable bias in the roadmap?
Can you add an example about Rope2d as in META Sam2 https://github.com/facebookresearch/sam2/blob/main/sam2/modeling/sam/transformer.py#L289