jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Cannot use grad with sparse matrix multiplication #13118

Closed akr97 closed 1 year ago

akr97 commented 2 years ago

Description

I'm working on an optimization problem over a large sparse system. When the (sparse BCOO) matrix describing the system is created by addition or simply constructed with appropriate indices and values, there is no problem. However, the most natural way of composing the matrix is by multiplying other sparse matrices, which causes evaluating the gradient to fail. Below is a minimal example

import jax
import jax.numpy as jnp
from jax.experimental.sparse import BCOO
from jax import grad

def f1(k):
# test grad with sparse matrix mult. grad_k || A x ||^2
    n = jnp.size(k)

    v = jnp.hstack([k, -k[1:]])
    r = jnp.hstack([jnp.arange(n), jnp.arange(1,n)])
    c = jnp.hstack([jnp.arange(n), jnp.arange(n-1)])
    A = BCOO((v, jnp.vstack([r, c]).T), shape=(n,n))
    #print(A.todense())

    b = jnp.ones(n)

    x = A @ b

    return jnp.linalg.norm(x)**2

def f2(k):
# test grad with sparse matrix mult. grad_k || A x ||^2
    n = jnp.size(k)

    D = BCOO((k, jnp.tile(jnp.arange(n), jnp.array([2,1])).T), shape=(n, n))
    G = BCOO((jnp.hstack([jnp.ones(n), -1*jnp.ones(n-1)]), jnp.vstack([jnp.hstack([jnp.arange(0, n), jnp.arange(1, n)]), jnp.hstack([jnp.arange(0, n), jnp.arange(0, n-1)])]).T), shape=(n,n))
    A = D @ G
    #print(A.todense())

    b = jnp.ones(n)

    x = A @ b

    return jnp.linalg.norm(x)**2

df1 = grad(f1)
df2 = grad(f2)

K0 = jnp.array([5., 10, 15, 20, 25])

print("f1(K0): ", f1(K0))
print("f2(K0): ", f2(K0))

print("grad_f1(K0): ", df1(K0))
print("grad_f2(K0): ", df2(K0))

In this toy example, evaluating both f1 and f2 is fine:

f1(K0): 25.0 f2(K0): 25.0

Calculating the gradient of f1 is fine:

grad_f1(K0): [10. 0. 0. 0. 0.]

Calculating the gradient of f2 fails with the underlying issue:

jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'bcoo_spdot_general' not implemented

Is there any plans to extend jax sparse support? Otherwise, is there any better idea than trying to do the actual matrix construction in using the scipy sparse implementation and then defining a custom_vjp for it?

What jax/jaxlib version are you using?

v0.3.14

Which accelerator(s) are you using?

CPU

Additional system info

Mac M1

NVIDIA GPU info

No response

jakevdp commented 2 years ago

Thanks for the question – the transpose rule (i.e. reverse-mode autodiff) for sparse-sparse matrix multiplication hasn't yet been implemented: https://github.com/google/jax/blob/2ce7eb5b5cd724588dfdf6b78760905579621d85/jax/experimental/sparse/bcoo.py#L1289 It's actually quite a complicated operation, which is why I haven't gotten around to implementing it yet.

In the meantime, as long as you stick with forward-mode auto-diff it should work:

df2 = jax.jacfwd(f2)
print("grad_f2(K0): ", df2(K0))
grad_f2(K0):  [10.  0.  0.  0.  0.]
dawsonc commented 1 year ago

Hello! Any chance of a fix for this issue? I'm facing this in a setting where it would be a real pain to pay the extra cost of forward-mode autodiff (one scalar output and several hundred inputs). Thanks!

Update: I found a workaround. I only need sparse-dense matrix-vector products, and using the sparsify transform with one sparse and one dense input seems to work, so I don't think this issue blocks us.

jakevdp commented 1 year ago

Hi - glad you solved it. Yes sparse-dense matmul has fully-implemented autodiff. It's only reverse-mode sparse-sparse matmul that remains unimplemented.