lessw2020 / Ranger21

Ranger deep learning optimizer rewrite to use newest components
Apache License 2.0
320 stars 44 forks source link

hit nan for variance_normalized #19

Closed jimmiebtlr closed 3 years ago

jimmiebtlr commented 3 years ago

Not certain this is a bug yet, but I'm getting this rarely after awhile of training and am not finding an issue in my side. Input to loss function looks good (no nan's). I'm working with a fairly complex loss function though, so very possible I have a rare bug in my code.

I'm using the following options

Ranger21(
      params=params, lr=3e-4, 
      num_epochs=1e12, num_batches_per_epoch=1, num_warmup_iterations=1000, 
      using_gc=True, weight_decay=1e-4, use_madgrad=True
      )

I've seen this with a batch size of 4-128 so far, so doesn't seem to be dependent on that.

lessw2020 commented 3 years ago

Hi @jimmiebtlr, Thanks for opening the issue! One quick test - would you mind running with AdamW (i.e. use_madgrad=False, or just omit the param). The madgrad side of the code is less explored atm vs AdamW core so if it reproes only one side that would be helpful to know where to investigate.

jimmiebtlr commented 3 years ago

I'll give that a shot, it takes some time to reproduce, so will post back when I find one way or another. Thanks!

jimmiebtlr commented 3 years ago

Rarely coming up without the use_madgrad param as well. Will make a local copy and see if I can track it down more.

jimmiebtlr commented 3 years ago

Definetly happening in the gradient calculation somewhere.

lessw2020 commented 3 years ago

Hi @jimmiebtlr - I've just checked in a softplus transform update along with gradient normalization, that should be helpful here as it boosts very small values, and set a new high on our small benchmark dataset with it.

twmht commented 2 years ago

@jimmiebtlr and @lessw2020

I also jump into this issue, Running with SGD without any problem, So my loss function would be fine enough.

any update on this?

jimmiebtlr commented 2 years ago

The issue was in my own code from what I recall.