NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

[TE/JAX] Prototype for New XLA Custom Calls with FFI #946

Open phu0ngng opened 2 weeks ago

phu0ngng commented 2 weeks ago

Description

This PR introduced a prototype for converting TE/JAX custom calls to the new XLA Custom Call with FFI. These primitives are implemented with the new custom calls:

FFI Datatype with FP8 types has been added to XLA in this PR (openxla/xla#13856). All related tests passed.

WIP: Verifying that these custom calls are captured in CudaGraph.

Type of change

Checklist: