NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
2k stars 333 forks source link

[Feature Request] Please support pytorch and Jax FP8 type without custom data type #391

Open MoFHeka opened 1 year ago

MoFHeka commented 1 year ago

Torch FP8 data type may be released at version 2.1, and Jax FP8 supported has already being released.

mingxu1067 commented 1 year ago

@nouiz Would you have comments on this about JAX? @jeng1220 @zlsh80826 for viz

nouiz commented 1 year ago

There is effort to have native XLA support for fp8. But it is more complex and so it will take more time. TE is the fast path for fp8 kernels.

MoFHeka commented 1 year ago

How about pytorch FP8 data type? I noticed that it was already merged into main branch.

Also I found TE still not supporting attention dot calculation when LLM casual text training. Would it be better when using XLA? @nouiz

Finally I want to know if there is a roadmap for TE FP8 LLM training. Should I choose Jax or torch? Not for today, but for the next year.