qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
18 stars 7 forks source link

diffrax with mesolve is very slow #26

Open nwlambert opened 11 months ago

nwlambert commented 11 months ago

Doing some basic tests with mesolve()+qutip-jax and the diffrax method seems excessively slow, both on cpu and gpu. Manually extracting the ODE data and putting it in diffrax.diffeqsolve() is much quicker, so perhaps there is some bottleneck somewhere?

e.g., comparing

import qutip as qt
import numpy as np
import qutip_jax
import jax
with qt.CoreOptions(default_dtype="jax"):
    N = 2
    a = qt.destroy(N) & qt.qeye(N) & qt.qeye(N)
    b = qt.qeye(N) & qt.destroy(N)  & qt.qeye(N)
    c = qt.qeye(N) & qt.qeye(N) & qt.destroy(N)  
    H = (a.dag()*a + b.dag()*b + c.dag()*c + 
        (a.dag()+a) * (b+b.dag()) + 
        (b.dag()+b) * (c+c.dag())
        )

    c_ops =[a,b,c]

    t = 10
    options = {"method": "diffrax", "normalize_output": False}
    solver = qt.MESolver(H, c_ops, options=options)

    result = solver.run(
            qt.basis(N, 1, dtype="jax") & qt.basis(N,0, dtype="jax") & qt.basis(N,0, dtype="jax"),
            [0, t],
            e_ops=qt.num(N, dtype="jax")& qt.qeye(N, dtype="jax") & qt.qeye(N, dtype="jax")
        )

compare this to a manual attempt using the RHS and initial condition from above (not super sure this is correct, but seems to give reasonable same output)

import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController
L=qt.liouvillian(H, c_ops)
LJ=jax.numpy.array(L.full())
rho0J=jax.numpy.array(qt.operator_to_vector(qt.ket2dm(qt.basis(N, 1) & qt.basis(N,0) & qt.basis(N,0))).full())

def f(t, y, args):
    return LJ @ y
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
term = ODETerm(f)
solver = Dopri5()
y0 = rho0J
solution = diffeqsolve(term, solver, t0=0, t1=t, dt0=0.01, y0=y0, stepsize_controller=stepsize_controller)
nwlambert commented 11 months ago

just to clarify, for the above example, the first snippet takes ~100s on a GPU, bit less on a CPU, second snippet using diffeqsolve() directly takes 114 ms (and the same example with standard mesolve() with normal CSR data layer takes < 600 microseconds)

there's probably some overhead on the GPU side, but this also scales up very badly (increasing N makes the first snippet unusable quickly).

Ericgig commented 11 months ago

I did expect diffrax solver to be slower than normal mesolve (on cpu), but not that much...

On cpu the first snippet takes 3.5s on my computer and using diffeqsolve directly takes 400ms. Not as bad but still not great.

When I use the same algorithm by adding "stepsize_controller" : PIDController(rtol=1e-5, atol=1e-5), "solver": Dopri5() to the options. I get 500ms using qutip, which is about the same. So we are not that bad on cpus. (The defaults we use are "solver": diffrax.Tsit5(), "stepsize_controller": diffrax.ConstantStepSize().)

Can you try to profile on gpu to see why we are so inefficient? My guess is that we always compute a coefficient event for constant QobjEvo with the diffrax method, this coefficient is a function returning a constant, but it's probably computed on the cpu...

ps. I am on the jaxdia branch.

nwlambert commented 11 months ago

Thanks eric, that helps a lot! playing around with combinations of options it seems like the stepsize is the thing that was really slowing it down. e.g., doing the native diffeqsolver() with the same constant stepsize is also very slow (though a little faster than qutip-jax, could just be because of different choice of dt0 in qutip-jax).

I didn't have much luck with the profiler (will keep trying to get something useful out of it), but after some playing around, it's not really clear to me if there's really a problem or not. some examples, with N=4 (to slow things down a bit), and with the Dopri5() solver, PIDController for stepsize:

1) Standard CSR qutip: 0.2s 2) qutip-jax cpu: 14s 3) qutip-jax gpu: 0.7s 4) native jax diffeqsolver() cpu: 14s 4) native jax diffeqsolver() gpu: 0.6s

nwlambert commented 11 months ago

Just a quick addition, I gave the jaxdia branch a try, this is super encouraging! With jaxdia we can really push up the Hilbert space size, and I see some pretty impressive numbers.. For N=10 qutip standard CSR or Dia: 200s qutip-jaxdia-gpu: 5.3s

For N=12 qutip standard CSR or Dia: 824s qutip-jaxdia-gpu: 16s

with standard qutip-jax I tend to run out of memory around N=5, so jaxdia really helps us see some crossover.

I will try and double check I am not messing something up, but this seems very impressive!

jan-o-e commented 7 months ago

Hey guys, interesting discussion here.

Is there a way to include time dependent Hamiltonian params in mesolve with Jax? I just gave the jaxdia branch a try with no luck.

Ericgig commented 7 months ago

The normal list format work if jitted functions are used: H = [H0, [H1, jax.jit(f)]].

jan-o-e commented 7 months ago

Thanks. Can you also put arrays of a length corresponding to the number of steps in the numerical solver in there for the time dependent parameters instead of a function?

This is option 3 described here: https://qutip.org/docs/latest/guide/dynamics/dynamics-time.html

Ericgig commented 7 months ago

No, (not yet). That version use scipy and cython, which don't mixes well with jax. For now, you would have to make / find a spline function that support jit. jax-cosmo seems to have something promising.