Open joeryjoery opened 2 years ago
Hi @joeryjoery. Indeed, jit=True
is necessary for a solver to be "vmappable". We can definitely improve the documentation about this but I am not sure if it's possible to detect this and raise an error message. @froystig will definitely know more.
I'm not sure whether this is possible, but I found out quite awkwardly that
vmapping
solvers strictly requires the optimized function to be staged.For example, this code works fine:
But, this code produces a concretization error due to
jit=False
in the solver.The exception is raised due to the
cond-fun
in thewhile-loop
, so upon inspection I completely understand why this happens. The argument specification to the solver though kind of made me believe that I had the option tojit
or not tojit
, but with avmap
you have tojit
...Could this perhaps be inferred automatically? Or is this unfortunately just a sharp-edge of Jax, in which case I believe it would be useful to make a note of this in the documentation for the
jit
argument in theSolver
classes.