qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

`tr` is not jit compatible because of isherm #34

Open BoxiLi opened 4 months ago

BoxiLi commented 4 months ago

tr checks first if the matrix is Hermitian and changes the type of the output.

...
out = _data.trace(self._data)
# This ensures that trace can return something that is not a number such
# as a `tensorflow.Tensor` in qutip-tensorflow.
return out.real if (self.isherm
                and hasattr(out, "real")
                ) else out

However, when determining the isherm property, jnp.allclose is used to check if the matrix is close to hermitian.

@partial(jit, static_argnames=["tol"])
def _isherm(matrix, tol):
    return jnp.allclose(matrix, matrix.T.conj(), atol=tol, rtol=0)

In principle, following the philosophy of jit, anything related to the property of a matrix and will be used in branching the computation should not evaluate the matrix. Maybe for JaxArray we should just leave it as false if it can not be derived explicitly?

Example:

import qutip
import qutip_jax
import jax
import jax.numpy as np
qutip.settings.core["default_dtype"] = "jax"
@jax.jit
def tmp(a):
    m = qutip.Qobj(np.array([[1., a], [np.conjugate(a), 1.]]))
    return m.tr()
tmp(1.-1.j)
BoxiLi commented 4 months ago

What I don't understand is, dag seems to work fine for this case

@jax.jit
def tmp(a):
    m = qutip.Qobj(np.array([[1., a], [np.conjugate(a), 1.]]))
    m.dag()
    return 0.
tmp(1.-1.j)

where _isherm is used instead of isherm.

def dag(self):
    """Get the Hermitian adjoint of the quantum object."""
    if self._isherm:
        return self.copy()