Hi @bolducke thanks for reporting this. Do you have an example that I can use for the repro? I plan to update the unit tests to cover that use case for both GN and LM.
@amir-saadat I was going to report this as well. I wrote a super simple test case to demonstrate, though I'm not sure its what you're looking for.
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jaxopt
M = 5
params = jnp.zeros((2,M))
params =[0].set(jnp.arange(M) * 1.0)
params =[1].set(jnp.arange(M)**2 * 1.0)
params_dict = {'A': jnp.arange(M) * 1.0, 'B': jnp.arange(M)**2 * 1.0}
def F(params):
return jnp.asarray([jnp.sum(params[0]), jnp.sum(params[0] * params[1]**2)])
def F_dict(params):
return jnp.asarray([jnp.sum(params['A']), jnp.sum(params['A'] * params['B']**2)])
def optimize_F_gn(params, F):
gn = jaxopt.GaussNewton(residual_fun=F)
def optimize_F_lm(params, F):
gn = jaxopt.LevenbergMarquardt(residual_fun=F)
print(optimize_F_gn(params, F))
print(optimize_F_gn(params_dict, F_dict))
print(optimize_F_lm(params, F)) # fails
print(optimize_F_lm(params_dict, F_dict)) # fails
GaussNewton is working as intended with Pytree. I would expect the same for LM. Instead, I had to flatten the array to make it properly works.
The errors appear at line 445. The error comes from the fact that the pytree of params and vec do not match.