I have a question regarding the use of torch.compile with the flex_attention function. I would like to confirm if the following two approaches are equivalent in terms of functionality and performance:
from torch.nn.attention.flex_attention import flex_attention
o, m = flex_attention(q, k, v, block_mask=block_mask, return_lse=return_max_logits)
2.
```python
from torch.nn.attention.flex_attention import flex_attention
from functools import partial
opt_flex_attention = torch.compile(partial(flex_attention, block_mask=block_mask))
o, m = opt_flex_attention(q, k, v, return_lse=return_max_logits)
In the second approach, I am using partial to predefine the block_mask. I would like to know if there are any significant differences between these two methods regarding how torch.compile optimizes the function, and if there is any potential performance impact or behavioral difference I should be aware of.
Hi,
I have a question regarding the use of
torch.compile
with theflex_attention
function. I would like to confirm if the following two approaches are equivalent in terms of functionality and performance:flex_attention = torch.compile(flex_attention, dynamic=False)
o, m = flex_attention(q, k, v, block_mask=block_mask, return_lse=return_max_logits)
In the second approach, I am using
partial
to predefine theblock_mask
. I would like to know if there are any significant differences between these two methods regarding howtorch.compile
optimizes the function, and if there is any potential performance impact or behavioral difference I should be aware of.Thank you for your assistance!