patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
265 stars 12 forks source link

Slow compile of least_squares with large dict parameters #62

Closed MaximilianJHuber closed 1 month ago

MaximilianJHuber commented 1 month ago

I find it convenient to organize the parameters of a non-linear function in a dict, but for a large model that causes prohibitively long compile time.

MWE: (compiles for >1 minute, change 100 to 500 and it does not return)

import jax
import optimistix as optx

@jax.jit
def rosenbrock(p):
   return [1 - p['1'], 100 * (p['2'] - p['1']**2)]

optx.least_squares(
   rosenbrock, 
   optx.LevenbergMarquardt(rtol=1e-8, atol=1e-8), 
   {str(int(k)): 1.0 for k in np.arange(0,100)}
)

Is there a better way of organizing named parameters?

patrick-kidger commented 1 month ago

Such compile times are an unfortunate facet of JAX when describing very large computation graphs. (Like this one.)

Probably the simplest thing to do would be to call params, unflatten = jax.flatten_util.ravel_pytree(...) to squash all your parameters together into one array prior to calling least_squares, and then unflatten them again inside your target function.

MaximilianJHuber commented 1 month ago

Perfect, thank you!