Open jan-o-e opened 9 months ago
mesolve
does not support jit
. It does safety checks, manage metadata, use cython, etc. which does not work well inside jit. That's why most example spit the setup (solver = MESolver(...)
) and the computations (solver.run
) and only the second is inside the jit compiled function.
With spline coefficient, the separation of setup and computation is harder if you want to reuse the solver. You would need to pass the InterpolatedUnivariateSpline
as args
.
Ah got it, thanks, yeah i was thinking of passing the spline as an arg.
Following the end of the discussion in https://github.com/qutip/qutip-jax/issues/26 - I am working on the qutip-jax-dia branch and tried to implement cubic splines from jax_cosmo for a time dependent hamiltonian simulation (master equation with static collapse operators) which can be massively speed up with jax for re-running the same simulation with different TD params. I can run the simulation without jitting the function sim() on a CPU but it's very slow (as to be expected).
I know qutip-jax-dia is in beta beta, but maybe some of the clever people here have some suggestions as to why I can't jit the function.
For reference I'm working on osx-arm64 with an M1 chip.
The error message I get is:
Full code is as follows (note that I am just using some constant functions to test the whole thing:
Many thanks!