A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
This PR introduced a prototype for converting TE/JAX custom calls to the new XLA Custom Call with FFI.
These primitives are implemented with the new custom calls:
CastTranspose
ActLu
DActLu
FFI Datatype with FP8 types has been added to XLA in this PR (openxla/xla#13856).
All related tests passed.
WIP: Verifying that these custom calls are captured in CudaGraph.
Type of change
[ ] Documentation change (change only to the documentation, either a fix or a new content)
[ ] Bug fix (non-breaking change which fixes an issue)
[ ] New feature (non-breaking change which adds functionality)
[x] Breaking change (fix or feature that would cause existing functionality to not work as expected)
Description
This PR introduced a prototype for converting TE/JAX custom calls to the new XLA Custom Call with FFI. These primitives are implemented with the new custom calls:
CastTranspose
ActLu
DActLu
FFI Datatype
with FP8 types has been added to XLA in this PR (openxla/xla#13856). All related tests passed.WIP: Verifying that these custom calls are captured in CudaGraph.
Type of change
Checklist: