Hello! I was testing the defjvp feature to deal with some stability issues of the SVD, but I noticed that defjvp doesn't work as expected if the output of the function to be differentiated depends on the primal. Let me share some code. I started by defining new_svd , which should be exactly the same as jax.numpy.svd but with its derivatives defined with defjvp (code from jax._src.lax.linalg.py with minor modifications):
import jax.lax as lax
from jax._src.lax import lax as lax_internal
from jax import custom_jvp
import jax.numpy as jnp
import jax.random as jrandom
import jax
def _extract_diagonal(s):
i = lax.iota("int32", min(s.shape[-2], s.shape[-1]))
return s[..., i, i]
def _construct_diagonal(s):
i = lax.iota("int32", s.shape[-1])
return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)
def _H(x):
return _T(x).conj()
def _T(x):
return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))
@custom_jvp
def new_SVD(x):
return jnp.linalg.svd(x, full_matrices=False)
@new_SVD.defjvp
def _svd_jvp_rule(primals, tangents):
(A,) = primals
(dA,) = tangents
U, s, Vt = jnp.linalg.svd(A, full_matrices=False)
Ut, V = _H(U), _H(Vt)
s_dim = s[..., None, :]
dS = Ut @ dA @ V
ds = _extract_diagonal(dS.real)
s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1]))
s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
dSS = s_dim.astype(A.dtype) * dS
SdS = _T(s_dim.astype(A.dtype)) * dS
s_zeros = (s == 0).astype(s.dtype)
s_inv = 1 / (s + s_zeros) - s_zeros
s_inv_mat = _construct_diagonal(s_inv)
dUdV_diag = 0.5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))
m, n = A.shape[-2:]
if m > n:
dAV = dA @ V
dU = dU + (dAV - U @ (Ut @ dAV)) * s_inv.astype(A.dtype)
if n > m:
dAHU = _H(dA) @ U
dV = dV + (dAHU - V @ (Vt @ dAHU)) * s_inv.astype(A.dtype)
return (U, s, Vt), (dU, ds, _H(dV))
Then I wanted to compare the results of jax.value_and_grad, which should be the same since I used the same jvp rule. The code is as follows:
def new_SVD_to_scalar(A):
U, s, Vt = my_SVD(A)
# Case 1:
# return jnp.linalg.norm((U*s) @ Vt - A)
# Case 2:
# return jnp.linalg.norm((U*s) @ Vt)
def normal_SVD_to_scalar(A):
U, s, Vt = jnp.linalg.svd(A, full_matrices=False)
# Case 1:
# return jnp.linalg.norm((U*s) @ Vt - A)
# Case 2:
# return jnp.linalg.norm((U*s) @ Vt)
def test_random_normal(length, width):
A = jrandom.uniform(jrandom.PRNGKey(0), (length, width))
new_res = jax.value_and_grad(new_SVD_to_scalar)(A)
normal_res = jax.value_and_grad(normal_SVD_to_scalar)(A)
assert jnp.allclose(new_res[0], normal_res[0])
assert jnp.allclose(new_res[1], normal_res[1]) # Returns False in Case 1
Uncomment both statements in either Case 1 or Case 2 to alternate between cases.
Basically, no problems are encountered when the output of the function does not depend on A (case 2). But once A is used (e.g. case 1), the gradients computed by defjvp and the original svd are different.
Note: The only difference between my jvp and the original is this:
# Original
s, U, Vt = svd_p.bind(
A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index,
algorithm=algorithm,
)
# In my jvp:
U, s, Vt = jnp.linalg.norm(A, full_matrices=False)
Which, in my opinion, should be different since we don't have access to the primitive svd_p when using defjvp.
Note: I know the outputs of svd_p.bind are permutated, but I don't think it is the reason of the difference.
Am I missing something? Are we not allowed to define an output as a function of the primal? Shouldn't there be an error/warning raised if it is the case?
System info (python version, jaxlib version, accelerator, etc.)
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Description
Hello! I was testing the defjvp feature to deal with some stability issues of the SVD, but I noticed that defjvp doesn't work as expected if the output of the function to be differentiated depends on the primal. Let me share some code. I started by defining
new_svd
, which should be exactly the same asjax.numpy.svd
but with its derivatives defined with defjvp (code fromjax._src.lax.linalg.py
with minor modifications):Then I wanted to compare the results of
jax.value_and_grad
, which should be the same since I used the same jvp rule. The code is as follows:Uncomment both statements in either Case 1 or Case 2 to alternate between cases. Basically, no problems are encountered when the output of the function does not depend on A (case 2). But once A is used (e.g. case 1), the gradients computed by defjvp and the original svd are different.
Note: The only difference between my jvp and the original is this:
Which, in my opinion, should be different since we don't have access to the primitive svd_p when using defjvp.
Note: I know the outputs of
svd_p.bind
are permutated, but I don't think it is the reason of the difference.Am I missing something? Are we not allowed to define an output as a function of the primal? Shouldn't there be an error/warning raised if it is the case?
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35 jaxlib: 0.4.35 numpy: 2.1.2 python: 3.11.0 (main, Oct 24 2022, 18:26:48) [MSC v.1933 64 bit (AMD64)] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Windows', node='MSI', release='10', version='10.0.22631', machine='AMD64')
$ nvidia-smi Tue Nov 26 00:13:02 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 561.09 Driver Version: 561.09 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 4080 ... WDDM | 00000000:01:00.0 Off | N/A | | N/A 44C P4 20W / 40W | 0MiB / 12282MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+