Open achalddave opened 8 months ago
Hi, Thanks for reporting this! A lot of operators in xFormers don't support torch.compile at the moment. This is on our roadmap, but might take ~months to get there (we might also need to fix some bugs in PyTorch as well...)
Ah, okay, thanks! Is there an issue that tracks this that we could follow? We'd love to support torch.compile+xformers attention in our repo.
We can use this issue to track. However this particular error might be related to FSDP ... xFormers operator will most likely incur a graph break (which will make performance worse), but shouldn't cause an exception or error.
🐛 Bug
Using xformers.memory_efficient_attention with FSDP and torch.compile fails when using bfloat16, but works when using float32. It's unclear to me if this is an xformers bug, an FSDP bug, or a torch.compile bug. It might be related to https://github.com/pytorch/pytorch/issues/112164, and it came up in our codebase where we use xformers: https://github.com/mlfoundations/open_lm/issues/72
Command
torchrun --nproc_per_node 2 script.py
To Reproduce
Steps to reproduce the behavior:
torchrun --nproc_per_node 2 script.py
Expected behavior
Code runs without error.
Environment
Please copy and paste the output from the environment collection script from PyTorch (or fill out the checklist below manually).
You can run the script with:
Additional context
xformers version: 0.0.22.