qutip / qutip-tutorials

QuTiP Tutorials
BSD 3-Clause "New" or "Revised" License
36 stars 36 forks source link

Port "Steady-State: Homodyned Jaynes-Cummings emission" tutorial to QuTiP 5. Closes #91 #98

Closed ArturDomingues closed 3 months ago

ArturDomingues commented 4 months ago

Closes #91 I am participating in UnitaryHack2024, and I found the issue https://github.com/qutip/qutip-tutorials/issues/91 as one of the proposed challenges. Made some alterations to both mesolve that use Options since it is being deprecated, altering options=Options(old_args) to options={args}, where args are the key-value pairs of possible values that options receive and old_args where kwargs that Options() received. Substituted the parallel_map with list comprehension, since it was breaking when running the notebooks. Changed the inline (\An architecture for self-homodyned nonclassical light\) html link referencing to a proper markdown link referencing ([An architecture for self-homodyned nonclassical light](https://arxiv.org/abs/1611.01566))

hodgestar commented 4 months ago

Thanks @ArturDomingues. Could you add a Testing section at the end try to write some assertions in it that check that the notebook produced sensible results?

I've activate CI. Please address any errors produced once the CI run finishes.

ArturDomingues commented 4 months ago

I didn't even change some of the files that are breaking with this CI builds, what is going on? One other thing, I just edited the james cumming file and added it directly in my github fork, now I have cloned the fork and created a venv to run all necessary stuff, the environment.yml file uses qutip 4.7.0, I changed it before runing so that I have 5.0.2 installed in the environment, do I ignore this change when I commit and pull the fixes?

ArturDomingues commented 4 months ago

Another question, this testing section, should I assert that the g values respect the relations presented at the text part? if so, there are some problems, some g20 factor are greater than one, but that is addressed in the text, what would be some reasonable tests to be applied?

ArturDomingues commented 4 months ago

I wrote some tests, I don't know if they are the most appropriate tests, so please make sure that the Testing section is as you wanted. Also made sure that the formatting of the .md file was correct by running

nbqa black 020_homodyned-Jaynes-Cummings-emission.md
ArturDomingues commented 4 months ago

There was a failed in CI but in a file I didn't made changes to, here is the report

=================================== FAILURES ===================================
_ /home/runner/work/qutip-tutorials/qutip-tutorials/notebooks/miscellaneous/JAX_backend.ipynb _
---------------------------------------------------------------------------
@jax.jit
def fp(t, w):
    return jax.numpy.exp(1j * t * w)
@jax.jit
def fm(t, w):
    return jax.numpy.exp(-1j * t * w)
@jax.jit
def cte(t, A):
    return A
with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.num(10)
    c_ops = [qutip.QobjEvo([qutip.destroy(10), fm], args={"w": 1.0})]
H.isherm  # Precomputing the `isherm` flag
solver = qutip.MESolver(
    H, c_ops, options={"method": "diffrax", "normalize_output": False}
)
def final_expect(solver, rho0, t, w):
    result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
    return result.e_data[0][-1].real
dfinal_expect_dt = jax.jit(
    jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
)
dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[6], line 35
     29     return result.e_data[0][-1].real
     32 dfinal_expect_dt = jax.jit(
     33     jax.grad(final_expect, argnums=[2]), static_argnames=["solver"]
     34 )
---> 35 dfinal_expect_dt(solver,qutip.basis(10,8,dtype="jax"),0.1,1.0)
    [... skipping hidden 21 frame]
Cell In[6], line 28, in final_expect(solver, rho0, t, w)
     27 def final_expect(solver, rho0, t, w):
---> 28     result = solver.run(rho0,[0,t],args={"w":w},e_ops=H)
     29     return result.e_data[0][-1].real
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/qutip/solver/solver_base.py:176, in Solver.run(self, state0, tlist, e_ops, args)
    142 """
    143 Do the evolution of the Quantum system.
    144 
   (...)
    173     can control the saved data in the options.
    174 """
    175 _time_start = time()
--> 176 _data0 = self._prepare_state(state0)
    177 self._integrator.set_state(tlist[0], _data0)
    178 self._argument(args)
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/qutip/solver/solver_base.py:107, in Solver._prepare_state(self, state)
    [104](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:105)     norm = state.norm()
    [105](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:106) # Use the settings atol instead of the solver one since the second
    [106](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:107) # refer to the ODE tolerance and some integrator do not use it.
--> [107](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:108) self._normalized = np.abs(norm-1) <= settings.core["atol"]
    [108](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:109) if self.rhs.dims[1] == state.dims:
    [109](https://github.com/ArturDomingues/qutip-tutorials/actions/runs/9343341432/job/25712773213#step:13:110)     return stack_columns(state.data)
File /usr/share/miniconda3/envs/test-environment/lib/python3.10/site-packages/jax/_src/core.py:650, in Tracer.__array__(self, *args, **kw)
    649 def __array__(self, *args, **kw):
--> 650   raise TracerArrayConversionError(self)
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function final_expect at /tmp/ipykernel_7480/401557022.py:27 for jit. This value became a tracer due to JAX operations on these lines:
  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line /tmp/ipykernel_7480/401557022.py:28 (final_expect)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
=========================== short test summary info ============================
FAILED miscellaneous/JAX_backend.ipynb:: - See https://jax.readthedocs.io/en/...
================== 1 failed, 74 passed in 2407.87s (0:40:07) ===================
Error: Process completed with exit code 1.
ArturDomingues commented 4 months ago

I was trying to fix the .md file of the jax notebook but ran into one problem, and there maybe a compatibility problem in the qutip-jax library or with jax and qutip running together. Since I ran at the and the following lines

print(final_expect(solver,qutip.basis(10, 8, dtype="jax"),0.1,1.0))
print(jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0))
%timeit final_expect(solver,qutip.basis(10, 8, dtype="jax"),0.1,1.0)
%timeit jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
%timeit dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)

