lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

Upgrade sdpa kernel #260

Closed Ryu1845 closed 4 weeks ago

Ryu1845 commented 2 months ago

torch.backends.cuda.sdp_kernel is deprecated, this adds the replacement without changing the API (basically what PyTorch is currently doing internally). This mostly has the effect of removing the following warning.

FutureWarning: torch.backends.cuda.sdp_kernel() is deprecated. In the future, this context manager will be removed. Please see, torch.nn.attention.sdpa_kernel() for the new context manager, with updated signature.
lucidrains commented 2 months ago

@Ryu1845 nice! will this break for some versions of pytorch 2.x?

Ryu1845 commented 2 months ago

Yes I think this is for >=2.3, should I add a fallback, bump up the minimum PyTorch version, or do you think just keeping the deprecated kernel is better?

AugustDev commented 4 weeks ago

Could you please finalise with this, there's a lot of warnings about torch.backends.cuda.sdp_kernel() currently when using x-transformers?

lucidrains commented 4 weeks ago

@AugustDev @Ryu1845 hey Augustinas and Sofian

been procrastinating on this, but decided to make a move just now

could you let me know if it is fixed on the latest version?