Alibaba-MIIL / ASL

Official Pytorch Implementation of: "Asymmetric Loss For Multi-Label Classification"(ICCV, 2021) paper
MIT License
732 stars 102 forks source link

NaNs with fp16 #53

Closed pgsrv closed 3 years ago

pgsrv commented 3 years ago

Hi! I use multi-label AsymmetricLoss with default args in a modified timm's train script. When I turn on native AMP, I get partly NaNs, partly regular floats from ASL. And fp32 is ok.

from log:

pid 4361 INFO: AsymmetricLoss x tensor([[ 73.8750, -10.7578,  64.6250, -20.7031,  81.4375,  42.9688],
        [ 62.3438,  -9.9453,  58.5000, -17.8594,  71.8750,  37.4375],
        [ 38.2500,  -4.2578,  36.7188, -12.7344,  43.1250,  25.2188],
        [ 54.4062,  -6.5781,  49.5000, -14.1250,  64.7500,  29.8281],
        [ 50.4688,  -8.9766,  44.2500, -14.5938,  56.2812,  31.2031],
        [ 59.1875,  -5.0039,  52.4375, -19.5781,  64.8750,  32.9688]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) y tensor([[0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1.]], device='cuda:0')
pid 4361 INFO: AsymmetricLoss loss nan <================================================== NAN
pid 4361 INFO: AsymmetricLoss x tensor([[ 43.3438,  -6.1172,  40.2500, -12.6875,  49.8750,  26.6406],
        [ 46.7188,  -6.6211,  40.7188, -12.7188,  54.7500,  27.3438],
        [ 47.8438,  -4.1523,  44.6250, -15.6172,  51.8750,  27.2344],
        [ 51.0312,  -7.9258,  49.4375, -13.4922,  56.9062,  29.8906],
        [ 50.7500,  -6.8281,  43.8125, -16.8125,  52.2500,  31.2969],
        [ 53.7500, -10.8438,  48.0312, -14.3750,  57.1875,  30.2500]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddmmBackward>) y tensor([[0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 1.]], device='cuda:0')
pid 4361 INFO: AsymmetricLoss loss 54.76890563964844
...
mrT23 commented 3 years ago

I recommend calculating ASL with fp32.

basically instead of doing:

with autocast():
    self.pred = self.model(*self.xb);
    self.loss = self.loss_func(self.pred, *self.yb);

try

with autocast():
    self.pred = self.model(*self.xb);
# no mixed precision here.
self.loss = self.loss_func(self.pred.float(), *self.yb);
pgsrv commented 3 years ago

It works! Thank you!