Open arunoruto opened 3 months ago
I came across a vectorized Levenberg-Marquardt implementation, which sounds like exactly what I need!
Sadly I have only a little bit of experience in programming with JAX. I will try to naively implement this paper in Python, and then try to port it to Optimistix.
tldr; the paper introduces an operator G
and factor mu
with the final LM equation becoming:
arg min ||f(p)-y||^2 + mu ||Gp||^2
p
Gp
returns a concatenation of the spatial gradients of each parameter map.
JAX has a great tool for this called jax.vmap
. First describe a single optimization problem you want to solve, and then vmap it over as many extra dimensions as you like. JAX and Optimistix will autovectorize the whole operation.
Thanks for the tip! I was reading about vmap
s this morning, so I was able to sketch a small PoC and it seems to work! I am not sure if it is the right way to do it, so I wanted to ask if I am on the right track:
amsa
function works for scalars, vectors and images/matrices. Would you say it would be better to cater it to scalars and then apply vmap
to obtain the vector and matrix versions? What is the JAX way of doing it?My function accepts a lot of parameters:
single_scattering_albedo: ArrayLike, # [a, b]
incidence_direction: ArrayLike, # [a, b, 3]
emission_direction: ArrayLike, # [a, b, 3]
surface_orientation: ArrayLike, # [a, b, 3]
phase_function: dict,
b_n: ArrayLike,
a_n: ArrayLike = np.nan,
roughness: float = 0,
hs: float = 0,
bs0: float = 0,
hc: float = 0,
bc0: float = 0,
refl_optimization: ArrayLike = 0.0, # [a, b]
I used vmap
twice:
amsa1 = jax.vmap(
amsa,
(0, 0, 0, 0, None, None, None, None, None, None, None, None, 0),
0,
)
amsa2 = jax.vmap(
amsa1,
(1, 1, 1, 1, None, None, None, None, None, None, None, None, 1),
1,
)
def amsa_optx_wrapper(x, args):
return amsa2(x, *args)
Is this correct? Or is there a better way? This question is kinda related with the one above.
And btw, have I missed something in the documentation regarding vmap
and optimistix
? I couldn't find anything regarding automatic vectorization. Or is it something fairly common with JAX and is implicitly known?
My current amsa function works for scalars, vectors and images/matrices. Would you say it would be better to cater it to scalars and then apply vmap to obtain the vector and matrix versions? What is the JAX way of doing it?
Yes, exactly! Usually better to let vmap handle the rank-polymorphism for you.
I used vmap twice:
This looks reasonable. FWIW you can simplify your handling of the vmaps a bit: pack your arguments into batched and nonbatched groups and you can just do jax.vmap(fn, in_axes=(0, None))((to_batch1, to_batch2, to_batch3, ...), (nobatch1, nobatch2, ...))
. All that's going on there is that in_axes
should be a pytree-prefix of the arguments.
You might also like equinox.filter_vmap
if that's ever helpful.
I am getting an out of memory error when using the code on GPU, which doesn't happen with scipy.
If I had to guess, SciPy is probably doing its optimizer-internal operations on the CPU -- only evaluating your user-provided vector field on the GPU. So probably it's using less memory anyway.
Or is it something fairly common with JAX and is implicitly known?
Yup, common with JAX!
I just recently came across JAX and I am now trying to use it for my implementation of the Hapke Anisotropic Multiple Scattering Approximation model. I made a similar issue on jaxopt, but since the repository isn't going to be maintained much in the future, I gave optimistix a try! It seems to be faster than jaxopt and a bit than scipy:
I am using optx.least_squares with LM in a for-loop which iterates over all pixels and tries to find the best fit to the target function. The structure is similar to the code provided here.
I was wondering if the current implementation would somehow allow me to pass a multidimensional array, or even a matrix, and optimize along an axis. Is there a trick maybe to achieve what I want?
Also, would it be possible to maybe provide a derivative function of the target function? I am still impressed I got such runtimes without providing it, but why derive it if I can provide it :)