Closed lmriccardo closed 1 year ago
hi @lmriccardo . You don't need to decorate with @jax.value_and_grad . GradientDescent expects the loss function, not its gradient (it will compute the gradient itself)
Hi @fabianp , thanks for the response. I saw that it is also possible to pass the gradient of the function then setting the attribute value_and_grad=True
of GradientDescent
. In fact it works properly if there are no vmap involved, i.e., with a single parameter vector.
Is this an error?
I don't think so, that seems to me the intended behavior.
vmapping your objective function is not something that solvers expect. Solvers expect a real-valued function as input
Hello, I'm using JaxOpt for a master degree thesis on parameter estimation of dynamical systems in system biology.
Goal: Multiple optimization of the same model with different initial solutions.
I have tried two different things:
solver.run
method.The first approach does not work at all. This because, since the solution is not a vector anymore, instead it is a matrix of point, the
condition
function of the inner while loop raise a JAX error of the type:cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(bool[3])].
. Indeed, I would like not to change the source code, hence this is not the rigth approach I guess.On the second approach, I decided to VMAP the entire optimization. The problem is that, it tooks a lot of time to complete, since there is no real parallelization happening here. Obviously, you have to assume that the simulation and computation of the gradient are expensive and each optimization step takes some time to complete.
How can I proceed?
Here is the code to reproduce the error with the first approach: