patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

New solver for block tridiagonal matrices #80

Open aidancrilly opened 9 months ago

aidancrilly commented 9 months ago

Hi,

I needed a block tridiagonal solver for a project of mine, so I took a stab at adding one to lineax (very nice package, thank you!). The solve is a simple extension of the Thomas algorithm and this can scale better than LU which doesn't exploit the banded structure. I tested my implementation vs MatrixLinearOperator on the matrix and it can be considerably faster (~ 4x faster for 100 diagonal 2x2 blocks).

I have run the existing tests and have written another of my own to test the block tridiagonal representation and solve. All tests pass except “test_identity_with_different_structures_complex”, but this also fails for me on the main branch(?). I will admit that the tag and token methodology used by lineax isn't super familiar to me so apologies if I have not used this properly for the new block tridiagonal class.

Hopefully this addition is of use.

patrick-kidger commented 9 months ago

Ah, this is excellent! Thank you for the contribution. I really like this implementation, which looks very clean throughout.

This does touch on an interesting point: @packquickly and I were discussing adding a general "block operator" abstraction, which wraps around other operators. (For example, it's common to have a block [dense identity] matrix when solving Levenberg-Mardquardt type problems, and that could benefit from some specialist solving as well.)

I'm wondering if it might be possible to support block-tridiagonal as a special case of that. I suspect the general case of doing a linear solve with respect to arbitrary block matrices might be a bit tricky, though.

@packquickly WDYT?

packquickly commented 9 months ago

First off, excellent stuff!! Thank you very much for this PR.

Regarding block matrices, now may be a good time to settle on an abstraction for them. General linear solves against the compositional block matrices Patrick mentioned does seem a bit painful to do efficiently, but only a bit. At first glance, following the non-recursive implementation of Block LU in Golub and Van Loan should work as a generic solver for these block operators. Looking at the code here, and the implementation of block tridiagonal outlined in Golub and Van Loan I think going from this implementation of block tridiagonal to one using the compositional "block operator" as Patrick suggested would not require too many changes either.

f0uriest commented 6 months ago

Any progress on this? I'd be super interested in having a general abstraction for block linear operators, especially if the blocks themselves can be block linear operators. For example, a problem I'm working on now has a block triangular matrix where each block is itself block tri-diagonal, with the inner most blocks being dense.

A general block LU would be a good starting point, though having specialized options for common sparsity patterns (eg block tridiagonal, triangular, hessenberg, diagonal etc) would be useful too.

The main limitation I see with having nested block operators is that jax doesn't really like array-of-struct type things, so maybe there's a more clever way?

patrick-kidger commented 6 months ago

I'd be very happy to see this revived as well!

On the topic of block operators: @packquickly and I discussed this a bit offline, and we don't think there's actually a nice way to express this in general. IIRC this basically boiled down to the mathematics: there isn't (?) a nice way to express a linear-solve-against-a-block-operator in terms of linear solves against the individual operators.

That leaves me wondering if the best way to handle this is something like (a) nonetheless introducing a general block operator, but (b) not trying to introduce a corresponding general "block solver": rather, have individual solvers such as the one implemented here, and have them verify that the components of the block-operator look like what they expect.

I'm fairly open to suggestions/implementations on this one. :)

f0uriest commented 6 months ago

Can you clarify a bit what you mean by

there isn't (?) a nice way to express a linear-solve-against-a-block-operator in terms of linear solves against the individual operators.

Do you just mean that at some point you end up with things like solve(operator1, operator2) (which seems like you would then need to materialize operator 2)? I would think that could be handled with some sort of InverseLinearOperator abstraction + the regular ComposedLinearOperator such that solve(operator1, operator2) -> ComposedLinearOperator(InverseLinearOperator(operator1), operator2)

InverseLinearOperator.mv would basically be a wrapper around linear_solve, maybe with some precomputation in cases where it makes sense (like only doing an lu factorization once, similar to the Solver.init methods for some of them)

patrick-kidger commented 6 months ago

Hmmm. So to your point, I think we can sort-of do this. The following identity certainly exists (do others like it exists for NxM blocks, not just 2x2?):

https://en.wikipedia.org/wiki/Block_matrix#Inversion

and the (D - CA-1B)-1 components could be handled in the way you describe above.

I'm guessing the above is what you have in mind?

I think the question is on the solver. How would the above help us avoid writing the custom solver implemented in this PR? I don't see a way to do that. It's a neat mathematical identity, but in the dense case it's not better than just materialising and using LU, whilst in the structured case we still need to write custom solvers.

So the place I land on is that whilst we can certainly create a block operator abstraction without difficulty, we would still need individual solver implementations corresponding to different block structures.

f0uriest commented 6 months ago

Ok yeah, there would still be a need for specific solvers like this one, but I think the block abstraction would still be useful, and possibly combining it with the existing tag system for structured block matrices, each with specialized solvers (like the existing ones for diagonal, tridiagonal, triangular, PSD etc)

I'll make an issue for tracking further ideas for this.