jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.99k stars 2.75k forks source link

Gradient of SVD with degenerate singular values becomes NaN #2311

Open Jakob-Unfried opened 4 years ago

Jakob-Unfried commented 4 years ago

If an SVD has degenerate singular values (multiple entries of S are exactly equal), the gradient pass will give nan.

This is because the AD formula contains something like 1 / (s[i] ** 2 - s[j] ** 2) which becomes nan if s[i] == s[j]. There are also terms with 1 / S which become 'nan' if any of 'S' is zero.

In particular, this situation arises when multiple singular values are 0.0. and it evaluates the entire gradient of a function to nan, even if the function does not depend on the rows of U and columns of V that correspond to the 0.0 singular values (See example below). Essentially in this case, even if the adjoint gradients of the corresponding slices of U and V are 0., they are multiplied by nan and thus give nan.

I don't have a clear vision, what the best case scenario is here. I just wanted to share that this problem arises for me and how i will dirty-fix it. Maybe fix "1." would be of interest to implement in jax? thoughts?

Example

I contruct a matrix that has only a few non-zero singular values (actually, only the first in non-zero, but due two numerical error, the second is non-zero too)

mat_degenerate = np.ones([5, 5])
U, S, V = np.linalg.svd(mat_degenerate)
print(S)
>>> [4.9999995e+00 1.9792830e-07 0.0000000e+00 0.0000000e+00 0.0000000e+00]

I design a function that only depends on the first two singular values (which are clearly neither degenerate nor zero), and the corresponding rows (columns) of U (V)

def truncated_svd(x, num):
    U, S, V = np.linalg.svd(x, full_matrices=False)
    return U[:, :num], S[:num], V[:num, :]

def foo(x):
    U, S, V = truncated_svd(x, 2)
    return np.real(np.trace(np.dot(V, U))) + np.sum(S)

And see that the entire gradient becomes nan

