neurophysik / jitcdde

Just-in-time compilation for delay differential equations
Other
56 stars 14 forks source link

jitcdde and jax #52

Open boxaio opened 8 months ago

boxaio commented 8 months ago

is jitcdde compatible with jax?

Wrzlprmft commented 8 months ago

As far as I can tell, Jax is for vectorisable problems (SPMD). The main strength of JiTC*DE is that it can handle non-vectorisable problems (MPMD). Therefore combining this with Jax makes little sense. If your problem is vectorisable, e.g., a typical PDE, there are usually more suited tools than JiTC*DE, and they may use or be Jax.

However, I acknowledge that there is a lack of tools for DDEs and thus you might want to have the best of both worlds for something like a PDDE (partial delay differential equation). However, here you have the additional problem that JiTCDDE’s integrator is strongly intertwined with the computation of the derivative (where Jax might help).

That being said, there may be a solution for some problems where you can benefit from both modules, but for that I would need to know the detailed problem.

boxaio commented 8 months ago

As far as I can tell, Jax is for vectorisable problems (SPMD). The main strength of JiTCDE is that it can handle non-vectorisable problems (MPMD). Therefore combining this with Jax makes little sense. If your problem is vectorisable, e.g., a typical PDE, there are usually more suited tools than JiTCDE, and they may use or be Jax.

However, I acknowledge that there is a lack of tools for DDEs and thus you might want to have the best of both worlds for something like a PDDE (partial delay differential equation). However, here you have the additional problem that JiTCDDE’s integrator is strongly intertwined with the computation of the derivative (where Jax might help).

That being said, there may be a solution for some problems where you can benefit from both modules, but for that I would need to know the detailed problem.

Thanks for your reply. Actually, I was trying to solve DDEs using JiTCDDE, starting from enormous amount of different initial conditions. For solving ODEs with various initial conditions, this is easily done with jax.vmap() function in JAX. However, in JiTDDE, the data are represented using numpy arrays instead of jax.numpy, see the file jitcdde/jitced_template.c line 3: # include <numpy/arrayobject.h> I guess I have to transform from numpy to jax.numpy, but did not find similar file in jax.numpy

Wrzlprmft commented 8 months ago

I see two main ways of how parallelising initial conditions can make you win over JiTCDDE:

In both cases, the more initial conditions you run in parallel, the worse the drawbacks of parallelising:

First, suppose that you parallelise by “copying” your differential equations, e.g., instead of having six three-dimensional systems, you have one eighteen-dimensional system consisting of six non-interacting subsystems.

One of the powers of JiTCDDE is that it uses an adaptive integrator, i.e., the step size changes according to how easily the dynamics can be integrated. The problem with the copying approach is that the integrator will choose the lowest step size needed by any of these six systems and thus many of the systems will be integrated with a smaller step size than necessary. A variation of this approach that avoids this would be adapting step sizes for each subsystem. However, that way some integrations may be longer than others.

Moreover, in many scenarios where you run a lot initial conditions, you have some criterion for aborting a run, e.g., because the dynamics clearly converged to a fixed point. With the copying approach, you lose that option.

The probably most worthwhile way to use the copying approach is to use a fixed step size since the step size will be rather constant anyway for a large number of subsystems. This way, you save some overhead required for step-size computations, but then you are so far removed from what JiTCDDE does that you might as well write your own tool. However, since you do not have to handle changing step sizes, implementing this becomes a lot easier.

For solving ODEs with various initial conditions, this is easily done with jax.vmap() function in JAX.

At a quick glance, unless you create your own solver using Jax routines, this will either not grant you any speed-up compared to basic multi-core parallelisation tools or use the copying approach outlined above with all its drawbacks.