Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
[NVIDIA] Add config option to use cudnn flash attention #73
The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.
With this PR, users can simply set USE_CUDNN_FLASH_ATTENTION=True in their config and then the attention part will be replaced with the cudnn flash attention.
This PR is to allow users to enable the cudnn flash attention. The PR depends on https://github.com/google/praxis/pull/53.
The preliminary results for the GPT3-5B, we can observe ~30% perf improve on 8xH100 GPUs.
With this PR, users can simply set
USE_CUDNN_FLASH_ATTENTION=True
in their config and then the attention part will be replaced with the cudnn flash attention.cc. @nluehr @zhangqiaorjc