ROCm / apex

A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
BSD 3-Clause "New" or "Revised" License
17 stars 14 forks source link

Enable --focal_loss and --index_mul_2d extensions for ROCm #91

Closed hubertlu-tw closed 1 year ago

jithunnair-amd commented 1 year ago

Both extensions use the --ftz=false compiler flag for nvcc which is to disable flushing of denormals to zero. Not sure what is the default behavior on ROCm.