jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.51k stars 2.8k forks source link

celu activation wavefunction sometimes produces nans where it shouldn't #9873

Closed VolodyaCO closed 11 months ago

VolodyaCO commented 2 years ago

I have a fairly simple full forward neural network written in flax

from flax import linen as nn

class MLP(nn.Module):
    features: Tuple[int]

    @nn.compact
    def __call__(self, x):
        for f in self.features:
            x = nn.celu(nn.Dense(f)(x), alpha=2)
        return jnp.squeeze(nn.Dense(1)(x))

My problem is fairly simple: I standard scale a dataset of 5 features which one target, and I perform mean square error minimisation. Somehow, I get nans in the loss function, triggered by a call in jnp.expm1 inside the celu.

Now, nn.celu calls jax.nn.celu. Checking the code I found out that celu is implemented like this:

def celu(x: Array, alpha: Array = 1.0) -> Array:
    return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))

Fortunately, I checked the code of elu, which is implemented like this:

def elu(x: Array, alpha: Array = 1.0) -> Array:
    safe_x = jnp.where(x > 0, 0., x)
    return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))

I modified celu adding the safe_x line and the nan error completely disappeared.

My question is, why is there a need to have this safe_x? Shouldn't the jnp.where make x safe by default? This looks like a bug in jnp.where, as the third argument should be executed only if x > 0. There should not be a need for the safe_x.

jakevdp commented 2 years ago

This sounds like https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where, and the fix recommended there is similar to your safe_x fix. See that answer for more details about why this happens.

VolodyaCO commented 2 years ago

It looked related. Would you want me to add a fix to those activation functions? @jakevdp

jakevdp commented 2 years ago

Yes, that would be great - thanks!

YouJiacheng commented 2 years ago

Okay, I can reproduce nan by using large positive input (>100). (I find elu's fix to learn how to repro) You are right.


I can't reproduce the nan behavior. And I believe that celu itself do not need safe_x since expm1 is smooth everywhere. (Why elu use safe_x? I don't know.)

jakevdp commented 2 years ago

I did some digging - the elu version of this safe_x change was added in #1538; it seems that the gradient of expm1 is not stable for very large inputs. So yes, I think the same treatment should be given to celu, but it doesn't look like other functions in this file need a fix.

YouJiacheng commented 2 years ago

Gradient of expm1 for large inputs should be inf IMO.

YouJiacheng commented 2 years ago

But I have seen this pattern a lot of times, maybe we can extract the pattern to a function?

jakevdp commented 11 months ago

It looks like this was fixed by https://github.com/google/jax/commit/b0805a8a318696cc39eb53a05dc10f2e9cee6a29