print(value_and_grad(foo)(mat_degenerate)
>>> (DeviceArray(6.9999995, dtype=float32), DeviceArray([[nan, nan, nan, nan, nan],
              [nan, nan, nan, nan, nan],
              [nan, nan, nan, nan, nan],
              [nan, nan, nan, nan, nan],
              [nan, nan, nan, nan, nan]], dtype=float32))

Dirty workaround

I have seen people use safe_inverse = lambda x, eps: x / (x**2 + eps) and replace 1 / (s[i] ** 2 - s[j] ** 2) by safe_inverse(s[i] ** 2 - s[j] ** 2) as well as 1/S by safe_inverse(S) , usually with eps == 1e-12. So i implemented that. Contact me if anyone is interested

Aminpdi commented 3 years ago

Hi @Jakob-Unfried

I have the same problem with backpropagation on SVD. specifically, when the matrix is large (most probably you will have the same singular values). Did your implementation work? I mean Did you have stable backpropagation and the loss is decreasing properly? If yes, is it possible to share your implementation or explain it?

Jakob-Unfried commented 3 years ago

@Aminpdi

Sending you a PN, I don't want to think about licensing and it being "clean enough", so I don't want to make it public.

quantshah commented 3 years ago

Hi, I have a somewhat similar issue with Jax's SVD decomposition for larger matrices. I wrote a simple test to compare against Numpy's SVD and for larger matrices there is a discrepancy. If I check the decomposition and compare, there seems to be errors in the decomposed matrices (the errors depend on the singular values). I have a routine where I am forcing an operator to become positive after SVD by discarding the negative singular values, but this leads to some NaNs later in my code.

It seems that jax.linalg.svd is tested for matrices of sizes 29 and 53 and with a tolerance of 1e-4, see the test.

import numpy as np
from scipy.stats import unitary_group
from jax import numpy as jnp
np.random.seed(42)

def rand_positive(N):
    """
    Creates a random positive operator.

    Args:
        N (int): Size of the operator.

    Returns:
        op (ndarray): A positive operator of dimension (N, N)
    """
    U = unitary_group.rvs(N)
    V = unitary_group.rvs(N)
    s = np.zeros((N, N))
    np.fill_diagonal(s, np.random.uniform(0, 5, size = (N, N)))
    op = np.dot(U@s, np.conjugate(V).T)

    return op

def test_svd(N):
    """
    Tests if SVD decomposition from Jax matches with Numpy for a
    random operator of size (N, N).

    Args:
        N (int): The size of the operator.

    Returns:
        assert_list ([bool, bool, bool]): The assertion for how close the decomposition
                                          O = USV matches with Numpy's decomposition.
    """
    E = rand_positive(N)

    u, s, vh = np.linalg.svd(E)
    ju, js, jvh = jnp.linalg.svd(E)

    return (np.allclose(u, ju, atol=1e-4, rtol=1e-4),
            np.allclose(s, js),
            np.allclose(vh, jvh, atol=1e-4, rtol=1e-4))

test_svd(2), test_svd(3), test_svd(5), test_svd(7), test_svd(29), test_svd(53)

which gives output

((True, True, True),
 (True, True, True),
 (True, True, True),
 (False, True, False),
 (False, True, False),
 (True, True, True))

However, setting the seed to

np.random.seed(29)

leads to the failure of test_svd(29).

Is this something that is expected? I am not an expert in the numerics of SVD but it seems that the SVD is incorrect if compared to Numpy's implementation for certain cases. If you think this is an issue, I could open a new thread regarding this.

Meanwhile any help in dealing with this would be very much appreciated. Thanks.

hawkinsp commented 3 years ago

@quantshah Please open a new issue for a new problem.

That said, I suspect the issue is you need to enable 64-bit precision in JAX (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ). Does that help?

quantshah commented 3 years ago

@hawkinsp Thanks for the quick response. I wrote a test using complex128 and enabling 64-bit precision and the errors still remain

import numpy as np
from scipy.stats import unitary_group

from jax.config import config
from jax import numpy as jnp

config.update("jax_enable_x64", True)
np.random.seed(29)

def rand_positive(N):
    """
    Creates a random positive operator.

    Args:
        N (int): Size of the operator.

    Returns:
        op (ndarray): A positive operator of dimension (N, N)
    """
    U = unitary_group.rvs(N)
    V = unitary_group.rvs(N)
    s = np.zeros((N, N))
    np.fill_diagonal(s, np.random.uniform(1., 5., size = (N, N)))
    op = np.dot(np.dot(U, s), np.conjugate(V).T)

    return op

N = 20
E = rand_positive(N).astype(jnp.complex128)
print(E.dtype)

u, s, vh = np.linalg.svd(E)
ju, js, jvh = jnp.linalg.svd(E)

np.allclose(u, ju), np.allclose(s, js), np.allclose(vh, jvh)

If I actually try to see the difference between the real parts of the matrices then it is quite significant compared to what Numpy gives. I will open a new issue for this if you think it is a bug and requires some further discussions. Thanks.

Unknown Unknown-1

souryadey commented 2 years ago

@Jakob-Unfried : I am facing the same problem and like your safe_inverse solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?

wangleiphy commented 2 years ago

@Jakob-Unfried : I am facing the same problem and like your safe_inverse solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?

see here https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py

souryadey commented 2 years ago

@Jakob-Unfried : I am facing the same problem and like your safe_inverse solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?

see here https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py

Thanks, but I'm not sure I understand how to use it. I'm trying to find the SVD of a matrix Y. Invoking as U, Sigma, Vt = SVD(Y) gives the error cannot unpack non-iterable SVD object.

wangleiphy commented 2 years ago

Here is an example of how to use that function https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/trg.py#L3-L4

souryadey commented 2 years ago

That worked, thank you so much!

JadM133 commented 1 week ago

Hello, It's been a while since this issue was discussed. But I wanted to share the following paper: Robust differentiable SVD. Which proposes a way of using Taylor's extension to avoid the explosion of gradients when eigenvalues are close (or exactly the same). We do lose some parts of the gradients, though. Maybe it is possible to use this implementation only when eigenvalues are close?