NVIDIA / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
8.16k stars 1.35k forks source link

Add xentropy bf16 support #1790

Open zyeric opened 3 months ago

zyeric commented 3 months ago

Although the kernel supports dispatching to fp16 and bf16 at the same time, there are two incorrect checkers that enforce the input tensor can only be fp16 in the interface.