google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
936 stars 66 forks source link

Parallel execution of multiple optimization processes #508

Closed lmriccardo closed 1 year ago

lmriccardo commented 1 year ago

Hello, I'm using JaxOpt for a master degree thesis on parameter estimation of dynamical systems in system biology.

Goal: Multiple optimization of the same model with different initial solutions.

I have tried two different things:

  1. VMAP on the objective function and run a single optimization
  2. VMAP the solver.run method.

The first approach does not work at all. This because, since the solution is not a vector anymore, instead it is a matrix of point, the condition function of the inner while loop raise a JAX error of the type: cond_fun must return a boolean scalar, but got output type(s) [ShapedArray(bool[3])].. Indeed, I would like not to change the source code, hence this is not the rigth approach I guess.

On the second approach, I decided to VMAP the entire optimization. The problem is that, it tooks a lot of time to complete, since there is no real parallelization happening here. Obviously, you have to assume that the simulation and computation of the gradient are expensive and each optimization step takes some time to complete.

How can I proceed?

Here is the code to reproduce the error with the first approach:

import jax
import jax.numpy as jnp
import jaxopt

@jax.jit
@jax.value_and_grad
def loss(params, x, target) -> jnp.ndarray:
    """ Compute the loss function """
    diff = params * x - target
    return diff.T @ diff / diff.shape[0]

# Define the initial starting points
params = jnp.array(np.random.uniform(size=(3, 10)))
x = jnp.array(np.random.uniform(size=(10,)))
x = jnp.tile(x, (3,1))
target = jnp.array(np.random.uniform(size=(10,))) * 10
target = jnp.tile(target, (3,1))

# VMAP on the loss function
_loss = jax.vmap(loss)

solver = jaxopt.GradientDescent(_loss, maxiter=100, value_and_grad=True)
params, res = solver.run(params, x, target)
fabianp commented 1 year ago

hi @lmriccardo . You don't need to decorate with @jax.value_and_grad . GradientDescent expects the loss function, not its gradient (it will compute the gradient itself)

lmriccardo commented 1 year ago

Hi @fabianp , thanks for the response. I saw that it is also possible to pass the gradient of the function then setting the attribute value_and_grad=True of GradientDescent. In fact it works properly if there are no vmap involved, i.e., with a single parameter vector.

Is this an error?

fabianp commented 1 year ago

I don't think so, that seems to me the intended behavior.

vmapping your objective function is not something that solvers expect. Solvers expect a real-valued function as input