patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Add PowLinearOperator for element-wise power operators #31

Open quattro opened 1 year ago

quattro commented 1 year ago

Hi all,

I'm prototyping a PowLinearOperator for some internal projects that need to perform (g * g).mv(x) for LinearOperator g and vector x, and thought I'd collect some thoughts before proceeding.

If we restrict to positive integer scalars, then we can rely on the decomposition of (g * g).mv(x) == diag(g diag(x) g.T) [or recurse similarly for scalars > 2].

Would this be of general interest? If so, I'm happy to expand on the test suite and perform a pull req.

patrick-kidger commented 1 year ago

If I understand your definition correctly: thinking of the operator as a matrix, this is an elementwise power of each entry in the matrix?

That does seem nontrivial to implement for general operators, so I think this would make sense as a PR! Probably the main thing to be careful about is to avoid compile times increasing as the power increases. The recursion should be doing using a lax.scan to avoid this.

(And moreover it's probably also worth overriding MatrixLinearOperator.__pow__, which can be efficiently special-cased.)