Closed leviyevalex closed 1 year ago
Thanks for the report! It's definitely possible that something in JAX's autodiff machinery is producing incorrect gradients, but my first guess seeing the code would be that your analytical expression is incorrect – it's a relatively complicated expression, and after staring at it for two minutes it's not obvious to me that the gradient is correct (not obvious that it's wrong either). Still, it wouldn't be hard to make a mistake there.
If you think JAX is producing the wrong answer, I'd suggest trying to come up with a reproduction in which it is more straightforward to see and confirm that JAX is producing the wrong value.
Thank you for such a quick reply!
So the method gradient_strain
which implements the derivative analytically is confirmed to yield the same result as jax.jacfwd(model.strain)
. So this portion of the code doesn't have any bugs in it! Its only when we try to differentiate through the inner product that we get bugs. However, this part is very clean.
Since the potential is given by $$V(x) = \frac{1}{2} Re \lt h(x) - d, h(x) - d \gt$$ and since $d$ is a constant, the gradient has the following form $$\nabla V(x) = Re \lt \nabla h(x) , h(x) - d \gt$$ And for reference, the inner product is defined as $$\lt a,b \gt = \sum_i a_i^*b_i$$
If you have any hypotheses on what may be causing this bug I can reply with a more minimal example, but I'm stumped.
Update 1: I got rid of a bunch of fluff in the original code I posted, and the bug remains Update 2: Just an observation, but the last component of the gradient agrees, but all others disagree
I haven't looked into the details, but I just wanted to point out that numerical checking of gradients (to convince yourself that autodiff is doing the right thing) is pretty easy: just compare to finite differences, like $f'(x) \approx \frac{f(x + \epsilon) - f(x)}{\epsilon}$ (along a random direction). That's what the internal tool jax._src.test_util.check_grads
does automatically.
(Also I recommend using \langle
and \rangle
to typeset inner products in TeX!)
One thing that immediately sticks out to me is that you have a 10^22 range in the magnitude of your inputs – it wouldn't be surprising to me if floating point roundoff errors are coming into play with any computation involving numbers at such different scales.
I believe you have the right idea Jake. It appears that changing the scales for the amplitude variable fixes the issue. Specifically, changing
self.PSD = 1e-40 * jnp.ones_like(self.frequency)
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22])
to
self.PSD = jnp.ones_like(self.frequency)
injection = jnp.array([0, 0, 30.0, 0.24, 2])
resolves the issue.
Could you explain why this would be an issue?
Also, would it be possible to have JAX throw a warning in a situation like this?
Could you explain why this would be an issue?
For example:
>>> 1.0 + 1E-22 == 1.0
True
Floating point math is only an approximation of real math. 64-bit floating point only can represent about 16 decimal places in a single value; 32-bit can only represent about 8. So if you're doing arithmetic operations involving numbers of very different scales, floating point expressions are likely to lose precision.
Also, would it be possible to have JAX throw a warning in a situation like this?
No, this is just something you have to work with and be aware of when doing floating point operations on all modern systems. There's a good reference on this topic here: https://stackoverflow.com/q/588004
Makes sense! I appreciate the help.
Description
First off, thank you all for your work on this incredible project!
We are trying to compute the derivative of the potential function below (found in gravitational wave parameter estimation).
We have a complex valued inner product $<\cdot,\cdot> : \mathbb{C}^n \times \mathbb{C}^n \to \mathbb{C}$, and a probability distribution defined by
$$ p(x) = e^{-\frac{1}{2} ||h(x)-d||^2 } $$
where $d\in \mathbb{C}^n$ is a complex valued waveform, $h \in \mathbb{C}^n$ is a complex valued waveform model, and the norm is induced by the inner product in the standard way.
We have a closed form expression for $h$, and have confirmed that the derivatives obtained using JAX (compared with the analytical answer) are in agreement (see code below).
However, we find disagreement when comparing the gradient of the potential $V(x) = - \ln p(x)$ found by JAX and the analytically obtained gradient.
See the following example:
Any thoughts on what may be causing this behavior?
What jax/jaxlib version are you using?
jax 0.4.13, jaxlib 0.4.12
Which accelerator(s) are you using?
GPU
Additional system info
WSL
NVIDIA GPU info