qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

qutip-jax and MEsolve incompatibility using jit #49

Closed ArturDomingues closed 1 month ago

ArturDomingues commented 1 month ago

I was porting the 020_homodyned-Jaynes-Cummings-emission.md file to the tutorial-v5 and run into a CI build problem with a file I didn't even changed which was the JAX-backend.md file, the CI build error was

=================================== FAILURES ===================================
_ /home/runner/work/qutip-tutorials/qutip-tutorials/notebooks/miscellaneous/JAX_backend.ipynb _
---------------------------------------------------------------------------
@jax.jit
def fp(t, w):
    return jax.numpy.exp(1j * t * w)
@jax.jit
def fm(t, w):
    return jax.numpy.exp(-1j * t * w)
@jax.jit
def cte(t, A):
    return A
with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.num(10)
    c_ops = [qutip.QobjEvo([qutip.destroy(10), fm], args={"w": 1.0})]
H.isherm  # Precomputing the `isherm` flag
solver = qutip.MESolver(
    H, c_ops, options={"method": "diffrax", "normalize_output": False}
)
def final_expect(solver, rho0, t, w):
    result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
    return result.e_data[0][-1].real
dfinal_expect_dt = jax.jit(
    jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
)
dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[6], line 35
     29     return result.e_data[0][-1].real
     32 dfinal_expect_dt = jax.jit(
     33     jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
     34 )
---> 35 dfinal_expect_dt(solver,qutip.basis(10,8,dtype="jax"),0.1,1.0)
    [... skipping hidden 21 frame]
Cell In[6], line 28, in final_expect(solver, rho0, t, w)
     27 def final_expect(solver, rho0, t, w):
---> 28     result = solver.run(rho0,[0,t],args={"w":w},e_ops=H)
     29     return result.e_data[0][-1].real
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/qutip/solver/solver_base.py:176, in Solver.run(self, state0, tlist, e_ops, args)
    142 """
    143 Do the evolution of the Quantum system.
    144 
   (...)
    173     can control the saved data in the options.
    174 """
    175 _time_start = time()
--> 176 _data0 = self._prepare_state(state0)
    177 self._integrator.set_state(tlist[0], _data0)
    178 self._argument(args)
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/qutip/solver/solver_base.py:107, in Solver._prepare_state(self, state)
    [104](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:105)     norm = state.norm()
    [105](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:106) # Use the settings atol instead of the solver one since the second
    [106](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:107) # refer to the ODE tolerance and some integrator do not use it.
--> [107](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:108) self._normalized = np.abs(norm-1) <= settings.core["atol"]
    [108](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:109) if self.rhs.dims[1] == state.dims:
    [109](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:110)     return stack_columns(state.data)
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/jax/_src/core.py:650, in Tracer.__array__(self, *args, **kw)
    649 def __array__(self, *args, **kw):
--> 650   raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function final_expect at /tmp/ipykernel_7480/401557022.py:27 for jit. This value became a tracer due to JAX operations on these lines:
  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line /tmp/ipykernel_7480/401557022.py:28 (final_expect)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
=========================== short test summary info ============================
FAILED miscellaneous/JAX_backend.ipynb:: - See https://jax.readthedocs.io/en/...
================== 1 failed, 74 passed in 2407.87s (0:40:07) ===================
Error: Process completed with exit code 1.

After converting the file to a notebook and running locally I realized that it wasn't possible to execute the line

final_expect(solver,qutip.basis(10, 8, dtype="jax"),0.1,1.0)

it raised the following error

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[22], [line 37](vscode-notebook-cell:?execution_count=22&line=37)
     [35](vscode-notebook-cell:?execution_count=22&line=35) get_ipython().run_line_magic('timeit', 'final_expect(solver,qutip.basis(10, 8, dtype="jax"),0.1,1.0)')
     [36](vscode-notebook-cell:?execution_count=22&line=36) get_ipython().run_line_magic('timeit', 'jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)')
---> [37](vscode-notebook-cell:?execution_count=22&line=37) dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
     [38](vscode-notebook-cell:?execution_count=22&line=38) get_ipython().run_line_magic('timeit', 'dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)')

    [... skipping hidden 21 frame]

Cell In[22], [line 27](vscode-notebook-cell:?execution_count=22&line=27)
     [26](vscode-notebook-cell:?execution_count=22&line=26) def final_expect(solver, rho0, t, w):
---> [27](vscode-notebook-cell:?execution_count=22&line=27)     result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
     [28](vscode-notebook-cell:?execution_count=22&line=28)     return result.e_data[0][-1].real

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\qutip\solver\solver_base.py:163, in Solver.run(self, state0, tlist, args, e_ops)
    [129](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:129) """
    [130](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:130) Do the evolution of the Quantum system.
    [131](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:131) 
   (...)
    [160](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:160)     can control the saved data in the options.
    [161](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:161) """
    [162](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:162) _time_start = time()
--> [163](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:163) _data0 = self._prepare_state(state0)
    [164](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:164) self._integrator.set_state(tlist[0], _data0)
    [165](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:165) self._argument(args)

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\qutip\solver\solver_base.py:101, in Solver._prepare_state(self, state)
     [98](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:98)     norm = state.norm()
     [99](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:99) # Use the settings atol instead of the solver one since the second
    [100](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:100) # refer to the ODE tolerance and some integrator do not use it.
--> [101](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:101) self._normalized = np.abs(norm - 1) <= settings.core["atol"]
    [102](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:102) if self.rhs.dims[1] == state.dims:
    [103](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:103)     return stack_columns(state.data)

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\jax\_src\core.py:650, in Tracer.__array__(self, *args, **kw)
    [649](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/jax/_src/core.py:649) def __array__(self, *args, **kw):
--> [650](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/jax/_src/core.py:650)   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function final_expect at C:\Users\artur\AppData\Local\Temp\ipykernel_4980\3189065131.py:26 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line C:\Users\artur\AppData\Local\Temp\ipykernel_4980\3189065131.py:27 (final_expect)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

i couldn't find a fix other than running the line as

jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)

since this modification made everything work fine, I made the pull request with this file altered as well, in the following way

# When qutip-jax is fixed, uncomment the line bellow and deleted the line after it
# dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)

Even though this fixes the problem, it is just a work around since the previous version was working correctly before, there should be an alert to qutip-jax to fix this problem

Originally posted by @ArturDomingues in https://github.com/qutip/qutip-tutorials/issues/100