When a parallel solver is used, the solver attempts to parallelize over every single time step that is being solved. As a result, these parallel solvers use massive amounts of memory. In order to enable these solvers to be used at higher dimensions, or over longer time periods, some functionality should be added which limits how many steps are parallelized at once.
A couple different ways to do this:
Have the user manually input how many steps to parallelize over -- fast solution but less user-friendly as it requires a lot of minute control, may be a good initial solve
Have the user input how much memory they have available and then have a heuristic to choose parallelization based on that -- users must note that other parts of dynamics may use additional memory, and it may be hard to calculate exactly how much memory the parallel solver uses
Fully automate the process based on some internal Jax memory limits
To add to the main comment: The parallel solver looping structure is implemented in the function fixed_step_lmde_solver_parallel_template_jax, in the solvers/fixed_step_solvers.py file.
What is the expected behavior?
When a parallel solver is used, the solver attempts to parallelize over every single time step that is being solved. As a result, these parallel solvers use massive amounts of memory. In order to enable these solvers to be used at higher dimensions, or over longer time periods, some functionality should be added which limits how many steps are parallelized at once.
A couple different ways to do this: