Closed pgsrv closed 3 years ago
I recommend calculating ASL with fp32.
basically instead of doing:
with autocast():
self.pred = self.model(*self.xb);
self.loss = self.loss_func(self.pred, *self.yb);
try
with autocast():
self.pred = self.model(*self.xb);
# no mixed precision here.
self.loss = self.loss_func(self.pred.float(), *self.yb);
It works! Thank you!
Hi! I use multi-label AsymmetricLoss with default args in a modified timm's train script. When I turn on native AMP, I get partly NaNs, partly regular floats from ASL. And fp32 is ok.
from log: