qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
105 stars 61 forks source link

Update jax sparse operations when possible #52

Open DanPuzzuoli opened 3 years ago

DanPuzzuoli commented 3 years ago

PR #51 introduces JAX sparse evaluation for models. Due to some limited functionality for the jax BCOO sparse array type, some work arounds were required to implement a few things. Subsequent releases of JAX are expected to eliminate the need for these workarounds, and this issue is a reminder of what these are.

DanPuzzuoli commented 2 years ago

Code samples for the above points:

setup:

import jax.numpy as jnp
from jax import jit, grad
from jax.experimental import sparse as jsparse

# sparse versions of jax.numpy operations
jsparse_sum = jsparse.sparsify(jnp.sum)
jsparse_matmul = jsparse.sparsify(jnp.matmul)
jsparse_add = jsparse.sparsify(jnp.add)
jsparse_subtract = jsparse.sparsify(jnp.subtract)

coeffs = jnp.array([1., 2., 3.])
dense_array = jnp.array([[[0., 1.], [1., 0.]], [[0., 1.], [1j, 0.]], [[0., 1.], [0., 1.]]])
sparse_array = jsparse.BCOO.fromdense(dense_array, n_batch=1)

Test code for linear combo:

def jsparse_linear_combo(coeffs, mats):
    return jsparse_sum(coeffs[:, None, None] * mats, axis=0)

jsparse_linear_combo(coeffs, sparse_array)

Triple product reverse-mode differentiation test:

jsparse_triple_product = jsparse.sparsify(lambda A, X, B: A @ X @ B)

def f(X):
    return jsparse_triple_product(sparse_array, X, sparse_array).real.sum()

jit_grad_f = jit(grad(f))
jit_grad_f(jnp.eye(2, dtype=float))
DanPuzzuoli commented 2 years ago

Update:

As of jax 0.2.26 and jaxlib 0.1.75 the above code snippets work. PR #69 now removes the caveat that LindbladModel.evaluate_rhs cannot be reverse-mode autodiffed when in sparse mode, and changes the autodiff test case to revert to testing reverse-mode autodiff.

Updating jsparse_linear_combo in operator_collections.py still needs to be done: while the above snippet works, simply updating jsparse_linear_combo results in several test failures, and why these are happening needs to be figured out. It's possible they're all just numpy.array v.s. jax.numpy.array type errors in the test case setups.