Use torch_xla casting system via the environment variables XLA_DOWNCAST_BF16 or XLA_USE_BF16.
Use the native torch.autocast feature.
The first approach was already supported, this PR adds support for the second approach.
It also fixes issues related to how we can set the NEURON_CC_FLAGS. If they are set too late (e.g after the process group initialization), they will be ignored by the compiler. This PR makes sure we set them at the right time.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
What does this PR do?
There are two ways to cast to
bfloat16
:torch_xla
casting system via the environment variablesXLA_DOWNCAST_BF16
orXLA_USE_BF16
.torch.autocast
feature.The first approach was already supported, this PR adds support for the second approach. It also fixes issues related to how we can set the
NEURON_CC_FLAGS
. If they are set too late (e.g after the process group initialization), they will be ignored by the compiler. This PR makes sure we set them at the right time.