jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.17k stars 2.76k forks source link

Qutip can't handle .jnp arrays #7134

Closed nlwach closed 3 years ago

nlwach commented 3 years ago

I am trying to differentiate a quantum circuit which uses qutip. Whenever I try to create a qutip.Qobj using a jnp.array the created qutip.Qobj is empty. Somewhere the jnparrays differ too much from the usual np arrays. The creation of qutip.Qobj using the standard np.array's works perfectly fine.

Here the code for the jnp arrays:

state = jnp.array([1+0j, 1+0j])*1/jnp.sqrt(2)
Q_state = q.Qobj(state)
print(Q_state)

and the output

Quantum object: dims = [[1], [1]], shape = (1, 1), type = bra
Qobj data =
[[0.]]

And using NumPyarrays:

state = np.array([1+0j, 1+0j])*1/np.sqrt(2)
Q_state = q.Qobj(state)
print(Q_state)

and the output

Quantum object: dims = [[2], [1]], shape = (2, 1), type = ket
Qobj data =
[[0.70710678]
 [0.70710678]]
jakevdp commented 3 years ago

Hi - thanks for the question. It appears that Qutip does not support duck-typed ndarrays, as you can see in the source: https://github.com/qutip/qutip/blob/dc2abc4be6540a3a3c549ed21d8da83f3f0db5aa/qutip/qobj.py#L221-L343

JAX arrays do not meet any of the isinstance checks used there, so they appear to be incompatible with the package. I would suggest explicitly casting the input to numpy, i.e.

Q_state = q.Qobj(np.asarray(state))

If you would like qutip to support JAX and other numpy-like array inputs natively (perhaps by checking for an __array__ attribute on the input), I would suggest opening a feature request in the Qutip repository.