Open Jakob-Unfried opened 4 years ago
Hi @Jakob-Unfried
I have the same problem with backpropagation on SVD. specifically, when the matrix is large (most probably you will have the same singular values). Did your implementation work? I mean Did you have stable backpropagation and the loss is decreasing properly? If yes, is it possible to share your implementation or explain it?
@Aminpdi
Sending you a PN, I don't want to think about licensing and it being "clean enough", so I don't want to make it public.
Hi, I have a somewhat similar issue with Jax's SVD decomposition for larger matrices. I wrote a simple test to compare against Numpy's SVD and for larger matrices there is a discrepancy. If I check the decomposition and compare, there seems to be errors in the decomposed matrices (the errors depend on the singular values). I have a routine where I am forcing an operator to become positive after SVD by discarding the negative singular values, but this leads to some NaNs later in my code.
It seems that jax.linalg.svd is tested for matrices of sizes 29 and 53 and with a tolerance of 1e-4, see the test.
import numpy as np
from scipy.stats import unitary_group
from jax import numpy as jnp
np.random.seed(42)
def rand_positive(N):
"""
Creates a random positive operator.
Args:
N (int): Size of the operator.
Returns:
op (ndarray): A positive operator of dimension (N, N)
"""
U = unitary_group.rvs(N)
V = unitary_group.rvs(N)
s = np.zeros((N, N))
np.fill_diagonal(s, np.random.uniform(0, 5, size = (N, N)))
op = np.dot(U@s, np.conjugate(V).T)
return op
def test_svd(N):
"""
Tests if SVD decomposition from Jax matches with Numpy for a
random operator of size (N, N).
Args:
N (int): The size of the operator.
Returns:
assert_list ([bool, bool, bool]): The assertion for how close the decomposition
O = USV matches with Numpy's decomposition.
"""
E = rand_positive(N)
u, s, vh = np.linalg.svd(E)
ju, js, jvh = jnp.linalg.svd(E)
return (np.allclose(u, ju, atol=1e-4, rtol=1e-4),
np.allclose(s, js),
np.allclose(vh, jvh, atol=1e-4, rtol=1e-4))
test_svd(2), test_svd(3), test_svd(5), test_svd(7), test_svd(29), test_svd(53)
which gives output
((True, True, True),
(True, True, True),
(True, True, True),
(False, True, False),
(False, True, False),
(True, True, True))
However, setting the seed to
np.random.seed(29)
leads to the failure of test_svd(29)
.
Is this something that is expected? I am not an expert in the numerics of SVD but it seems that the SVD is incorrect if compared to Numpy's implementation for certain cases. If you think this is an issue, I could open a new thread regarding this.
Meanwhile any help in dealing with this would be very much appreciated. Thanks.
@quantshah Please open a new issue for a new problem.
That said, I suspect the issue is you need to enable 64-bit precision in JAX (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision ). Does that help?
@hawkinsp Thanks for the quick response. I wrote a test using complex128 and enabling 64-bit precision and the errors still remain
import numpy as np
from scipy.stats import unitary_group
from jax.config import config
from jax import numpy as jnp
config.update("jax_enable_x64", True)
np.random.seed(29)
def rand_positive(N):
"""
Creates a random positive operator.
Args:
N (int): Size of the operator.
Returns:
op (ndarray): A positive operator of dimension (N, N)
"""
U = unitary_group.rvs(N)
V = unitary_group.rvs(N)
s = np.zeros((N, N))
np.fill_diagonal(s, np.random.uniform(1., 5., size = (N, N)))
op = np.dot(np.dot(U, s), np.conjugate(V).T)
return op
N = 20
E = rand_positive(N).astype(jnp.complex128)
print(E.dtype)
u, s, vh = np.linalg.svd(E)
ju, js, jvh = jnp.linalg.svd(E)
np.allclose(u, ju), np.allclose(s, js), np.allclose(vh, jvh)
If I actually try to see the difference between the real parts of the matrices then it is quite significant compared to what Numpy gives. I will open a new issue for this if you think it is a bug and requires some further discussions. Thanks.
@Jakob-Unfried : I am facing the same problem and like your safe_inverse
solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?
@Jakob-Unfried : I am facing the same problem and like your
safe_inverse
solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?
see here https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py
@Jakob-Unfried : I am facing the same problem and like your
safe_inverse
solution, but I'm not sure how to apply such a thing when doing backpropagation in Pytorch. Do you have any idea on that?see here https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/adlib/svd.py
Thanks, but I'm not sure I understand how to use it. I'm trying to find the SVD of a matrix Y
. Invoking as U, Sigma, Vt = SVD(Y)
gives the error cannot unpack non-iterable SVD object
.
Here is an example of how to use that function https://github.com/wangleiphy/tensorgrad/blob/master/tensornets/trg.py#L3-L4
That worked, thank you so much!
Hello, It's been a while since this issue was discussed. But I wanted to share the following paper: Robust differentiable SVD. Which proposes a way of using Taylor's extension to avoid the explosion of gradients when eigenvalues are close (or exactly the same). We do lose some parts of the gradients, though. Maybe it is possible to use this implementation only when eigenvalues are close?
If an SVD has degenerate singular values (multiple entries of S are exactly equal), the gradient pass will give
nan
.This is because the AD formula contains something like
1 / (s[i] ** 2 - s[j] ** 2)
which becomesnan
ifs[i] == s[j]
. There are also terms with1 / S
which become 'nan' if any of 'S' is zero.In particular, this situation arises when multiple singular values are
0.0
. and it evaluates the entire gradient of a function tonan
, even if the function does not depend on the rows of U and columns of V that correspond to the0.0
singular values (See example below). Essentially in this case, even if the adjoint gradients of the corresponding slices of U and V are0.
, they are multiplied bynan
and thus givenan
.I don't have a clear vision, what the best case scenario is here. I just wanted to share that this problem arises for me and how i will dirty-fix it. Maybe fix "1." would be of interest to implement in jax? thoughts?
Example
I contruct a matrix that has only a few non-zero singular values (actually, only the first in non-zero, but due two numerical error, the second is non-zero too)
I design a function that only depends on the first two singular values (which are clearly neither degenerate nor zero), and the corresponding rows (columns) of U (V)
And see that the entire gradient becomes
nan
Dirty workaround
I have seen people use
safe_inverse = lambda x, eps: x / (x**2 + eps)
and replace1 / (s[i] ** 2 - s[j] ** 2)
bysafe_inverse(s[i] ** 2 - s[j] ** 2)
as well as1/S
bysafe_inverse(S)
, usually witheps == 1e-12
. So i implemented that. Contact me if anyone is interested