A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
In this PR, the FusedAttnForwardFFI is registered to XLA via 2 phases:
prepare where the te_cudnn_handle_init_ffi is called.
execute phase where the actual fused attention function is called.
In the past, the default was to have only the execute phase.
This helps to avoid memory allocation for the handle in the cudaGraph capturing area, thus resolving the issue we have in the distributed run with cudaGraph.
In general, all the custom calls that involve cuDNN kernels should be registered with these two phases in the future. Since the prepare function needs to have the same function argument list with the execute function ones, variadic args are used to make this te_cudnn_handle_init_ffi general so that it can be called in any custom calls.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[x] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Description
In this PR, the FusedAttnForwardFFI is registered to XLA via 2 phases:
prepare
where thete_cudnn_handle_init_ffi
is called.execute
phase where the actual fused attention function is called. In the past, the default was to have only theexecute
phase. This helps to avoid memory allocation for the handle in the cudaGraph capturing area, thus resolving the issue we have in the distributed run with cudaGraph.In general, all the custom calls that involve cuDNN kernels should be registered with these two phases in the future. Since the prepare function needs to have the same function argument list with the execute function ones, variadic args are used to make this
te_cudnn_handle_init_ffi
general so that it can be called in any custom calls.Type of change
Checklist: