Open Adrien-Kahn opened 4 months ago
Hi - thanks for the clear report! Indeed this looks like a bug in the logsumexp
implementation; I think the root cause is this line: https://github.com/google/jax/blob/46103f6ff3f4cbac2881fece37d145bba7d3e60d/jax/_src/ops/special.py#L74
It removes the dependence of the output on b
for b = 0
. I suspect the best way of fixing this will be to add a custom JVP rule.
Hello @superbobry, though a beginner, i would like to give this a look to make my first contribution. Any documentation or idea to get me started would be appreciated. I have a clear understanding of how Jacobian is being computed.
Reopening because the fix had to be rolled back
What was the problem?
The change led to failing tests in other projects.
not a contribution. could this implementation https://github.com/ott-jax/ott/blob/561adbbda9a88ea11d42d3262dbb5ce81bd482e8/src/ott/math/utils.py#L142 be useful?
Description
The gradient of
jax.scipy.special.logsumexp
alongb
is wrong whenb
contains a0
. In the following example, the gradient of $f(x_1, x_2) = \log(x_1 + x_2)$ should be:but JAX outputs that the gradient at
(1,0)
is(1,0)
when it should be(1,1)
:output:
There is only a problem for an exact
0
: the gradient taken at(1,1e-30)
outputs the correct value of(1,1)
:output:
Furthermore, if the
logsumexp
is implemented manually, the gradient at(1, 0)
is also correct:output:
cc @PhilipVinc
System info (python version, jaxlib version, accelerator, etc.)