Open Randl opened 2 months ago
@patrick-kidger specifically the third point requires some design decision which I yet have no idea how to realize. Make complex-to-real operator support in Lineax?
I think JAX has a sharply limited notion of complex linearity, to be honest. They're pretty up-front about the fact that they're really just treating things as functions between vector spaces over the reals.
Thinking out loud: linearity is the property that f(a + λb) = f(a) + λf(b)
. If f
maps between vector spaces over different fields then this isn't defined, and you need to start reaching for some generalised notion of linearity.
We should follow JAX's lead on this one, I think, if we're going to interface well with the rest of its ecosystem.
You are correct that linear map between vector spaces over different fields doesn't make sense, definitely when the field of the input has elements not contained in the field of the output.
If I understand correctly, jax.linearize
is equivalent to JVP, which, following Jax docs is linear in c
, d
separately but not in c+id
(which is the same for addition but not for constant multiplication). This is why stuff doesn't fail in Lineax, at least until you try to materialize it.
I'm not familiar with internals, but making R->R
map under the hood makes the most sense for me.
So, thinking out loud, the only real place we are affected is materialization. As long as we do not materialize the matrix, jax internals take care of everything for us.
When we want to materialize, we have a problem since the materialization of C->R
operator is impossible (it's not a linear operator). Instead, we can materialize the 2NxM
matrix for the C^N->R^M
jacobian. Now, it should be the responsibility of the user to do C->R^2
transformations with the inputs.
However, if we want to have a uniform interface, I think we should allow representing the C->R
operator as R^2 ->R
for any operator class. One possible solution is to have a flag that indicates that inputs are complex even though the matrix is real and then internally perform C-> R^2
transformations.
So, the first step would be to make lineax C->R
compatible
R^2->R
operatorsR^2->R
operator is in fact C->R
C
vector)Then, it should be possible to use those for optimization. What do you think?
I like the identification of materialization as being where things go wrong. I think that's plausibly the main problem or only problem for us.
What does jax.jacfwd(some_complex_to_real_function)
do here? This is the native JAX equivalent. We should be able to do whatever they do.
jax gives an error and suggests to use jvp
directly:
https://github.com/google/jax/blob/b957f8baab287f1a0e1e880b885f89b1f4272b50/jax/_src/api.py#L846-L850
I'm not sure that's an option for lineax solvers and as such for optimistix.
It would be really cool to have this in Optimistix!
jax gives an error and suggests to use
jvp
directly:
That's funny, I think I actually added that to the error message in this pull request, which incidentally is about supporting heterogeneous pytrees (with complex and real values).
- Add flag indicating that the
R^2->R
operator is in factC->R
I just wonder why you need the flag? Wouldn't it be more ideal to support any pytree input? It might we worth it to take a look at my pull request to see how I did it and whether you can adapt my solution (or more ideally call into it somehow) to Optimistix?
I agree that support for arbitrary pytree can be nice.
I'm not sure what exactly you propose to call since we want an equivalent of jaxfwd
which doesn't support C->R
in Jax.
We do not probably need the flag explicitly, as input structure has the required information of which part is complex.
What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.
We do not probably need the flag explicitly, as input structure has the required information of which part is complex.
Totally agree!
What we need is then during operator application in case of operator stored as matrix and during materialization of Jacobian operator convert the complex input into real one.
Maybe. Is it possible to fix jax.jacfwd
? Then Patrick's wish "We should be able to do whatever they do" would come true?
Is it possible to fix jax.jacfwd?
Depends on you expectations of jacfwd
. You can't possibly expect to get a matrix that you multiply the vector by to get JVP (see the discussion of complex linearity above). We can make it output R^2 for each C input, but the bookkeeping of transforming input is still on user. I'm not sure if jax would want this change, I'd say it is controversial for exposure to the end user (I'm more ok with it for internal use).
Related tutorial that discusses similar issues of JVP of C->R functions: