google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.65k stars 181 forks source link

softmax_cross_entropy handles -inf logits incorrectly #896

Closed carlosgmartin closed 6 months ago

carlosgmartin commented 6 months ago

softmax_cross_entropy handles a -inf logit incorrectly when the corresponding label is 0. For example:

import optax
from jax import nn, numpy as jnp
from jax.scipy.special import xlogy

def softmax_cross_entropy_alt_0(logits, labels):
    return -(xlogy(labels, nn.softmax(logits, -1))).sum(-1)

def softmax_cross_entropy_alt_1(logits, labels):
    return -(labels * nn.log_softmax(logits, -1)).sum(-1, where=labels != 0)

def softmax_cross_entropy_alt_2(logits, labels):
    x = labels * nn.log_softmax(logits, -1)
    return -jnp.where(labels == 0, 0, x).sum(-1)

logits = jnp.array([-jnp.inf, 0])
labels = jnp.array([0, 1])
print(optax.softmax_cross_entropy(logits, labels))  # nan
print(optax.softmax_cross_entropy(logits.clip(-1e10), labels))  # -0.0
print(softmax_cross_entropy_alt_0(logits, labels))  # -0.0
print(softmax_cross_entropy_alt_1(logits, labels))  # -0.0
print(softmax_cross_entropy_alt_2(logits, labels))  # -0.0

In principle, it should be doing what softmax_cross_entropy_alt_0 does, which uses xlogy, which treats x * log(y) as 0 when x == 0.

Two possible solutions are given by softmax_cross_entropy_alt_1 and softmax_cross_entropy_alt_2 above.

I can submit a PR for either.

(For context, I encountered this issue when doing action masking of logits in an RL setting.)

vroulet commented 6 months ago

Hello @carlosgmartin, Great catch! If you are up for doing a PR that would be great. Why wouldn't you use the function xlogy, that is softmax_cross_entropy_alt_0? This seems to me to be the safest option.

carlosgmartin commented 6 months ago

@vroulet I was guessing there was some reason why optax chose log_softmax over log (or now xlogy) of softmax. Perhaps log_softmax is supposed to have better numerical properties/stability? If there's no advantage, softmax_cross_entropy_alt_0 would be fine.

carlosgmartin commented 6 months ago

If my math is correct, symbolically,

softmax_cross_entropy(x, p) == logsumexp(x, -1) - (x * p).sum(-1)

which, with "hard zeroing" by the label probabilities p, becomes

logsumexp(x, -1) - jnp.where(p == 0, 0, x * p).sum(-1)

So perhaps this would be the best way to implement it?

On a related note, for anyone who's interested, the standard convention in measure theory is to let ∞ × 0 = 0, which provides some justification/explanation for the "hard zeroing" approach:

vroulet commented 6 months ago

Great point for the fact that we want to treat the logsumexp with a specific implementation and never pass by an actual softmax. Quick question though:

vroulet commented 6 months ago

In other words, I agree with you that y log x = 0 for x=0, y=0 if both x and y are real numbers (the convention stems from looking at the limit towards 0 for both x and y). But here y is a categorical variable that is handily mapped to a set of integers. It seems to me that what you really want is to encode the support of a reference probability distribution against which you compute a kl divergence. You'll see that the code for the kl_divergence handles exactly what you want (in terms of y log x). On the other hand it may not handle properly the softmax, so one may create a loss for such a case.

carlosgmartin commented 6 months ago

@vroulet Not sure I understood you correctly. Can you explain what you mean by "permute the labels" with a concrete example? If a logit is -inf but the corresponding label is not 0, the correct cross entropy in that case should be +inf, I think. Example:

import optax
from jax import nn, numpy as jnp
from jax.scipy.special import logsumexp, xlogy

def softmax_cross_entropy_alt_0(logits, labels):
    return -(xlogy(labels, nn.softmax(logits, -1))).sum(-1)

def softmax_cross_entropy_alt_1(logits, labels):
    return -(labels * nn.log_softmax(logits, -1)).sum(-1, where=labels != 0)

def softmax_cross_entropy_alt_2(logits, labels):
    x = labels * nn.log_softmax(logits, -1)
    return -jnp.where(labels == 0, 0, x).sum(-1)

def softmax_cross_entropy_alt_3(x, p):
    return logsumexp(x, -1) - jnp.where(p == 0, 0, x * p).sum(-1)

logits = jnp.array([-jnp.inf, 0, 1])
labels = jnp.array([0.2, 0.4, 0.4])
print(optax.softmax_cross_entropy(logits, labels))  # inf
print(optax.softmax_cross_entropy(logits.clip(-1e10), labels))  # 2000000000.0
print(softmax_cross_entropy_alt_0(logits, labels))  # inf
print(softmax_cross_entropy_alt_1(logits, labels))  # inf
print(softmax_cross_entropy_alt_2(logits, labels))  # inf
print(softmax_cross_entropy_alt_3(logits, labels))  # inf

So I think all implementations handle this case correctly.

Please correct me if I misinterpreted.

vroulet commented 6 months ago

Absolutely but that creates strange behaviors such as (for logits=[-inf, 2] on datas of {cats, dogs} where y in {0, 1} encodes the belonging to each class and we are getting y=dog)

You see that the choice of labels is completely arbitrary in this case. It's rather strange to get a bug in one case and not in the other one while these labels were chosen arbitrarily.

I agree that one would also permute the logits in this case. I just wanted to point out the arbitrary nature of the labels in classification which is the usual setting of this loss.

Another argument: imagine that, for some reason, in your usual classification pipeline, you got a inf in the logit just for the label 0 (highly improbable but it could happen). In that case, the user would weirdly see a non-inifinite loss while the gradients may be infinite. The bug would be quite hard to trace because the behavior is somewhat unexpected in a classification setting.

That said, I may be over-conservative (softmax_cross_entropy is a widely used function so any change must be taken with extra care), @mtthss may have another take on this.

vroulet commented 6 months ago

Isn't your objective to properly handle the support of a target distribution as in the kl_divergence?

carlosgmartin commented 6 months ago

I wouldn't expect -inf logits to arise in practice (barring severe training pathologies, in which case all bets are off anyway) unless some explicit masking is being applied to the logits, like action masking in RL or, in your example, if the classifier knows ahead of time that the image is definitely not a dog, due to some external "hard" piece of information being passed to the classifier.

I don't think it's necessarily an issue (or avoidable) that there are situations where a value is finite but the gradient is infinite. An example of this is jnp.sqrt(0.) == 0. but jax.grad(jnp.sqrt)(0.) == jnp.inf.

vroulet commented 6 months ago

Ok, you won me. If you can make a PR with softmax_cross_entropy_alt_1 it would be great. We'll see if some integration tests break then.

carlosgmartin commented 6 months ago

Did you mean alt_3?

vroulet commented 6 months ago

