Closed akr97 closed 1 year ago
Thanks for the question – the transpose rule (i.e. reverse-mode autodiff) for sparse-sparse matrix multiplication hasn't yet been implemented: https://github.com/google/jax/blob/2ce7eb5b5cd724588dfdf6b78760905579621d85/jax/experimental/sparse/bcoo.py#L1289 It's actually quite a complicated operation, which is why I haven't gotten around to implementing it yet.
In the meantime, as long as you stick with forward-mode auto-diff it should work:
df2 = jax.jacfwd(f2)
print("grad_f2(K0): ", df2(K0))
grad_f2(K0): [10. 0. 0. 0. 0.]
Hello! Any chance of a fix for this issue? I'm facing this in a setting where it would be a real pain to pay the extra cost of forward-mode autodiff (one scalar output and several hundred inputs). Thanks!
Update: I found a workaround. I only need sparse-dense matrix-vector products, and using the sparsify
transform with one sparse and one dense input seems to work, so I don't think this issue blocks us.
Hi - glad you solved it. Yes sparse-dense matmul has fully-implemented autodiff. It's only reverse-mode sparse-sparse matmul that remains unimplemented.
Description
I'm working on an optimization problem over a large sparse system. When the (sparse BCOO) matrix describing the system is created by addition or simply constructed with appropriate indices and values, there is no problem. However, the most natural way of composing the matrix is by multiplying other sparse matrices, which causes evaluating the gradient to fail. Below is a minimal example
In this toy example, evaluating both f1 and f2 is fine:
Calculating the gradient of f1 is fine:
Calculating the gradient of f2 fails with the underlying issue:
Is there any plans to extend jax sparse support? Otherwise, is there any better idea than trying to do the actual matrix construction in using the scipy sparse implementation and then defining a custom_vjp for it?
What jax/jaxlib version are you using?
v0.3.14
Which accelerator(s) are you using?
CPU
Additional system info
Mac M1
NVIDIA GPU info
No response