and got this as outputs

[c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\equinox\_jit.py:49](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/equinox/_jit.py:49): UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
7.176200305515273
[c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\equinox\_jit.py:49](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/equinox/_jit.py:49): UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)
(Array(-9.6348726, dtype=float64, weak_type=True),)
55.1 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
228 ms ± 21.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

and this as an error

---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[22], [line 37](vscode-notebook-cell:?execution_count=22&line=37)
     [35](vscode-notebook-cell:?execution_count=22&line=35) get_ipython().run_line_magic('timeit', 'final_expect(solver,qutip.basis(10, 8, dtype="jax"),0.1,1.0)')
     [36](vscode-notebook-cell:?execution_count=22&line=36) get_ipython().run_line_magic('timeit', 'jax.grad(final_expect, argnums=[2])(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)')
---> [37](vscode-notebook-cell:?execution_count=22&line=37) dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)
     [38](vscode-notebook-cell:?execution_count=22&line=38) get_ipython().run_line_magic('timeit', 'dfinal_expect_dt(solver, qutip.basis(10, 8, dtype="jax"), 0.1, 1.0)')

    [... skipping hidden 21 frame]

Cell In[22], [line 27](vscode-notebook-cell:?execution_count=22&line=27)
     [26](vscode-notebook-cell:?execution_count=22&line=26) def final_expect(solver, rho0, t, w):
---> [27](vscode-notebook-cell:?execution_count=22&line=27)     result = solver.run(rho0, [0, t], args={"w": w}, e_ops=H)
     [28](vscode-notebook-cell:?execution_count=22&line=28)     return result.e_data[0][-1].real

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\qutip\solver\solver_base.py:163, in Solver.run(self, state0, tlist, args, e_ops)
    [129](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:129) """
    [130](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:130) Do the evolution of the Quantum system.
    [131](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:131) 
   (...)
    [160](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:160)     can control the saved data in the options.
    [161](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:161) """
    [162](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:162) _time_start = time()
--> [163](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:163) _data0 = self._prepare_state(state0)
    [164](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:164) self._integrator.set_state(tlist[0], _data0)
    [165](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:165) self._argument(args)

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\qutip\solver\solver_base.py:101, in Solver._prepare_state(self, state)
     [98](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:98)     norm = state.norm()
     [99](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:99) # Use the settings atol instead of the solver one since the second
    [100](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:100) # refer to the ODE tolerance and some integrator do not use it.
--> [101](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:101) self._normalized = np.abs(norm - 1) <= settings.core["atol"]
    [102](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:102) if self.rhs.dims[1] == state.dims:
    [103](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/qutip/solver/solver_base.py:103)     return stack_columns(state.data)

File c:\Users\artur\anaconda3\envs\qutip-tutorials\lib\site-packages\jax\_src\core.py:650, in Tracer.__array__(self, *args, **kw)
    [649](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/jax/_src/core.py:649) def __array__(self, *args, **kw):
--> [650](file:///C:/Users/artur/anaconda3/envs/qutip-tutorials/lib/site-packages/jax/_src/core.py:650)   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float64[].
The error occurred while tracing the function final_expect at C:\Users\artur\AppData\Local\Temp\ipykernel_4980\3189065131.py:26 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f64[] = convert_element_type[new_dtype=float64 weak_type=False] b
    from line C:\Users\artur\AppData\Local\Temp\ipykernel_4980\3189065131.py:27 (final_expect)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

So the first 4 lines run ok, where the second gives the same answer as this tutorial, but the 5th line, which is the line that is giving the error and is the same line as in the tutorial, don't run properly, I don't know why.

PS.: What seems to be happening is that the "normalize_output":False option is not enough, since in the start value it is being normalized and failing because of the conversion to numpy array

ArturDomingues commented 4 months ago

I made some changes to the JAX_backend.md file, If you find other solve to it, please modify it

hodgestar commented 4 months ago

@ArturDomingues Commenting out the line causing the jAX issue was the correct thing to do for the moment. There were some other small CI issues. I've approved them to run again.

ArturDomingues commented 4 months ago

To me there are no CI build errors, what else needs to be done?

ArturDomingues commented 4 months ago

Hey @hodgestar is there anything more I need to do? Or the way things are now are good for accepting this pull request?

ArturDomingues commented 4 months ago

I just saw that https://github.com/qutip/qutip/pull/2448 was merged to qutip, should I revert the JAX_backend.md file to it's original form?

ArturDomingues commented 4 months ago

Is there anything more to be done @hodgestar ?

ArturDomingues commented 3 months ago

@hodgestar if there are no other changes to be made, could you accept the pull request?