alt_3, alt_2, alt_1 look fine to me. The function nn.log_sfotmax carefully handles the logsumexp (I mean they use the so called "log-sum-exp trick" that shifts by the maximal value before applying the log-sum-exp (https://en.wikipedia.org/wiki/LogSumExp)) One may benchmark/test them to select the best one.

carlosgmartin commented 6 months ago

@vroulet Done: #898

vroulet commented 6 months ago

Solved in #898, with #916 ensuring correct gradients w.r.t. labels. Thanks again @carlosgmartin !

carlosgmartin commented 5 months ago

@vroulet Hi Vincent, I'm a little confused by https://github.com/google-deepmind/optax/commit/9f7446eba856108b98600ff070a2fda6048266af. Does this mean that the standard softmax_cross_entropy will not handle the -inf case correctly? Is there any reason to create two separate functions, one that works correctly and one that doesn't? To my understanding, that just seems to sow confusion. Thanks for the clarification.

vroulet commented 5 months ago

Hello Carlos, The first implementation #898 broke an internal test as the derivatives with respect to the labels were not properly handled, see #916. (We could also have raised an error if the user was trying to differentiate w.r.t. labels. But then this would probably need to be done through all losses that have a similar logic. I preferred to simply fix the function to let it handle properly derivatives.) The fix #916 created a spark of instability in internal experiments. I haven't been able to identify why. I had to revert this to ensure not breaking everyone's experiments. I separated then your proposal from the original code to keep access to both.

Now, I can remove your proposal if you think this creates confusion. To integrate your proposal, we need to get proper tests that could let us know whether such a change can break other people's code/experiments (I may have access to codebases that would enable to check such behaviors, but instabilities can happen late in training so this would not be a simple unit test). If you have leads on why the instabilities happened (it could be the custom jvp for example), we can work on that.

carlosgmartin commented 5 months ago

@vroulet

Thanks for letting me know. I think I might know what the problem is, and how to solve it.

My guess is that the gradient mismatch arises because the hard zeroing (taking the constant 0 branch of the jnp.where) causes the gradient with respect to a zero label to be zero, which might not be the case (if the corresponding log_prob is nonzero).

Below is a script showing the issue, as well as a potential solution:

import jax
from jax import nn, numpy as jnp
from jax.test_util import check_grads

def softmax_cross_entropy_1(logits, labels, axis=-1, where=None):
    log_probs = nn.log_softmax(logits, axis=axis, where=where)
    return -(labels * log_probs).sum(axis=axis, where=where)

def softmax_cross_entropy_2(logits, labels, axis=-1, where=None):
    log_probs = nn.log_softmax(logits, axis=axis, where=where)
    force_zero = labels == 0
    x = jnp.where(force_zero, 0, labels * log_probs)
    return -x.sum(axis=axis, where=where)

def softmax_cross_entropy_3(logits, labels, axis=-1, where=None):
    log_probs = nn.log_softmax(logits, axis=axis, where=where)
    force_zero = (labels == 0) & jnp.isneginf(log_probs)
    x = jnp.where(force_zero, 0, labels * log_probs)
    return -x.sum(axis=axis, where=where)

def main():
    labels = jnp.zeros(4).at[0].set(1)

    for f in [
        softmax_cross_entropy_1,
        softmax_cross_entropy_2,
        softmax_cross_entropy_3,
    ]:
        print(f.__name__)

        for logits, target in [
            (jnp.zeros(4), 1.3862944),
            # (jnp.zeros(4).at[1].set(-jnp.inf), 1.0986123),
            # (jnp.zeros(4).at[0].set(-jnp.inf), jnp.inf),
        ]:
            print("logits:", logits)

            output = f(logits, labels)

            if output == target:
                print(f"val ok: {output}")
            else:
                print(f"val not ok: {output} vs. {target}")

            try:
                check_grads(f, (logits, labels), 1)
                print("grad ok")
            except Exception as e:
                print("grad not ok", end="")
                print(e)

        print()

if __name__ == "__main__":
    main()

Output:

softmax_cross_entropy_1
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad ok

softmax_cross_entropy_2
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad not ok
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.24685502
Max relative difference: 0.1283005
 x: array(2.170893, dtype=float32)
 y: array(1.924038, dtype=float32)

softmax_cross_entropy_3
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad ok
vroulet commented 5 months ago

Yes, that's what https://github.com/google-deepmind/optax/pull/916 was about using a customjvp. (see the tests that were added too). We can indeed simply target the corner case. The main issue is the instability we get then in experiments and how to test that.

carlosgmartin commented 5 months ago

Relevant thread regarding the issues nans cause with autodiff: https://github.com/google/jax/issues/1052#issuecomment-514083352.

Could you give some more details about the instabilities you're referring to?

vroulet commented 5 months ago

I don't have access to full logs of such instabilities. The repo responsible for the alarm is the following one: https://github.com/google/aqt/blob/main/aqt/jax/v2/numerics/fp8_numerics_test.py. These tests were failing half the time after https://github.com/google-deepmind/optax/commit/9f7446eba856108b98600ff070a2fda6048266af .

carlosgmartin commented 5 months ago

Thanks. To be clear, are you sure that this new version

def softmax_cross_entropy(logits, labels, axis=-1, where=None):
    log_probs = nn.log_softmax(logits, axis=axis, where=where)
    force_zero = (labels == 0) & jnp.isneginf(log_probs)
    x = jnp.where(force_zero, 0, labels * log_probs)
    return -x.sum(axis=axis, where=where)

causes issues? The forced-zero branch of where is now only taken when log_probs isn't finite, so it seems to me that, on finite inputs, it should always yield the same values and gradients as the current softmax_cross_entropy.