Open DanPuzzuoli opened 3 years ago
Code samples for the above points:
setup:
import jax.numpy as jnp
from jax import jit, grad
from jax.experimental import sparse as jsparse
# sparse versions of jax.numpy operations
jsparse_sum = jsparse.sparsify(jnp.sum)
jsparse_matmul = jsparse.sparsify(jnp.matmul)
jsparse_add = jsparse.sparsify(jnp.add)
jsparse_subtract = jsparse.sparsify(jnp.subtract)
coeffs = jnp.array([1., 2., 3.])
dense_array = jnp.array([[[0., 1.], [1., 0.]], [[0., 1.], [1j, 0.]], [[0., 1.], [0., 1.]]])
sparse_array = jsparse.BCOO.fromdense(dense_array, n_batch=1)
Test code for linear combo:
def jsparse_linear_combo(coeffs, mats):
return jsparse_sum(coeffs[:, None, None] * mats, axis=0)
jsparse_linear_combo(coeffs, sparse_array)
Triple product reverse-mode differentiation test:
jsparse_triple_product = jsparse.sparsify(lambda A, X, B: A @ X @ B)
def f(X):
return jsparse_triple_product(sparse_array, X, sparse_array).real.sum()
jit_grad_f = jit(grad(f))
jit_grad_f(jnp.eye(2, dtype=float))
Update:
As of jax 0.2.26 and jaxlib 0.1.75 the above code snippets work. PR #69 now removes the caveat that LindbladModel.evaluate_rhs
cannot be reverse-mode autodiffed when in sparse mode, and changes the autodiff test case to revert to testing reverse-mode autodiff.
Updating jsparse_linear_combo
in operator_collections.py
still needs to be done: while the above snippet works, simply updating jsparse_linear_combo
results in several test failures, and why these are happening needs to be figured out. It's possible they're all just numpy.array
v.s. jax.numpy.array
type errors in the test case setups.
PR #51 introduces JAX sparse evaluation for models. Due to some limited functionality for the jax
BCOO
sparse array type, some work arounds were required to implement a few things. Subsequent releases of JAX are expected to eliminate the need for these workarounds, and this issue is a reminder of what these are.jsparse_linear_combo
is defined at the beginning ofoperator_collections.py
that takes a linear combination of sparse arrays (specified as a 3d BCOO array), with the coefficients given in a dense 1d array. This cannot be achieved using a sparisfied version ofjnp.tensordot
as the design convention jax is using is that such operations will always output dense arrays if at least one input is dense. Hence,jsparse_linear_combo
multiplies the coefficients against the sparse array directly via broadcasting. However, sparse-dense element-wise multiplication, at the time of writing, is limited to arrays of the same shape, and therefore it is necessary to explicitly blow up the coefficient array to a dense array with the same shape as the sparse array (which is huge). I'm not sure if this is done via views so it's okay, but this should be changed when possible regardless.LindbladModel
in jax-sparse not being reverse-mode differentiable. It is however, forward mode differentiable. At some point this will change, and we will need to remove the caveat inLindbladModel.evaluation_mode
that sparse mode with jax is not reverse-mode differentiable.~