qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
18 stars 7 forks source link

Better ways to extract dimensions for `QobjEvo` #33

Closed BoxiLi closed 7 months ago

BoxiLi commented 7 months ago

I'm not sure if it is better to raise this issue here or in qutip main.

In functions like qeye_like, the dimensions are obtained by first evaluating QobjEvo at t=0.

if isinstance(qobj, QobjEvo):
    qobj = qobj(0)

This is not always jit compatible as the evaluation at t=0 may depend on external values. Below is an oversimplified example.

@jax.jit
def pulse_fun(t, T):
    return t/T
@jax.jit
def tmp(t, T):
    H = qutip.destroy(3) * qutip.coefficient(pulse_fun, args={"T": T})
    return qutip.qeye_like(H).data._jxa
tmp(1., 10.)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape float64[].

A side question, what is the best way to get the jnp.array from Qobj? Is there a public version of Qobj.data._jxa.

Ericgig commented 7 months ago

It's not that easy to fix as qeye_like need the data layer type, not just the dimensions. But in a QobjEvo, different part could have different representation... But yes, it feels unneeded to evaluate the QobjEvo and I will try to think of something.

For the side question: Qobj.data_as("jax") will return the jnp.array from a Qobj.

ps. Feel free to open issues about what is not jit compatible in qutip here.