patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Add conj function and adjoint tests #62

Closed Randl closed 1 year ago

Randl commented 1 year ago

Add adjoint property, which is useful for complex-valued operators. See https://github.com/google/lineax/issues/57

patrick-kidger commented 1 year ago

What's the use-case for adding this? The only reason we have a transpose function is to handle the transposition necessary to synthesise the backward (VJP) pass. In particular JAX performs only transposition (in contrast to PyTorch, which performs conjugate-transposition), so I don't think we actually need this for any internal operation.

Randl commented 1 year ago

Specifically for normal operators, e.g., in CG part

patrick-kidger commented 1 year ago

Hmm. I think we can probably handle that without adding a new API. Rewriting this line: https://github.com/google/lineax/blob/9a573529f3eea2a6944009f6e4afe4f364366baf/lineax/_solver/cg.py#L120 to be ω(_transpose_mv(ω(_mv(vector)).call(jnp.conj).ω)).call(jnp.conj).ω should work I think. (This is morally just a jnp.conj(_transpose_mv(jnp.conj(_mv(vector)))), the ω is just a neat syntax we have for tree-mapping an operation. And check my math but I think conj(A conj(b)) = conj(A) b.)

Mostly I'm trying to avoid adding more to the public API. We'd also then need to register rules for every downstream operator.

Randl commented 1 year ago

I see the reasoning, but currently it's suggested that the user should do the adjoint themselves at least in some cases:

https://github.com/google/lineax/blob/9a573529f3eea2a6944009f6e4afe4f364366baf/lineax/_solver/cg.py#L107-L112

Also, not having adjoint means bilinears are out of the question, if this is relevant.

patrick-kidger commented 1 year ago

Hmm. Okay, I think I buy your argument. In that case I think I'd suggest adding a lineax.conj function, using single-dispatch in the same way as lineax.{linearise, ...}. This means that we'll be compatible with any other existing AbstractLinearOperator subclasses out there (no new abstract method), and also means we don't need to worry about the complexities of doing the full adjoint in one go -- we can decompose into conjugation and transposition, which should be simpler to reason about.

patrick-kidger commented 1 year ago

By the way, let me know what your plan is regarding the various PRs. I think what might be easiest is to finish all of them before merging any of them, and then we can merge them in whatever order you think is best.

(Some good news - I have no changes planned for Lineax at the moment, so you shouldn't need to worry about keeping any of them up to date with any other changes.)

Randl commented 1 year ago

From my perspective, these 5 PRs are ready to merge (up to minor unclosed discussions). It does not fully implement complex support, but it would be simpler for me to continue when these are merged so I can be sure this part already works.

As for the merge order, I think it should #59 and #63 should go first, then #60, then #61 and this one

patrick-kidger commented 1 year ago

Also LGTM!