patrick-kidger / optimistix

Nonlinear optimisation (root-finding, least squares, ...) in JAX+Equinox. https://docs.kidger.site/optimistix/
Apache License 2.0
326 stars 14 forks source link

Support complex-to-real optimization #76

Open Randl opened 2 months ago

Randl commented 2 months ago
Randl commented 2 months ago

@patrick-kidger specifically the third point requires some design decision which I yet have no idea how to realize. Make complex-to-real operator support in Lineax?

patrick-kidger commented 2 months ago

I think JAX has a sharply limited notion of complex linearity, to be honest. They're pretty up-front about the fact that they're really just treating things as functions between vector spaces over the reals.

Thinking out loud: linearity is the property that f(a + λb) = f(a) + λf(b). If f maps between vector spaces over different fields then this isn't defined, and you need to start reaching for some generalised notion of linearity.

We should follow JAX's lead on this one, I think, if we're going to interface well with the rest of its ecosystem.

Randl commented 2 months ago

You are correct that linear map between vector spaces over different fields doesn't make sense, definitely when the field of the input has elements not contained in the field of the output.

If I understand correctly, jax.linearize is equivalent to JVP, which, following Jax docs is linear in c, d separately but not in c+id (which is the same for addition but not for constant multiplication). This is why stuff doesn't fail in Lineax, at least until you try to materialize it.

I'm not familiar with internals, but making R->R map under the hood makes the most sense for me.

Randl commented 2 months ago

So, thinking out loud, the only real place we are affected is materialization. As long as we do not materialize the matrix, jax internals take care of everything for us.

When we want to materialize, we have a problem since the materialization of C->R operator is impossible (it's not a linear operator). Instead, we can materialize the 2NxM matrix for the C^N->R^M jacobian. Now, it should be the responsibility of the user to do C->R^2 transformations with the inputs.

However, if we want to have a uniform interface, I think we should allow representing the C->R operator as R^2 ->R for any operator class. One possible solution is to have a flag that indicates that inputs are complex even though the matrix is real and then internally perform C-> R^2 transformations.

So, the first step would be to make lineax C->R compatible

Then, it should be possible to use those for optimization. What do you think?

patrick-kidger commented 2 months ago

I like the identification of materialization as being where things go wrong. I think that's plausibly the main problem or only problem for us.

What does jax.jacfwd(some_complex_to_real_function) do here? This is the native JAX equivalent. We should be able to do whatever they do.

Randl commented 2 months ago

jax gives an error and suggests to use jvp directly: https://github.com/google/jax/blob/b957f8baab287f1a0e1e880b885f89b1f4272b50/jax/_src/api.py#L846-L850 I'm not sure that's an option for lineax solvers and as such for optimistix.

NeilGirdhar commented 2 months ago

It would be really cool to have this in Optimistix!

jax gives an error and suggests to use jvp directly:

That's funny, I think I actually added that to the error message in this pull request, which incidentally is about supporting heterogeneous pytrees (with complex and real values).

  • Add flag indicating that the R^2->R operator is in fact C->R

I just wonder why you need the flag? Wouldn't it be more ideal to support any pytree input? It might we worth it to take a look at my pull request to see how I did it and whether you can adapt my solution (or more ideally call into it somehow) to Optimistix?

Randl commented 2 months ago

I agree that support for arbitrary pytree can be nice. I'm not sure what exactly you propose to call since we want an equivalent of jaxfwd which doesn't support C->R in Jax. We do not probably need the flag explicitly, as input structure has the required information of which part is complex. What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.

NeilGirdhar commented 2 months ago

We do not probably need the flag explicitly, as input structure has the required information of which part is complex.

Totally agree!

What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.

Maybe. Is it possible to fix jax.jacfwd? Then Patrick's wish "We should be able to do whatever they do" would come true?

Randl commented 2 months ago

Is it possible to fix jax.jacfwd?

Depends on you expectations of jacfwd. You can't possibly expect to get a matrix that you multiply the vector by to get JVP (see the discussion of complex linearity above). We can make it output R^2 for each C input, but the bookkeeping of transforming input is still on user. I'm not sure if jax would want this change, I'd say it is controversial for exposure to the end user (I'm more ok with it for internal use).

Randl commented 1 month ago

Related tutorial that discusses similar issues of JVP of C->R functions:

https://arxiv.org/abs/2409.06752