jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.32k stars 2.78k forks source link

Differentiation failing on real function #17401

Closed leviyevalex closed 1 year ago

leviyevalex commented 1 year ago

Description

First off, thank you all for your work on this incredible project!

We are trying to compute the derivative of the potential function below (found in gravitational wave parameter estimation).

We have a complex valued inner product $<\cdot,\cdot> : \mathbb{C}^n \times \mathbb{C}^n \to \mathbb{C}$, and a probability distribution defined by

$$ p(x) = e^{-\frac{1}{2} ||h(x)-d||^2 } $$

where $d\in \mathbb{C}^n$ is a complex valued waveform, $h \in \mathbb{C}^n$ is a complex valued waveform model, and the norm is induced by the inner product in the standard way.

We have a closed form expression for $h$, and have confirmed that the derivatives obtained using JAX (compared with the analytical answer) are in agreement (see code below).

However, we find disagreement when comparing the gradient of the potential $V(x) = - \ln p(x)$ found by JAX and the analytically obtained gradient.

See the following example:

class taylorf2:
    def __init__(self, injection):
        self.injection = injection 

        # Frequency grid
        self.frequency = jnp.linspace(10, 1000, num=1000)

        # Constants
        self.m_sun_sec = 1

        # Data
        self.data = self.strain(self.injection, self.frequency)

    def strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        phi = phase_coalescence

        expr = (Amplitude * jnp.exp(-1j * (-(jnp.pi/4) + 2 * f * jnp.pi * time_coalescence + (3 * (1 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (3715/756 + (55 * eta)/9))) / (128 * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta) - phi))) / f**(7/6)

        return expr

    def gradient_strain(self, x, frequencies):
        time_coalescence = x[0]
        phase_coalescence = x[1]
        chirp_mass = x[2]
        symmetric_mass_ratio = x[3]
        Amplitude = x[4]

        # Redefined for cleaner expression
        f = frequencies
        eta = symmetric_mass_ratio
        Mc = chirp_mass
        S = self.strain(x, frequencies)

        expr1 = -2j * f * jnp.pi * S
        expr2 = 1j * S
        expr3 = (5j * S * (252 + jnp.pi**(2/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(2/3) * (743 + 924 * eta))) / (32256 * Mc * jnp.pi**(5/3) * ((f * Mc * self.m_sun_sec) / eta**(3/5))**(5/3) * eta)
        expr4 = -((1j * S * (-743 + 1386 * eta)) / (16128 * f * Mc * self.m_sun_sec * jnp.pi * eta**(7/5)))
        expr5 = S / Amplitude

        return jnp.array([expr1, expr2, expr3, expr4, expr5])

    def inner_product(self, a, b):
        return jnp.sum(a.conjugate() * b, axis=-1).T

    def potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        return 0.5 * self.inner_product(residual, residual).real

    def gradient_potential_single(self, x):
        residual = self.strain(x, self.frequency) - self.data
        gradient_residual = self.gradient_strain(x, self.frequency)
        return self.inner_product(gradient_residual, residual).real

# Instantiate class
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22])
model = taylorf2(injection)

# Get a point to evaluate at
import numpy as np
x = model.injection + np.random.uniform(low=0, high=0.0001, size=5)

# The model and its derivative are calculated correctly
test1 = jax.jacfwd(model.strain)(x, 10)
test2 = model.gradient_strain(x, 10)

print(test1)
print(test2)

test3 = jax.jacrev(model.potential_single)(x)
test4 = model.gradient_potential_single(x)

# Last component agrees, but all others disagree
print(test3)
print(test4)

Any thoughts on what may be causing this behavior?

What jax/jaxlib version are you using?

jax 0.4.13, jaxlib 0.4.12

Which accelerator(s) are you using?

GPU

Additional system info

WSL

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.04              Driver Version: 536.23       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3070 ...    On  | 00000000:01:00.0 Off |                  N/A |
| N/A   46C    P8              13W / 132W |   8011MiB /  8192MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A        29      G   /Xwayland                                 N/A      |
|    0   N/A  N/A      5252      C   /python3.11                               N/A      |
|    0   N/A  N/A     11134      C   /python3.11                               N/A      |
+---------------------------------------------------------------------------------------+
jakevdp commented 1 year ago

Thanks for the report! It's definitely possible that something in JAX's autodiff machinery is producing incorrect gradients, but my first guess seeing the code would be that your analytical expression is incorrect – it's a relatively complicated expression, and after staring at it for two minutes it's not obvious to me that the gradient is correct (not obvious that it's wrong either). Still, it wouldn't be hard to make a mistake there.

If you think JAX is producing the wrong answer, I'd suggest trying to come up with a reproduction in which it is more straightforward to see and confirm that JAX is producing the wrong value.

leviyevalex commented 1 year ago

Thank you for such a quick reply!

So the method gradient_strain which implements the derivative analytically is confirmed to yield the same result as jax.jacfwd(model.strain). So this portion of the code doesn't have any bugs in it! Its only when we try to differentiate through the inner product that we get bugs. However, this part is very clean.

Since the potential is given by $$V(x) = \frac{1}{2} Re \lt h(x) - d, h(x) - d \gt$$ and since $d$ is a constant, the gradient has the following form $$\nabla V(x) = Re \lt \nabla h(x) , h(x) - d \gt$$ And for reference, the inner product is defined as $$\lt a,b \gt = \sum_i a_i^*b_i$$

If you have any hypotheses on what may be causing this bug I can reply with a more minimal example, but I'm stumped.

Update 1: I got rid of a bunch of fluff in the original code I posted, and the bug remains Update 2: Just an observation, but the last component of the gradient agrees, but all others disagree

mattjj commented 1 year ago

I haven't looked into the details, but I just wanted to point out that numerical checking of gradients (to convince yourself that autodiff is doing the right thing) is pretty easy: just compare to finite differences, like $f'(x) \approx \frac{f(x + \epsilon) - f(x)}{\epsilon}$ (along a random direction). That's what the internal tool jax._src.test_util.check_grads does automatically.

(Also I recommend using \langle and \rangle to typeset inner products in TeX!)

jakevdp commented 1 year ago

One thing that immediately sticks out to me is that you have a 10^22 range in the magnitude of your inputs – it wouldn't be surprising to me if floating point roundoff errors are coming into play with any computation involving numbers at such different scales.

leviyevalex commented 1 year ago

I believe you have the right idea Jake. It appears that changing the scales for the amplitude variable fixes the issue. Specifically, changing

self.PSD = 1e-40 * jnp.ones_like(self.frequency)
injection = jnp.array([0, 0, 30.0, 0.24, 2e-22])

to

self.PSD = jnp.ones_like(self.frequency)
injection = jnp.array([0, 0, 30.0, 0.24, 2])

resolves the issue.

Could you explain why this would be an issue?

Also, would it be possible to have JAX throw a warning in a situation like this?

jakevdp commented 1 year ago

Could you explain why this would be an issue?

For example:

>>> 1.0 + 1E-22 == 1.0
True

Floating point math is only an approximation of real math. 64-bit floating point only can represent about 16 decimal places in a single value; 32-bit can only represent about 8. So if you're doing arithmetic operations involving numbers of very different scales, floating point expressions are likely to lose precision.

Also, would it be possible to have JAX throw a warning in a situation like this?

No, this is just something you have to work with and be aware of when doing floating point operations on all modern systems. There's a good reference on this topic here: https://stackoverflow.com/q/588004

leviyevalex commented 1 year ago

Makes sense! I appreciate the help.