Open grazder opened 1 year ago
Surely the graph break around memory_efficient_attention is a good thing? The xformers implementation of memory_efficient_attention is better than what torch.compile currently can come up with on its own, so the user should want torch.compile to only optimize those parts of the logic which are outside memory_efficient_attention.
Is there also a graph break around standard PyTorch's F.scaled_dot_product_attention?
If torch.compile is able to preserve the calls to xFormers' mem_efficient and flash_attention impls, for the end users might be somehow less questionable if these (and xFormers' optimized unbind
) also do not trigger graph break, as for now it seems that the fewer graph breaks, the better perf
or at least document that these graph breaks are okay. as it may seem that it currently breaks only because these use custom autograd functions (and it's not clear if all custom autograd functions are causing graph breaks or if there are any special requirements/conditions https://github.com/pytorch/pytorch/issues/103318)
🐛 Bug
I am trying to use
memory_efficient_attention
withtorch.compile()
. But it seems thatmemory_efficient_attention
leads to graph breaks.xformers.ops.unbind
also causes graph breaks.Command
To Reproduce
Model
No
torch.compile()
behaviour:torch.compile behavior
output:
Expected behavior
It would be nice if
xformers
didn't create graph breaksEnvironment
Current colab
This problem also reproduces in my local setup.