pytorch / glow

Compiler for Neural Network hardware accelerators
Apache License 2.0
3.19k stars 684 forks source link

Call _sdp_attention in nn.functional.mha (#89470) #6038

Closed drisspg closed 1 year ago

drisspg commented 1 year ago

Summary: Replaces the the inline block of code in nn.funcitonal.mha with _scaled_dot_product_attention. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

X-link: https://github.com/pytorch/pytorch/pull/89470

Reviewed By: cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D41625335

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D41625335

facebook-github-bot commented 1 year ago

This pull request was exported from Phabricator. Differential Revision: D41625335