NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.6k stars 255 forks source link

Fix compilation bug with CUDA 12.1 #949

Closed Edenzzzz closed 1 week ago

Edenzzzz commented 1 week ago

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

Changes

This has been mentioned in #560 but somehow someone just changed it back... Importing , which imports , after defining nv_bfloat16 triggers re-declaration error.

Checklist:

timmoon10 commented 1 week ago

/te-ci pytorch

Edenzzzz commented 1 week ago

Thanks for catching this bug again. We missed the change in #757.

Please sign the commit (instructions) and I this is ready to go.

Sounds good!