jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.47k stars 2.8k forks source link

Wrong derivatives when using defjvp for sparse.BCOO matrices #23235

Open BodeTobias opened 2 months ago

BodeTobias commented 2 months ago

Description

While trying to equip an external sparse linear solver with a JVP rule, I encountered unexpected behavior related to the sparse BCOO matrix. I'm not sure if it's a bug or if I've overlooked something, but the derivatives differ from my checks. It also works if I pass the data and indices separately and only construct a BCOO matrix within the solver's definition. Attached is the code:


import jax
import jax.numpy as jnp
from jax.experimental import sparse

## Solver definitions
# Sparse solver equipped with jvp rule, here seems to be a problem
@jax.custom_jvp
def solve_fun_sparse(A, b):
    return jnp.linalg.solve(A.todense(), b)

@solve_fun_sparse.defjvp
def solve_fun_jvp(primals, tangents):
    A, b = primals
    A_dot, b_dot = tangents
    primal_result = solve_fun_sparse(A, b)
    result_dot = solve_fun_sparse(A, b_dot - A_dot @ primal_result)
    return primal_result, result_dot

# Using data and indices manually works
@jax.custom_jvp
def solve_fun_just_data(data, indices, b):
    A_mat = sparse.BCOO((data, indices), shape=(2, 2))
    return jnp.linalg.solve(A_mat.todense(), b)
@solve_fun_just_data.defjvp
def solve_fun_jvp(primals, tangents):
    data, indices, b = primals
    data_dot, _, b_dot = tangents
    primal_result = solve_fun_just_data(data, indices, b)
    A_dot_mat = sparse.BCOO((data_dot, indices), shape=(2, 2))
    rhs = b_dot - A_dot_mat @ primal_result
    result_dot = solve_fun_just_data(data, indices, rhs)
    return primal_result, result_dot

# Dense solver, just as reference
@jax.custom_jvp
def solve_fun_dense(A, b):
    return jnp.linalg.solve(A, b)
@solve_fun_dense.defjvp
def solve_fun_jvp(primals, tangents):
    A, b = primals
    A_dot, b_dot = tangents
    primal_result = solve_fun_dense(A, b)
    result_dot = solve_fun_dense(A, b_dot - A_dot @ primal_result)
    return primal_result, result_dot

## Test data
# Test with duplicate entries
data = jnp.array([2., 3., 4., 5.])
indices = jnp.array([[0, 0], [1, 1], [0, 1], [1, 1]])
b = jnp.array([1.0, 1.0])

# # Test with unique entries
# data = jnp.array([2., 3., 4., 5.])
# indices = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# b = jnp.array([1.0, 1.0])

# # Test with unique entries which are symmetric
# data = jnp.array([2., 3., 3., 5.])
# indices = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# b = jnp.array([1.0, 1.0])

## Tests
# With sparse matrices
def loss(data, indices, b):
    A = sparse.BCOO((data, indices), shape=(2, 2))
    return jnp.sum(solve_fun_sparse(A, b))
derivative = jax.jacfwd(loss)(data, indices, b)
print("Derivative with sparse matrix:", derivative)

# With dense matrix
def loss(data, indices, b):
    A = sparse.BCOO((data, indices), shape=(2, 2))
    return jnp.sum(solve_fun_dense(A.todense(), b))
derivative = jax.jacfwd(loss)(data, indices, b)
print("Derivative with dense matrix:", derivative)

# Direct check
def loss(data, indices, b):
    A = sparse.BCOO((data, indices), shape=(2, 2))
    return jnp.sum(jnp.linalg.solve(A.todense(), b))
derivative = jax.jacfwd(loss)(data, indices, b)
print("Derivative using jax solver:", derivative)

# Just with data
def loss(data, indices, b):
    return jnp.sum(solve_fun_just_data(data, indices, b))
derivative = jax.jacfwd(loss)(data, indices, b)
print("Derivative with data and indices:", derivative)

# Here is the output:
# Derivative with sparse matrix: [-0.125 -0.125 -0.125 -0.125]
# Derivative with dense matrix: [-0.125     0.015625 -0.0625    0.015625]
# Derivative using jax solver: [-0.125     0.015625 -0.0625    0.015625]
# Derivative with data and indices: [-0.125     0.015625 -0.0625    0.015625]

### System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.4
python: 3.11.9 (tags/v3.11.9:de54cf5, Apr  2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='levy', release='10', version='10.0.22631', machine='AMD64')
jakevdp commented 2 months ago

Hi - thanks for the question! I think this is behaving as expected.

When it comes to autodiff, sparse matrices are fundamentally different than dense matrices in that zero elements do not enter the computation. That means that when you take the gradient with respect to a sparse matrix, tangent values are only defined for specified matrix elements. On the other hand, when you take the gradient with respect to a dense matrix, tangents are defined for all matrix elements. In particular, this means that when you do something like A_dot @ primals, the result will differ depending on whether A_dot is a sparse or a dense matrix.

Another way to think about it: when you're differentiating with respect to a sparse matrix, you're differentiating only with respect to its specified elements. When you're differentiating with respect to a dense matrix, you're differentiating with respect to all of its elements.

We deliberately chose to define autodiff for sparse matrices in this way because it means that sparse operations have sparse gradients – if it were not the case, then JAX's autodiff would be useless in the context of large sparse computations, because the gradients would be dense and blow up the memory.

Does that make sense?

BodeTobias commented 2 months ago

Hi Jake, thanks a lot for the explanation!

I totally agree with your explanation. However, if a dense matrix were treated as a sparse matrix, the values in the derivative should still be the same as when it's defined as a dense matrix, right?

I tried to further simplify the example and wrote a custom matrix-vector multiplication. In A_dot, all indices seem to be set to [0,0] independently of A. When I transfer them from A to A_dot using A_dot = sparse.BCOO((A_dot.data, A.indices), shape=(2, 2)), the jvp rule works as I would expect. In case all indices in A_dot are [0,0], the derivatives with respect to the other components are all overwritten by the one of the first argument. Is there a reason for all indices beeing [0,0]?


import jax
import jax.numpy as jnp
from jax.experimental import sparse

@jax.custom_jvp
def matvec(A, b):
    return A @ b

@matvec.defjvp
def matvec_jvp(primals, tangents):
    A, b = primals
    A_dot, b_dot = tangents

    # This resolves my issue. The indices of A_dot were all [0,0]...
    A_dot = sparse.BCOO((A_dot.data, A.indices), shape=(2, 2))

    primal_result = matvec(A, b)
    tangent_result = matvec(A_dot, b) + matvec(A, b_dot)

    # jax.debug.print("A_dot: {x}", x = A_dot.data)
    # jax.debug.print("indices: {x}", x = A_dot.indices)
    # jax.debug.print("b: {x}", x = b)
    # jax.debug.print("A_dot @ b: {x}", x = A_dot @ b)

    return primal_result, tangent_result

# Test matrix
data = jnp.array([1., 0., 0., 1.])
indices = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
b = jnp.array([0.1, 1.0])

# Sparse matrix
def fun(data, indices, b):
    A = sparse.BCOO((data, indices), shape=(2, 2))
    return matvec(A, b).sum()
print("sparse: ", jax.jacfwd(fun)(data, indices, b)) # sparse:  [0.1 0.1 0.1 0.1]

# # Dense matrix
# def fun(data, indices, b):    
#     A = sparse.BCOO((data, indices), shape=(2, 2)).todense()
#     return matvec(A, b).sum()
# print("dense: ", jax.jacfwd(fun)(data, indices, b)) # dense:  [0.1 1.  0.1 1. ]