huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.43k stars 27.1k forks source link

Switch to sdpa_kernel api with newer torch version #34411

Open mobicham opened 1 month ago

mobicham commented 1 month ago

System Info

Who can help?

I noticed that many files in transformers use the older sdp api torch.backends.cuda.sdp_kernel. We just discovered a bug in Pytorch 2.5.0 and the old sdp api that would make it run slower https://github.com/pytorch/pytorch/issues/138386

It would be a good idea to update to the new api (from torch.nn.attention import sdpa_kernel, SDPBackend) and set the appropriate compile flag to avoid losing as much as 20% of the performance !

Information

Tasks

Reproduction

Gist here as a reference example: https://gist.github.com/mobicham/aa1e77689d9cf866cbea2cb75a53a9e4 More details in the torch issue: https://github.com/pytorch/pytorch/issues/138386

Expected behavior

Examples using sdp with torch 2.5.0 should run at least as fast as 2.4.1

github-actions[bot] commented 3 hours ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.