google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
918 stars 64 forks source link

OptaxSolver Error: too many positional arguments #360

Open MarioAuditore opened 1 year ago

MarioAuditore commented 1 year ago

Hello! I tried to implement the example of implicit differentiation as shown here but with my own functions. The task is to find mean for a set of vectors named X via gradient descent.

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import jax
import jax.numpy as jnp
from jax import grad, random, jit
from jax import jacobian, hessian, jacfwd, jacrev
key = random.PRNGKey(0)

import jaxopt
from jaxopt import implicit_diff
from jaxopt import linear_solve
from jaxopt import OptaxSolver, GradientDescent
import optax

def euclidean_distance(a, b):
    """
    Squared Euclidean distance
    """
    return jnp.inner(a - b, a - b)

def weighted_distance(x, X, w):
    loss = 0
    for i, obj in enumerate(X):
        loss += w[i] * euclidean_distance(obj, x)
    return loss

def identical(Y, Y_grad):
    return Y

Algorithm for finding mean:

# Mean calculation for manifolds with gradient descent
@implicit_diff.custom_root(jax.grad(weighted_distance))
def euclidean_weighted_mean(X_set, weights = None, lr = 0.1, n_iter = 50, plot_loss_flag = False):

    if weights == None:
        weights = jnp.full((X_set.shape[0]), 1) / X_set.shape[0]

    # init mean with random element from set
    Y = X_set[np.random.randint(0, X_set.shape[0], (1,))][0] 

    if plot_loss_flag:
        plot_loss = []
        prev_loss = 0
        plato_iter = 0
        plato_reached = False

    for i in range(n_iter):

        # calculate loss
        loss = weighted_distance(Y, X_set, weights)

        if plot_loss_flag:
            if jnp.allclose(jnp.array(loss), jnp.array(prev_loss)):
                if not plato_reached:
                    plato_iter = i
                    plato_reached = True
            else:
                prev_loss = loss
                plato_reached = False

        Y_grad = grad(weighted_distance, argnums= 0)(Y, X_set, weights)

        # calculate Riemannian gradient
        riem_grad_Y = Y_grad

        # update Y
        Y_step = Y - lr * riem_grad_Y

        # project new Y on manifold with retraction
        Y = Y_step

        if plot_loss_flag:
          # collect loss for plotting
          plot_loss.append(loss)

    if plot_loss_flag:
        print(f"Total loss: {weighted_distance(Y, X_set, weights)} got in {plato_iter} iterations")    
        fig, ax = plt.subplots()
        ax.plot(plot_loss)
        ax.set_xlabel("Iteration")
        ax.set_ylabel("Loss")
        plt.show()
    return Y

You can launch it like this:

d = 2
m = 4
X = jax.random.uniform(key, (m,d))
euclidean_weighted_mean(X, weights = None, lr = 1e-3, n_iter = 100, plot_loss_flag = True)

As you can see, I am calculating the weighted version of mean and that's where I use jaxopt. Let me define the global objective (just as an example): I want the weights have the value, which minimises the distance between the resulting mean and the desired point. In my case, I want the weights to influence the algorithm in such a way, that the resulting mean will be as close to X[0] as possible:

def global_task_objective(w, X, target_point, lr, n_iter):
    x = euclidean_weighted_mean(X, w, lr = lr, n_iter = n_iter)
    loss = euclidean_distance(x, target_point)
    return loss, x

target_point = X[0]

w_init = jnp.array(np.random.randn(X.shape[0])) * jnp.square(2 / X.shape[0]) 

lr = 1e-3
n_iter = 100

global_task_objective(w_init, X, target_point, lr, n_iter)
solver = OptaxSolver(opt=optax.amsgrad(1e-2), fun=global_task_objective, has_aux=True)
state = solver.init_state(w_init, X=X, target_point=target_point, lr=lr, n_iter=n_iter)

The problem emerges when I call

w_init, state = solver.update(params=w_init, 
                             state=state, 
                             X=X, target_point=target_point, lr=lr, n_iter=n_iter)
image

Meanwhile the official example with Ridge regression works perfectly. Any suggestions?

mblondel commented 1 year ago

When you use

@implicit_diff.custom_root(optimality_fun)
def fun(...):

fun and optimality_fun should have the same number of arguments, which is not the case here.

Your function euclidean_weighted_mean includes non-differentiable arguments like n_iter and plot_loss_flag. You need to remove them.

def make_euclidean_weighted_mean(lr = 0.1, n_iter = 50, plot_loss_flag = False):
  def euclidean_weighted_mean(x, X, weights):
    [...]