NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
2k stars 333 forks source link

[JAX] Added prepare phase for the FusedAttnForwardFFI #1313

Closed phu0ngng closed 3 weeks ago

phu0ngng commented 3 weeks ago

Description

In this PR, the FusedAttnForwardFFI is registered to XLA via 2 phases:

  1. prepare where the te_cudnn_handle_init_ffi is called.
  2. execute phase where the actual fused attention function is called. In the past, the default was to have only the execute phase. This helps to avoid memory allocation for the handle in the cudaGraph capturing area, thus resolving the issue we have in the distributed run with cudaGraph.

In general, all the custom calls that involve cuDNN kernels should be registered with these two phases in the future. Since the prepare function needs to have the same function argument list with the execute function ones, variadic args are used to make this te_cudnn_handle_init_ffi general so that it can be called in any custom calls.

Type of change

Checklist:

phu0ngng commented 3 weeks ago

/te-ci jax L1

phu0ngng commented 3 weeks ago

/te-ci jax L1