Closed Ericgig closed 1 year ago
grad
is working with solver:
import qutip_jax
import qutip as qt
from qutip.solver.sesolve import sesolve, SeSolver
import jax
@jax.jit
def fp(t, w):
return jax.numpy.exp(1j * t * w)
@jax.jit
def fm(t, w):
return jax.numpy.exp(-1j * t * w)
H = (
qt.num(2)
+ qt.destroy(2) * qt.coefficient(fp, args={"w": 3.1415})
+ qt.create(2) * qt.coefficient(fm, args={"w": 3.1415})
).to("jax")
ket = qt.basis(2, 1)
solver = SeSolver(H, options={"method": "diffrax"})
def f(w, solver):
result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
return result.expect[0][-1].real
jax.grad(f)(0.5, solver)
100.0%. Run time: 1.64s. Est. time left: 00:00:00:00
Total run time: 1.65s
CPU times: user 7.08 s, sys: 37.2 ms, total: 7.11 s
Wall time: 7.03 s
DeviceArray(-0.06224748, dtype=float64, weak_type=True)
The first call is slow, but following calls are fast:
def f(w, solver):
result = solver.run(ket, [0, 1], e_ops=qt.num(2).to("jax"), args={"w":w})
return result.expect[0][-1].real
%time jax.grad(f)(0.2, solver)
100.0%. Run time: 0.01s. Est. time left: 00:00:00:00
Total run time: 0.01s
CPU times: user 31.2 ms, sys: 4.15 ms, total: 35.4 ms
Wall time: 34.4 ms
DeviceArray(-0.09687763, dtype=float64, weak_type=True)
Changes Missing Coverage | Covered Lines | Changed/Added Lines | % | ||
---|---|---|---|---|---|
src/qutip_jax/jaxarray.py | 18 | 24 | 75.0% | ||
src/qutip_jax/qutip_trees.py | 44 | 55 | 80.0% | ||
src/qutip_jax/ode.py | 66 | 83 | 79.52% | ||
src/qutip_jax/qobjevo.py | 93 | 111 | 83.78% | ||
<!-- | Total: | 227 | 279 | 81.36% | --> |
Files with Coverage Reduction | New Missed Lines | % | ||
---|---|---|---|---|
src/qutip_jax/jaxarray.py | 1 | 87.14% | ||
src/qutip_jax/linalg.py | 1 | 93.75% | ||
<!-- | Total: | 2 | --> |
Totals | |
---|---|
Change from base Build 4470247850: | -4.7% |
Covered Lines: | 717 |
Relevant Lines: | 792 |
Hi @Ericgig , I'm trying out the code above. SeSolver
is now named SESolver
I guess, which is ok.
But for the expectation value in qutip.core
there is a conversion to NumPy array
File .../qutip/solver/result.py:330, in <listcomp>(.0)
328 @property
329 def expect(self):
--> 330 return [np.array(e_op) for e_op in self.e_data.values()]
This was introduced in https://github.com/qutip/qutip/pull/2077
@BoxiLi expect
as a list
instead of array
was one issues that often came back when migrating from v4 to v5.
e_data
return the list and do not force the casting to numpy. It should work with jax for auto differentiation here.
Working on Integrator with
jax
,diffrax
andequinox
.jit
andgrad
without issue. But if does not pickle well. I used the changes from qutip/qutip#1816 to have coefficient not force complex output and allow sub project like this to define new coefficient types. (See ericgig/qutip/autodiff for a cleaned #1816).Qobj
andQobjEvo
will be hard to make into Pytree, so I tried to make a simplified versionJaxQobjEvo
. It's immutable and with operations and support forcallable(t, args) -> Qobj
format removed, making it jax friendly is easy. All method are compiled and can be derived.DiffraxIntegrator
is in progress. I can derive the step function and pushing the derivation capacities to theSolver
layer seems doable.