Closed carlosgmartin closed 6 months ago
Hello @carlosgmartin, Great catch! If you are up for doing a PR that would be great. Why wouldn't you use the function xlogy, that is softmax_cross_entropy_alt_0? This seems to me to be the safest option.
@vroulet I was guessing there was some reason why optax chose log_softmax
over log
(or now xlogy
) of softmax
. Perhaps log_softmax
is supposed to have better numerical properties/stability? If there's no advantage, softmax_cross_entropy_alt_0
would be fine.
If my math is correct, symbolically,
softmax_cross_entropy(x, p) == logsumexp(x, -1) - (x * p).sum(-1)
which, with "hard zeroing" by the label probabilities p
, becomes
logsumexp(x, -1) - jnp.where(p == 0, 0, x * p).sum(-1)
So perhaps this would be the best way to implement it?
On a related note, for anyone who's interested, the standard convention in measure theory is to let ∞ × 0 = 0, which provides some justification/explanation for the "hard zeroing" approach:
Great point for the fact that we want to treat the logsumexp
with a specific implementation and never pass by an actual softmax.
Quick question though:
In other words, I agree with you that y log x = 0
for x=0, y=0
if both x
and y
are real numbers (the convention stems from looking at the limit towards 0 for both x
and y
). But here y
is a categorical variable that is handily mapped to a set of integers. It seems to me that what you really want is to encode the support of a reference probability distribution against which you compute a kl divergence. You'll see that the code for the kl_divergence handles exactly what you want (in terms of y log x
). On the other hand it may not handle properly the softmax, so one may create a loss for such a case.
@vroulet Not sure I understood you correctly. Can you explain what you mean by "permute the labels" with a concrete example? If a logit is -inf
but the corresponding label
is not 0, the correct cross entropy in that case should be +inf
, I think. Example:
import optax
from jax import nn, numpy as jnp
from jax.scipy.special import logsumexp, xlogy
def softmax_cross_entropy_alt_0(logits, labels):
return -(xlogy(labels, nn.softmax(logits, -1))).sum(-1)
def softmax_cross_entropy_alt_1(logits, labels):
return -(labels * nn.log_softmax(logits, -1)).sum(-1, where=labels != 0)
def softmax_cross_entropy_alt_2(logits, labels):
x = labels * nn.log_softmax(logits, -1)
return -jnp.where(labels == 0, 0, x).sum(-1)
def softmax_cross_entropy_alt_3(x, p):
return logsumexp(x, -1) - jnp.where(p == 0, 0, x * p).sum(-1)
logits = jnp.array([-jnp.inf, 0, 1])
labels = jnp.array([0.2, 0.4, 0.4])
print(optax.softmax_cross_entropy(logits, labels)) # inf
print(optax.softmax_cross_entropy(logits.clip(-1e10), labels)) # 2000000000.0
print(softmax_cross_entropy_alt_0(logits, labels)) # inf
print(softmax_cross_entropy_alt_1(logits, labels)) # inf
print(softmax_cross_entropy_alt_2(logits, labels)) # inf
print(softmax_cross_entropy_alt_3(logits, labels)) # inf
So I think all implementations handle this case correctly.
Please correct me if I misinterpreted.
Absolutely but that creates strange behaviors such as (for logits=[-inf, 2] on datas of {cats, dogs} where y in {0, 1} encodes the belonging to each class and we are getting y=dog)
You see that the choice of labels is completely arbitrary in this case. It's rather strange to get a bug in one case and not in the other one while these labels were chosen arbitrarily.
I agree that one would also permute the logits in this case. I just wanted to point out the arbitrary nature of the labels in classification which is the usual setting of this loss.
Another argument: imagine that, for some reason, in your usual classification pipeline, you got a inf in the logit just for the label 0 (highly improbable but it could happen). In that case, the user would weirdly see a non-inifinite loss while the gradients may be infinite. The bug would be quite hard to trace because the behavior is somewhat unexpected in a classification setting.
That said, I may be over-conservative (softmax_cross_entropy
is a widely used function so any change must be taken with extra care), @mtthss may have another take on this.
Isn't your objective to properly handle the support of a target distribution as in the kl_divergence?
I wouldn't expect -inf
logits to arise in practice (barring severe training pathologies, in which case all bets are off anyway) unless some explicit masking is being applied to the logits, like action masking in RL or, in your example, if the classifier knows ahead of time that the image is definitely not a dog, due to some external "hard" piece of information being passed to the classifier.
I don't think it's necessarily an issue (or avoidable) that there are situations where a value is finite but the gradient is infinite. An example of this is jnp.sqrt(0.) == 0.
but jax.grad(jnp.sqrt)(0.) == jnp.inf
.
Ok, you won me. If you can make a PR with softmax_cross_entropy_alt_1
it would be great. We'll see if some integration tests break then.
Did you mean alt_3?
alt_3, alt_2, alt_1 look fine to me. The function nn.log_sfotmax
carefully handles the logsumexp (I mean they use the so called "log-sum-exp trick" that shifts by the maximal value before applying the log-sum-exp (https://en.wikipedia.org/wiki/LogSumExp))
One may benchmark/test them to select the best one.
@vroulet Done: #898
Solved in #898, with #916 ensuring correct gradients w.r.t. labels. Thanks again @carlosgmartin !
@vroulet Hi Vincent, I'm a little confused by https://github.com/google-deepmind/optax/commit/9f7446eba856108b98600ff070a2fda6048266af. Does this mean that the standard softmax_cross_entropy
will not handle the -inf case correctly? Is there any reason to create two separate functions, one that works correctly and one that doesn't? To my understanding, that just seems to sow confusion. Thanks for the clarification.
Hello Carlos, The first implementation #898 broke an internal test as the derivatives with respect to the labels were not properly handled, see #916. (We could also have raised an error if the user was trying to differentiate w.r.t. labels. But then this would probably need to be done through all losses that have a similar logic. I preferred to simply fix the function to let it handle properly derivatives.) The fix #916 created a spark of instability in internal experiments. I haven't been able to identify why. I had to revert this to ensure not breaking everyone's experiments. I separated then your proposal from the original code to keep access to both.
Now, I can remove your proposal if you think this creates confusion. To integrate your proposal, we need to get proper tests that could let us know whether such a change can break other people's code/experiments (I may have access to codebases that would enable to check such behaviors, but instabilities can happen late in training so this would not be a simple unit test). If you have leads on why the instabilities happened (it could be the custom jvp for example), we can work on that.
@vroulet
Thanks for letting me know. I think I might know what the problem is, and how to solve it.
My guess is that the gradient mismatch arises because the hard zeroing (taking the constant 0
branch of the jnp.where
) causes the gradient with respect to a zero label
to be zero, which might not be the case (if the corresponding log_prob
is nonzero).
Below is a script showing the issue, as well as a potential solution:
import jax
from jax import nn, numpy as jnp
from jax.test_util import check_grads
def softmax_cross_entropy_1(logits, labels, axis=-1, where=None):
log_probs = nn.log_softmax(logits, axis=axis, where=where)
return -(labels * log_probs).sum(axis=axis, where=where)
def softmax_cross_entropy_2(logits, labels, axis=-1, where=None):
log_probs = nn.log_softmax(logits, axis=axis, where=where)
force_zero = labels == 0
x = jnp.where(force_zero, 0, labels * log_probs)
return -x.sum(axis=axis, where=where)
def softmax_cross_entropy_3(logits, labels, axis=-1, where=None):
log_probs = nn.log_softmax(logits, axis=axis, where=where)
force_zero = (labels == 0) & jnp.isneginf(log_probs)
x = jnp.where(force_zero, 0, labels * log_probs)
return -x.sum(axis=axis, where=where)
def main():
labels = jnp.zeros(4).at[0].set(1)
for f in [
softmax_cross_entropy_1,
softmax_cross_entropy_2,
softmax_cross_entropy_3,
]:
print(f.__name__)
for logits, target in [
(jnp.zeros(4), 1.3862944),
# (jnp.zeros(4).at[1].set(-jnp.inf), 1.0986123),
# (jnp.zeros(4).at[0].set(-jnp.inf), jnp.inf),
]:
print("logits:", logits)
output = f(logits, labels)
if output == target:
print(f"val ok: {output}")
else:
print(f"val not ok: {output} vs. {target}")
try:
check_grads(f, (logits, labels), 1)
print("grad ok")
except Exception as e:
print("grad not ok", end="")
print(e)
print()
if __name__ == "__main__":
main()
Output:
softmax_cross_entropy_1
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad ok
softmax_cross_entropy_2
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad not ok
Not equal to tolerance rtol=0.002, atol=0.002
JVP tangent
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.24685502
Max relative difference: 0.1283005
x: array(2.170893, dtype=float32)
y: array(1.924038, dtype=float32)
softmax_cross_entropy_3
logits: [0. 0. 0. 0.]
val ok: 1.3862943649291992
grad ok
Yes, that's what https://github.com/google-deepmind/optax/pull/916 was about using a customjvp. (see the tests that were added too). We can indeed simply target the corner case. The main issue is the instability we get then in experiments and how to test that.
Relevant thread regarding the issues nans cause with autodiff: https://github.com/google/jax/issues/1052#issuecomment-514083352.
Could you give some more details about the instabilities you're referring to?
I don't have access to full logs of such instabilities. The repo responsible for the alarm is the following one: https://github.com/google/aqt/blob/main/aqt/jax/v2/numerics/fp8_numerics_test.py. These tests were failing half the time after https://github.com/google-deepmind/optax/commit/9f7446eba856108b98600ff070a2fda6048266af .
Thanks. To be clear, are you sure that this new version
def softmax_cross_entropy(logits, labels, axis=-1, where=None):
log_probs = nn.log_softmax(logits, axis=axis, where=where)
force_zero = (labels == 0) & jnp.isneginf(log_probs)
x = jnp.where(force_zero, 0, labels * log_probs)
return -x.sum(axis=axis, where=where)
causes issues? The forced-zero branch of where
is now only taken when log_probs
isn't finite, so it seems to me that, on finite inputs, it should always yield the same values and gradients as the current softmax_cross_entropy
.
softmax_cross_entropy
handles a-inf
logit incorrectly when the corresponding label is 0. For example:In principle, it should be doing what
softmax_cross_entropy_alt_0
does, which usesxlogy
, which treatsx * log(y)
as0
whenx == 0
.Two possible solutions are given by
softmax_cross_entropy_alt_1
andsoftmax_cross_entropy_alt_2
above.I can submit a PR for either.
(For context, I encountered this issue when doing action masking of logits in an RL setting.)