google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.74k stars 2.72k forks source link

forward and reverse mode gradients for eigvalsh + log or eigvalsh + sqrt returns NaNs #21896

Open h-vijayakumaran opened 2 months ago

h-vijayakumaran commented 2 months ago

Description

This is probably a related one to #1383

When I compose a function with eigvalsh and sqrt or log, I am encountering nan mainly with higher order gradients, like hessian. I have also done a "hardcoded" eigenvalue computation, which seems to give correct results.

Here is minimal example

import jax
jax.config.update('jax_enable_x64', True)
import jax.numpy as np

# Create a random diagonal matrix with same values
seed = jax.random.PRNGKey(0)
# generate a random scalar value 
scalar = jax.random.uniform(seed, (1,1))

a = scalar[0]*np.diag(np.array([1.0 ,1.0]))

def eigen_decomposition_2x2_sym(inp) :
    """
    Compute the eigenvalue of a 2x2 symmetric matrix input
    Args:
        inp (inp.array) : The 2x2 symmetric matrix 
        inp = [[a, b]
            [b, c]] where a, c are the diagonal elements and b is the off-diagonal element

    Returns:
        eigvals (np.ndarray) : The eigenvalues of the 2x2 symmetric matrix
    """

    a = inp.at[0,0].get()
    b = (inp.at[0,1].get() + inp.at[1,0].get())/2
    c = inp.at[1,1].get()

    eigvals = np.array([[ 0.5 * ((a + c) + np.sqrt((a - c)**2 + 4 * b**2 + 1e-16 )),
                        0.5 * ((a + c) - np.sqrt((a - c)**2 + 4 * b**2 + 1e-16))]]).reshape(-1)

    return eigvals

def func1_jax (inp):
    eig_vals = np.linalg.eigvalsh(inp)
    J = np.prod(eig_vals)

    return np.sqrt(J)

def func1_hardcoded (inp):
    eig_vals = eigen_decomposition_2x2_sym(inp)
    J = np.prod(eig_vals)

    return np.sqrt(J)

def func2_jax (inp):
    eig_vals = np.linalg.eigvalsh(inp)
    J = np.prod(eig_vals)

    return np.log(J)

def func2_hardcoded (inp):
    eig_vals = eigen_decomposition_2x2_sym(inp)
    J = np.prod(eig_vals)

    return np.log(J)

def test_deriv(f_jax,f_hardcoded, inp):
    grad_func = jax.jacfwd(f_jax)
    hess_func = jax.jacfwd(grad_func)

    grad_func_hard_coded = jax.jacrev(f_hardcoded)
    hess_func_hard_coded = jax.jacrev(grad_func_hard_coded)

    # Check if the gradient and hessian are the same
    print("Gradients the same as hardcoded?")
    print(np.allclose(grad_func(inp), grad_func_hard_coded(inp)))
    print("Any nans in gradients?")
    print(np.any(np.isnan(grad_func(inp))))

    print("Hessians the same as hardcoded?")
    print(np.allclose(hess_func(inp), hess_func_hard_coded(inp)))
    print("Any nans in hessians?")
    print(np.any(np.isnan(hess_func(inp))))
    if np.any(np.isnan(hess_func(inp))):
        print("Hessian hardcoded:")
        print(hess_func_hard_coded(inp))
        print("Hessian jax:")
        print(hess_func(inp))

print("Input matrix:")
print(a)

print("Sqrt with eigh")
test_deriv(func1_jax, func1_hardcoded, a)

print("log with eigh")
test_deriv(func2_jax, func2_hardcoded, a)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.29
jaxlib: 0.4.29
numpy:  1.26.4

I am using CPU

superbobry commented 2 months ago

It looks like the following is sufficient to reproduce

def f(x):
    return jnp.prod(jnp.linalg.eigvalsh(x))

I traced the nan to this multiplication, where Fmat has infinities, which turn into nans during matmul.

@hawkinsp is this WAI?