patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
333 stars 14 forks source link

Support for nested root finders #86

Open djbower opened 1 month ago

djbower commented 1 month ago

I just wanted to confirm that Optimistix supports nested root finders? I'm getting a ValueError related to closure:

  def __call__(self, x):
>       return self.fn(x, self.args)
E       ValueError: Closure-converted function called with different dynamic arguments to the example arguments provided.

If so, I'll spend more time trying to figure out the problem.

patrick-kidger commented 1 month ago

Optimistix does support this! It looks like you're probably running into a known issue when using an old version of Equinox and a new version of JAX. Upgrading to the latest version of Equinox (by coincidence I've also just released v0.11.8, give it a try!) should probably resolve this.

djbower commented 1 month ago

Indeed it was an easy fix - I'm running with equinox 0.11.8 and Jax 0.4.33 and all works as expected now. Interestingly, with Jax 0.4.34 I ran into a different error.

patrick-kidger commented 1 month ago

For JAX 0.4.34 -- I've just ran into an issue here myself; possibly it's the same one! I've got a possible fix for Optimistix in https://github.com/patrick-kidger/optimistix/pull/87 and for Lineax in https://github.com/patrick-kidger/lineax/pull/111 .

(For reference, Optimistix depends on both Lineax and Equinox, which is why I'm mentioning them here.)

Can you give those a try and see if those fix things for you? If so, great, and I'll do a release! If not I'd love to have a MWE.