google / jaxopt

Hardware accelerated, batchable and differentiable optimizers in JAX.
https://jaxopt.github.io
Apache License 2.0
939 stars 66 forks source link

LevenbergMarquardt do not seems to work with non-flat input. #505

Open bolducke opened 1 year ago

bolducke commented 1 year ago

GaussNewton is working as intended with Pytree. I would expect the same for LM. Instead, I had to flatten the array to make it properly works.

image

The errors appear at line 445. The error comes from the fact that the pytree of params and vec do not match.

amir-saadat commented 1 year ago

Hi @bolducke thanks for reporting this. Do you have an example that I can use for the repro? I plan to update the unit tests to cover that use case for both GN and LM.

nickmcgreivy commented 10 months ago

@amir-saadat I was going to report this as well. I wrote a super simple test case to demonstrate, though I'm not sure its what you're looking for.

config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt

M = 5
params = jnp.zeros((2,M))
params = params.at[0].set(jnp.arange(M) * 1.0)
params = params.at[1].set(jnp.arange(M)**2 * 1.0)
params_dict = {'A': jnp.arange(M) * 1.0, 'B': jnp.arange(M)**2 * 1.0}

def F(params):
    return jnp.asarray([jnp.sum(params[0]), jnp.sum(params[0] * params[1]**2)])

def F_dict(params):
    return jnp.asarray([jnp.sum(params['A']), jnp.sum(params['A'] * params['B']**2)])

def optimize_F_gn(params, F):
    gn = jaxopt.GaussNewton(residual_fun=F)
    return gn.run(params).params

def optimize_F_lm(params, F):
    gn = jaxopt.LevenbergMarquardt(residual_fun=F)
    return gn.run(params).params

print(params)
print(optimize_F_gn(params, F)) 
print(optimize_F_gn(params_dict, F_dict))
print(optimize_F_lm(params, F)) # fails
print(optimize_F_lm(params_dict, F_dict)) # fails