jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.61k stars 2.82k forks source link

Weird defjvp behavior when finding grad of a scalar that depends on the primal #25101

Open JadM133 opened 3 days ago

JadM133 commented 3 days ago

Description

Hello! I was testing the defjvp feature to deal with some stability issues of the SVD, but I noticed that defjvp doesn't work as expected if the output of the function to be differentiated depends on the primal. Let me share some code. I started by defining new_svd , which should be exactly the same as jax.numpy.svd but with its derivatives defined with defjvp (code from jax._src.lax.linalg.py with minor modifications):

import jax.lax as lax
from jax._src.lax import lax as lax_internal
from jax import custom_jvp
import jax.numpy as jnp
import jax.random as jrandom
import jax

def _extract_diagonal(s):
    i = lax.iota("int32", min(s.shape[-2], s.shape[-1]))
    return s[..., i, i]

def _construct_diagonal(s):
    i = lax.iota("int32", s.shape[-1])
    return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s)

def _H(x):
    return _T(x).conj()

def _T(x):
    return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2))

@custom_jvp
def new_SVD(x):
    return jnp.linalg.svd(x, full_matrices=False)

@new_SVD.defjvp
def _svd_jvp_rule(primals, tangents):
    (A,) = primals
    (dA,) = tangents
    U, s, Vt = jnp.linalg.svd(A, full_matrices=False)

    Ut, V = _H(U), _H(Vt)
    s_dim = s[..., None, :]
    dS = Ut @ dA @ V
    ds = _extract_diagonal(dS.real)

    s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim))
    s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1]))
    s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2))
    F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros
    dSS = s_dim.astype(A.dtype) * dS
    SdS = _T(s_dim.astype(A.dtype)) * dS

    s_zeros = (s == 0).astype(s.dtype)
    s_inv = 1 / (s + s_zeros) - s_zeros
    s_inv_mat = _construct_diagonal(s_inv)
    dUdV_diag = 0.5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype)
    dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag)
    dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS)))

    m, n = A.shape[-2:]
    if m > n:
        dAV = dA @ V
        dU = dU + (dAV - U @ (Ut @ dAV)) * s_inv.astype(A.dtype)
    if n > m:
        dAHU = _H(dA) @ U
        dV = dV + (dAHU - V @ (Vt @ dAHU)) * s_inv.astype(A.dtype)

    return (U, s, Vt), (dU, ds, _H(dV))

Then I wanted to compare the results of jax.value_and_grad, which should be the same since I used the same jvp rule. The code is as follows:

def new_SVD_to_scalar(A):
  U, s, Vt = my_SVD(A)
  # Case 1:
  # return jnp.linalg.norm((U*s) @ Vt - A)
  # Case 2:
  # return jnp.linalg.norm((U*s) @ Vt)

def normal_SVD_to_scalar(A):
    U, s, Vt = jnp.linalg.svd(A, full_matrices=False)
    # Case 1:
    # return jnp.linalg.norm((U*s) @ Vt - A)
    # Case 2:
    # return jnp.linalg.norm((U*s) @ Vt)

def test_random_normal(length, width):
    A = jrandom.uniform(jrandom.PRNGKey(0), (length, width))
    new_res = jax.value_and_grad(new_SVD_to_scalar)(A)
    normal_res = jax.value_and_grad(normal_SVD_to_scalar)(A)
    assert jnp.allclose(new_res[0], normal_res[0])
    assert jnp.allclose(new_res[1], normal_res[1])  # Returns False in Case 1

Uncomment both statements in either Case 1 or Case 2 to alternate between cases. Basically, no problems are encountered when the output of the function does not depend on A (case 2). But once A is used (e.g. case 1), the gradients computed by defjvp and the original svd are different.

Note: The only difference between my jvp and the original is this:

# Original
s, U, Vt = svd_p.bind(
       A, full_matrices=False, compute_uv=True, subset_by_index=subset_by_index,
       algorithm=algorithm,
   )

# In my jvp:
U, s, Vt = jnp.linalg.norm(A, full_matrices=False)

Which, in my opinion, should be different since we don't have access to the primitive svd_p when using defjvp.

Note: I know the outputs of svd_p.bind are permutated, but I don't think it is the reason of the difference.

Am I missing something? Are we not allowed to define an output as a function of the primal? Shouldn't there be an error/warning raised if it is the case?

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35 jaxlib: 0.4.35 numpy: 2.1.2 python: 3.11.0 (main, Oct 24 2022, 18:26:48) [MSC v.1933 64 bit (AMD64)] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Windows', node='MSI', release='10', version='10.0.22631', machine='AMD64')

$ nvidia-smi Tue Nov 26 00:13:02 2024
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 561.09 Driver Version: 561.09 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Driver-Model | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA GeForce RTX 4080 ... WDDM | 00000000:01:00.0 Off | N/A | | N/A 44C P4 20W / 40W | 0MiB / 12282MiB | 0% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=========================================================================================| | No running processes found | +-----------------------------------------------------------------------------------------+