qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

qobj methods depending on isherm attribute give TracerBoolConversionError #30

Open flowerthrower opened 7 months ago

flowerthrower commented 7 months ago

The following code:

@jax.jit
def sin(t, p):
    return p[0] * jnp.sin(p[1] * t + p[2])

with qt.CoreOptions(default_dtype="jax"):
    H = qt.QobjEvo([qt.sigmax(), sin], args={"p": [1., 1., 0.]})
    solver = qt.MESolver(H, options={'method': 'diffrax', "normalize_output": False})

@jax.jit
def f(T, p):
    evo = solver.run(qt.qeye(2, dtype="jax"),
                     tlist=[0, T],
                     args={'p': p})
    diff =  evo.final_state - qt.destroy(2, dtype="jax")
    return diff.tr()

print(f(1., [1., 1., 0.]))

will trigger TracerBoolConversionError due to the if/else clause in qutip.core.qobj.tr() (same goes for qutip.core.qobj.dag()). See qutip.core.qobj:

    def tr(self):
        """Trace of a quantum object.

        Returns
        -------
        trace : float
            Returns the trace of the quantum object.

        """
        out = _data.trace(self._data)
        # This ensures that trace can return something that is not a number such
        # as a `tensorflow.Tensor` in qutip-tensorflow.
        return out.real if (self.isherm
                        and hasattr(out, "real")
                        ) else out

I understand that the subtraction operation diff = evo.final_state - qt.destroy(2, dtype="jax") might result in a non-Hermitian operator, so the value is not defined before runtime -> raising the error. However, to me it seems that the if-else clause is not strictly necessary (I guess the reason for it, is performance of the tr() and dag() method). Is there a way to prevent this within the jax-package or does that require changes in qutip.core.qobj?

Ericgig commented 7 months ago

If you know that diff is not hermitian, you should be able to set diff._isherm = False so the isherm is defined at compilation.

The reason of the if else is for users experience. Traces of dm are expected to be real.

flowerthrower commented 7 months ago

Thank you Eric, I guess it does not hurt to set diff._isherm = False for my problem (even if in some cases it might be Hermitian). Another workaround is to directly call diff.data.trace(). The issue raised the question, if it would be more sensible to put the if else clause into the datalayer implementation of tr (and e.g. dag).