Closed MaximilianJHuber closed 1 month ago
Such compile times are an unfortunate facet of JAX when describing very large computation graphs. (Like this one.)
Probably the simplest thing to do would be to call params, unflatten = jax.flatten_util.ravel_pytree(...)
to squash all your parameters together into one array prior to calling least_squares
, and then unflatten
them again inside your target function.
Perfect, thank you!
I find it convenient to organize the parameters of a non-linear function in a
dict
, but for a large model that causes prohibitively long compile time.MWE: (compiles for >1 minute, change
100
to 500 and it does not return)Is there a better way of organizing named parameters?