pytorch-labs / attention-gym

Helpful tools and examples for working with flex-attention
BSD 3-Clause "New" or "Revised" License
484 stars 24 forks source link

Clarification on torch.compile behavior with flex_attention #35

Closed kebijuelun closed 2 months ago

kebijuelun commented 2 months ago

Hi,

I have a question regarding the use of torch.compile with the flex_attention function. I would like to confirm if the following two approaches are equivalent in terms of functionality and performance:

  1. 
    from torch.nn.attention.flex_attention import flex_attention

flex_attention = torch.compile(flex_attention, dynamic=False)

o, m = flex_attention(q, k, v, block_mask=block_mask, return_lse=return_max_logits)


2.
```python
from torch.nn.attention.flex_attention import flex_attention
from functools import partial

opt_flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask))

o, m = opt_flex_attention(q, k, v, return_lse=return_max_logits)

In the second approach, I am using partial to predefine the block_mask. I would like to know if there are any significant differences between these two methods regarding how torch.compile optimizes the function, and if there is any potential performance impact or behavioral difference I should be aware of.

Thank you for your assistance!

Chillee commented 2 months ago

@kebijuelun Nope! Both should work fine.