Open BodeTobias opened 2 months ago
Hi - thanks for the question! I think this is behaving as expected.
When it comes to autodiff, sparse matrices are fundamentally different than dense matrices in that zero elements do not enter the computation. That means that when you take the gradient with respect to a sparse matrix, tangent values are only defined for specified matrix elements. On the other hand, when you take the gradient with respect to a dense matrix, tangents are defined for all matrix elements. In particular, this means that when you do something like A_dot @ primals
, the result will differ depending on whether A_dot
is a sparse or a dense matrix.
Another way to think about it: when you're differentiating with respect to a sparse matrix, you're differentiating only with respect to its specified elements. When you're differentiating with respect to a dense matrix, you're differentiating with respect to all of its elements.
We deliberately chose to define autodiff for sparse matrices in this way because it means that sparse operations have sparse gradients – if it were not the case, then JAX's autodiff would be useless in the context of large sparse computations, because the gradients would be dense and blow up the memory.
Does that make sense?
Hi Jake, thanks a lot for the explanation!
I totally agree with your explanation. However, if a dense matrix were treated as a sparse matrix, the values in the derivative should still be the same as when it's defined as a dense matrix, right?
I tried to further simplify the example and wrote a custom matrix-vector multiplication. In A_dot
, all indices seem to be set to [0,0] independently of A
. When I transfer them from A
to A_dot
using A_dot = sparse.BCOO((A_dot.data, A.indices), shape=(2, 2))
, the jvp rule works as I would expect. In case all indices in A_dot
are [0,0], the derivatives with respect to the other components are all overwritten by the one of the first argument. Is there a reason for all indices beeing [0,0]?
import jax
import jax.numpy as jnp
from jax.experimental import sparse
@jax.custom_jvp
def matvec(A, b):
return A @ b
@matvec.defjvp
def matvec_jvp(primals, tangents):
A, b = primals
A_dot, b_dot = tangents
# This resolves my issue. The indices of A_dot were all [0,0]...
A_dot = sparse.BCOO((A_dot.data, A.indices), shape=(2, 2))
primal_result = matvec(A, b)
tangent_result = matvec(A_dot, b) + matvec(A, b_dot)
# jax.debug.print("A_dot: {x}", x = A_dot.data)
# jax.debug.print("indices: {x}", x = A_dot.indices)
# jax.debug.print("b: {x}", x = b)
# jax.debug.print("A_dot @ b: {x}", x = A_dot @ b)
return primal_result, tangent_result
# Test matrix
data = jnp.array([1., 0., 0., 1.])
indices = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
b = jnp.array([0.1, 1.0])
# Sparse matrix
def fun(data, indices, b):
A = sparse.BCOO((data, indices), shape=(2, 2))
return matvec(A, b).sum()
print("sparse: ", jax.jacfwd(fun)(data, indices, b)) # sparse: [0.1 0.1 0.1 0.1]
# # Dense matrix
# def fun(data, indices, b):
# A = sparse.BCOO((data, indices), shape=(2, 2)).todense()
# return matvec(A, b).sum()
# print("dense: ", jax.jacfwd(fun)(data, indices, b)) # dense: [0.1 1. 0.1 1. ]
Description
While trying to equip an external sparse linear solver with a JVP rule, I encountered unexpected behavior related to the sparse BCOO matrix. I'm not sure if it's a bug or if I've overlooked something, but the derivatives differ from my checks. It also works if I pass the data and indices separately and only construct a BCOO matrix within the solver's definition. Attached is the code: