google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
446 stars 68 forks source link

[NVIDIA] Add config option to use cudnn flash attention #73

Closed kaixih closed 2 months ago

kaixih commented 6 months ago

This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53.

The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.

With this PR, users can simply set USE_CUDNN_FLASH_ATTENTION=True in their config and then the attention part will be replaced with the cudnn flash attention.

cc. @nluehr @zhangqiaorjc

kaixih commented 2 months ago

The sdpa is now in the jax public API (see this PR) and we can use it through this custom praxis layer in this PR.

Then, this PR introduced a fiddle config option: USE_CUDNN_FLASH_ATTENTION to turn it on.

cc. @abhinavgoel95 for viz.

kaixih commented 2 months ago

Gentle ping. @zhangqiaorjc