qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
18 stars 7 forks source link

QobjEvo tries to compress the Qobj elements #36

Open BoxiLi opened 7 months ago

BoxiLi commented 7 months ago

jax.jit fails when I try to include more than one time-dependent elements in QobjEvo. This is because QobEvo is trying to compare the Qobjs in the elements and compress them if possible. For my directly calling of QobjEvo, I can set the argument compress=False, but QobjEvo is called also within sesolve:

H = QobjEvo(H, args=args, tlist=tlist)

Example

@jax.jit
def pulse_fun(t, T):
    return t/T

@jax.jit
def tmp(t, T):
    # Demonstrate a simple case with more than one time-dependent elements.
    H = qutip.destroy(3) * qutip.coefficient(pulse_fun, args={"T": T}) + qutip.create(3) * qutip.coefficient(pulse_fun, args={"T": T})
    result = qutip.sesolve(
    H,
    qutip.basis(3),
    [0, 2.],
    options={
        "method": "diffrax",
        "dt0":0.01,
        "progress_bar":""})
    return result.final_state.data_as("JaxArray")

tmp(1., 10.)
Ericgig commented 7 months ago

I think that all solver function calls sesolve, mesolve, etc. do not work. The initialization step is not jax friendly. Returning Qobj is also broken and returning Result is not supported.