qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

Higher level classes with jax #14

Closed Ericgig closed 1 year ago

Ericgig commented 1 year ago

Working on Integrator with jax, diffrax and equinox.

Ericgig commented 1 year ago

grad is working with solver:

import qutip_jax
import qutip as qt
from qutip.solver.sesolve import sesolve, SeSolver
import jax
@jax.jit
def fp(t, w):
    return jax.numpy.exp(1j * t * w)

@jax.jit
def fm(t, w):
    return jax.numpy.exp(-1j * t * w)

H = (
    qt.num(2) 
    + qt.destroy(2) * qt.coefficient(fp, args={"w": 3.1415}) 
    + qt.create(2) * qt.coefficient(fm, args={"w": 3.1415})
).to("jax")

ket = qt.basis(2, 1)

solver = SeSolver(H, options={"method": "diffrax"})

def f(w, solver):
    result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
    return result.expect[0][-1].real

jax.grad(f)(0.5, solver)
100.0%. Run time:   1.64s. Est. time left: 00:00:00:00
Total run time:   1.65s
CPU times: user 7.08 s, sys: 37.2 ms, total: 7.11 s
Wall time: 7.03 s

DeviceArray(-0.06224748, dtype=float64, weak_type=True)

The first call is slow, but following calls are fast:

def f(w, solver):
    result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
    return result.expect[0][-1].real

%time jax.grad(f)(0.2, solver)
100.0%. Run time:   0.01s. Est. time left: 00:00:00:00
Total run time:   0.01s
CPU times: user 31.2 ms, sys: 4.15 ms, total: 35.4 ms
Wall time: 34.4 ms

DeviceArray(-0.09687763, dtype=float64, weak_type=True)
coveralls commented 1 year ago

Pull Request Test Coverage Report for Build 5158224090


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/qutip_jax/jaxarray.py 18 24 75.0%
src/qutip_jax/qutip_trees.py 44 55 80.0%
src/qutip_jax/ode.py 66 83 79.52%
src/qutip_jax/qobjevo.py 93 111 83.78%
<!-- Total: 227 279 81.36% -->
Files with Coverage Reduction New Missed Lines %
src/qutip_jax/jaxarray.py 1 87.14%
src/qutip_jax/linalg.py 1 93.75%
<!-- Total: 2 -->
Totals Coverage Status
Change from base Build 4470247850: -4.7%
Covered Lines: 717
Relevant Lines: 792

💛 - Coveralls
BoxiLi commented 1 year ago

Hi @Ericgig , I'm trying out the code above. SeSolver is now named SESolver I guess, which is ok.

But for the expectation value in qutip.core there is a conversion to NumPy array

File .../qutip/solver/result.py:330, in <listcomp>(.0)
    328 @property
    329 def expect(self):
--> 330     return [np.array(e_op) for e_op in self.e_data.values()]

This was introduced in https://github.com/qutip/qutip/pull/2077

Ericgig commented 1 year ago

@BoxiLi expect as a list instead of array was one issues that often came back when migrating from v4 to v5. e_data return the list and do not force the casting to numpy. It should work with jax for auto differentiation here.