Closed VolodyaCO closed 11 months ago
This sounds like https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where, and the fix recommended there is similar to your safe_x
fix. See that answer for more details about why this happens.
It looked related. Would you want me to add a fix to those activation functions? @jakevdp
Yes, that would be great - thanks!
Okay, I can reproduce nan by using large positive input (>100). (I find elu
's fix to learn how to repro)
You are right.
I can't reproduce the nan behavior.
And I believe that celu
itself do not need safe_x
since expm1
is smooth everywhere.
(Why elu
use safe_x
? I don't know.)
I did some digging - the elu
version of this safe_x
change was added in #1538; it seems that the gradient of expm1
is not stable for very large inputs. So yes, I think the same treatment should be given to celu
, but it doesn't look like other functions in this file need a fix.
Gradient of expm1
for large inputs should be inf
IMO.
But I have seen this pattern a lot of times, maybe we can extract the pattern to a function?
It looks like this was fixed by https://github.com/google/jax/commit/b0805a8a318696cc39eb53a05dc10f2e9cee6a29
I have a fairly simple full forward neural network written in flax
My problem is fairly simple: I standard scale a dataset of 5 features which one target, and I perform mean square error minimisation. Somehow, I get nans in the loss function, triggered by a call in
jnp.expm1
inside thecelu
.Now,
nn.celu
callsjax.nn.celu
. Checking the code I found out thatcelu
is implemented like this:Fortunately, I checked the code of
elu
, which is implemented like this:I modified
celu
adding thesafe_x
line and the nan error completely disappeared.My question is, why is there a need to have this
safe_x
? Shouldn't thejnp.where
makex
safe by default? This looks like a bug injnp.where
, as the third argument should be executed only ifx > 0
. There should not be a need for thesafe_x
.