qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
19 stars 7 forks source link

failing to jit mesolve for td ham parameters #35

Open jan-o-e opened 9 months ago

jan-o-e commented 9 months ago

Following the end of the discussion in https://github.com/qutip/qutip-jax/issues/26 - I am working on the qutip-jax-dia branch and tried to implement cubic splines from jax_cosmo for a time dependent hamiltonian simulation (master equation with static collapse operators) which can be massively speed up with jax for re-running the same simulation with different TD params. I can run the simulation without jitting the function sim() on a CPU but it's very slow (as to be expected).

I know qutip-jax-dia is in beta beta, but maybe some of the clever people here have some suggestions as to why I can't jit the function.

For reference I'm working on osx-arm64 with an M1 chip.

The error message I get is:

Traceback (most recent call last):
  File "/Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py", line 105, in <module>
    ys = fast_sim(single_sample)
  File "/Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py", line 65, in sim
    result=qt.mesolve(H=[[H_p, jax.jit(pump)], [H_s, jax.jit(stokes)], [H_dp, jax.jit(delta_pump)], [H_d2, jax.jit(delta_stokes)]],
  File "/Users/janoleernst/anaconda3/envs/qutip-jax/lib/python3.10/site-packages/qutip/solver/mesolve.py", line 128, in mesolve
    H = QobjEvo(H, args=args, tlist=tlist)
  File "qutip/core/cy/qobjevo.pyx", line 242, in qutip.core.cy.qobjevo.QobjEvo.__init__
  File "qutip/core/cy/qobjevo.pyx", line 807, in qutip.core.cy.qobjevo.QobjEvo.compress
  File "qutip/core/cy/qobjevo.pyx", line 761, in qutip.core.cy.qobjevo.QobjEvo._compress_merge_qobj
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function sim at /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:42 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:c128[2,4] = pjit[name=atleast_2d jaxpr={ lambda ; b:c128[2,4]. let  in (b,) }] c
    from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)

  operation a:c128[2,4] = pjit[name=atleast_2d jaxpr={ lambda ; b:c128[2,4]. let  in (b,) }] c
    from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)

  operation a:bool[] = pjit[
  name=allclose
  jaxpr={ lambda ; b:c128[2,4] c:f64[] d:f64[]. let
      e:bool[2,4] = pjit[
        name=isclose
        jaxpr={ lambda ; f:c128[2,4] g:f64[] h:f64[] i:f64[]. let
            j:c128[] = convert_element_type[
              new_dtype=complex128
              weak_type=False
            ] g
            k:f64[] = convert_element_type[new_dtype=float64 weak_type=False] h
            l:f64[] = convert_element_type[new_dtype=float64 weak_type=False] i
            m:c128[2,4] = sub f j
            n:f64[2,4] = abs m
            o:f64[] = abs j
            p:f64[] = mul k o
            q:f64[] = add l p
            r:bool[2,4] = le n q
            s:bool[2,4] = pjit[
              name=isinf
              jaxpr={ lambda ; t:c128[2,4]. let
                  u:f64[2,4] = real t
                  v:f64[2,4] = imag t
                  w:f64[2,4] = abs u
                  x:bool[2,4] = eq w inf
                  y:f64[2,4] = abs v
                  z:bool[2,4] = eq y inf
                  ba:bool[2,4] = or x z
                in (ba,) }
            ] f
            bb:bool[] = pjit[
              name=isinf
              jaxpr={ lambda ; bc:c128[]. let
                  bd:f64[] = real bc
                  be:f64[] = imag bc
                  bf:f64[] = abs bd
                  bg:bool[] = eq bf inf
                  bh:f64[] = abs be
                  bi:bool[] = eq bh inf
                  bj:bool[] = or bg bi
                in (bj,) }
            ] j
            bk:bool[2,4] = or s bb
            bl:bool[2,4] = and s bb
            bm:bool[2,4] = not bk
            bn:bool[2,4] = and r bm
            bo:bool[2,4] = eq f j
            bp:bool[2,4] = and bl bo
            bq:bool[2,4] = or bn bp
            br:bool[2,4] = ne f f
            bs:bool[] = ne j j
            bt:bool[2,4] = or br bs
            bu:bool[2,4] = not bt
            bv:bool[2,4] = and bq bu
          in (bv,) }
      ] b c 1e-05 d
      bw:bool[] = reduce_and[axes=(0, 1)] e
    in (bw,) }
] bx by bz
    from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Full code is as follows (note that I am just using some constant functions to test the whole thing:

import qutip as qt
import time
import numpy as np
import qutip_jax
import jax
import jax.numpy as jnp
import jax_cosmo.scipy.interpolate as inter
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController
from scipy.signal.windows import blackman

with qt.CoreOptions(default_dtype="jaxdia"):
    #initial state of the system
    rho0=qt.fock_dm(4,0)
    #simulation params
    n_steps = 10
    T=1
    resolution=10
    omega_0=30
    Omega_0 = omega_0
    gamma = 1.
    #Pump and Stokes Hamiltonians
    H_p = 0.5 * (qt.projection(4,1,0) + qt.projection(4,0,1))
    H_s = 0.5 * (qt.projection(4,2,1) + qt.projection(4,1,2))
    # Detuning Hamiltonian
    H_dp = qt.projection(4,1,1)
    H_d2 = qt.projection(4,2,2)

    time_list=jnp.linspace(0.,T, resolution*n_steps, dtype=jnp.float64)
    delta=0.0*omega_0
    Delta=0.0*omega_0
    #delta=Omega_0*0.15
    #Delta=-23.5*delta #uncomment if you want to introduce a bias
    H_d = Delta * qt.projection(4,1,1) + delta * qt.projection(4,2,2)
    #Artificial environment Lindblad operator:
    L = jnp.sqrt(gamma) * qt.projection(4,3,1)

    # Define  function that takes input arrays and transforms them into cubic splines
    def cubic_spline(input_array):
        return inter.InterpolatedUnivariateSpline(time_list, input_array)

def sim(single_action_sample):

    # Carries out a full episode of system dynamics for a single action sample within a batch
    delta_p, delta_2, omega_p, omega_s = single_action_sample

    # Define the control signals
    pump = cubic_spline(omega_p)
    stokes = cubic_spline(omega_s)
    delta_pump = cubic_spline(delta_p)
    delta_stokes = cubic_spline(delta_2)

    # Define mesovle options
    options = {
    "method": "diffrax", 
    "normalize_output": True, 
    "stepsize_controller" : PIDController(rtol=1e-5, atol=1e-5), 
    "solver": Dopri5()
    }  
    #start_time=time.time()
    result=qt.mesolve(H=[[H_p, jax.jit(pump)], [H_s, jax.jit(stokes)], [H_dp, jax.jit(delta_pump)], [H_d2, jax.jit(delta_stokes)]], 
                        rho0=rho0, tlist=time_list, c_ops=L, options=options)
    #final_time=time.time()
    #print(f"Time taken for single step: {final_time-start_time}")
    #reward is expectation value of final desired state
    reward = result.states[-1][2,2]

    return reward

# Test the environment
if __name__ == "__main__":
    # Test the step function  
    amp_stokes=jnp.array(50*blackman(100), dtype=jnp.complex128)
    amp_pump=jnp.array(50*blackman(100), dtype=jnp.complex128)
    det_stokes=jnp.array(-10*blackman(100), dtype=jnp.complex128)
    det_pump=jnp.array(0*100, dtype=jnp.complex128)
    single_sample = jnp.array(
        [det_pump, det_stokes,amp_pump, amp_stokes], dtype=jnp.complex128
    )

    #Test the sim function
    start=time.time()
    res=sim(single_sample)
    print(f"Time taken: {time.time()-start} ")
    fast_sim=jax.jit(sim) 
    start=time.time()
    ys = fast_sim(single_sample)
    print(f"Time taken jitted: {time.time()-start} ")
    print(ys)

Many thanks!

Ericgig commented 9 months ago

mesolve does not support jit. It does safety checks, manage metadata, use cython, etc. which does not work well inside jit. That's why most example spit the setup (solver = MESolver(...)) and the computations (solver.run) and only the second is inside the jit compiled function.

With spline coefficient, the separation of setup and computation is harder if you want to reuse the solver. You would need to pass the InterpolatedUnivariateSpline as args.

jan-o-e commented 9 months ago

Ah got it, thanks, yeah i was thinking of passing the spline as an arg.