patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
334 stars 14 forks source link

Optimization across multidimensional array #70

Open arunoruto opened 3 months ago

arunoruto commented 3 months ago

I just recently came across JAX and I am now trying to use it for my implementation of the Hapke Anisotropic Multiple Scattering Approximation model. I made a similar issue on jaxopt, but since the repository isn't going to be maintained much in the future, I gave optimistix a try! It seems to be faster than jaxopt and a bit than scipy:

# optimistix:
Inverse AMSA: Mean +- std dev: 984 ms +- 30 ms
# scipy:
Inverse AMSA: Mean +- std dev: 1.17 sec +- 0.03 sec
# LM:
Inverse AMSA: Mean +- std dev: 42.7 sec +- 0.5 sec

I am using optx.least_squares with LM in a for-loop which iterates over all pixels and tries to find the best fit to the target function. The structure is similar to the code provided here.

I was wondering if the current implementation would somehow allow me to pass a multidimensional array, or even a matrix, and optimize along an axis. Is there a trick maybe to achieve what I want?

Also, would it be possible to maybe provide a derivative function of the target function? I am still impressed I got such runtimes without providing it, but why derive it if I can provide it :)

arunoruto commented 3 months ago

I came across a vectorized Levenberg-Marquardt implementation, which sounds like exactly what I need!

Sadly I have only a little bit of experience in programming with JAX. I will try to naively implement this paper in Python, and then try to port it to Optimistix.

tldr; the paper introduces an operator G and factor mu with the final LM equation becoming:

arg min ||f(p)-y||^2 + mu ||Gp||^2
     p

Gp returns a concatenation of the spatial gradients of each parameter map.

patrick-kidger commented 3 months ago

JAX has a great tool for this called jax.vmap. First describe a single optimization problem you want to solve, and then vmap it over as many extra dimensions as you like. JAX and Optimistix will autovectorize the whole operation.

arunoruto commented 3 months ago

Thanks for the tip! I was reading about vmaps this morning, so I was able to sketch a small PoC and it seems to work! I am not sure if it is the right way to do it, so I wanted to ask if I am on the right track:

And btw, have I missed something in the documentation regarding vmap and optimistix? I couldn't find anything regarding automatic vectorization. Or is it something fairly common with JAX and is implicitly known?

patrick-kidger commented 3 months ago

My current amsa function works for scalars, vectors and images/matrices. Would you say it would be better to cater it to scalars and then apply vmap to obtain the vector and matrix versions? What is the JAX way of doing it?

Yes, exactly! Usually better to let vmap handle the rank-polymorphism for you.

I used vmap twice:

This looks reasonable. FWIW you can simplify your handling of the vmaps a bit: pack your arguments into batched and nonbatched groups and you can just do jax.vmap(fn, in_axes=(0, None))((to_batch1, to_batch2, to_batch3, ...), (nobatch1, nobatch2, ...)). All that's going on there is that in_axes should be a pytree-prefix of the arguments.

You might also like equinox.filter_vmap if that's ever helpful.

I am getting an out of memory error when using the code on GPU, which doesn't happen with scipy.

If I had to guess, SciPy is probably doing its optimizer-internal operations on the CPU -- only evaluating your user-provided vector field on the GPU. So probably it's using less memory anyway.

Or is it something fairly common with JAX and is implicitly known?

Yup, common with JAX!