patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
326 stars 22 forks source link

Block operators / solvers #96

Open f0uriest opened 1 month ago

f0uriest commented 1 month ago

Following the discussion in #80 I thought it would be good to lay out some ideas for a general block matrix/operator abstraction.

Some useful types of block operators:

I think the more powerful capability is if we allow block operators where each block itself can be an arbitrary LinearOperator including nested block operators. This would allow for a pretty expressive API that could represent most structured sparsity patterns that commonly arise when discretizing PDEs and lots of other applications.

Doing the latter would likely require a few things:

Some other thoughts:

patrick-kidger commented 1 month ago

Thanks for opening this issue -- I think you've hit pretty much all the major points I can think of, and also touched on several I hadn't, like vmap'ing blocks!

Some initial comments:

Allow linear_solve to accept an operator as the RHS, such that linear_solve(operator1, operator2) -> ComposedLinearOperator(InverseLinearOperator(operator1), operator2) These would allow us to symbolically/algorithmically invert an arbitrary block operator without having to know the specifics of what each block is.

It's not clear to me why we'd need linear_solve(op1, op2). We can write InverseLinearOperator(op1) @ op2 anyway, and I don't think we have any scenarios where we need vector/operator polymorphism. Other than that: right now linear_solve always materialises its output. It'd be nice to maintain this invariant.

If each sub-block is dense the overall block operator can usually be stored as a dense array with some special indexing to avoid having to store blocks of zeros. This is nice because arrays can be scanned/vmapped etc.

It sounds like this is an attempt to represent a block operator as something other than a list-of-list-of-operators. (Itself then wrapped into a BlockLinearOperator.) Do you have anything specific in mind for a general approach to this? Whilst I can imagine us eventually adding various kinds of sparse linear operators (JAX's own sparse arrays seem to have a hit a bit of a dead end), I don't know if that's what we want for block structure.

FWIW we don't have to make a single choice here. Whilst we should certainly minimise the amount we create, if we need to them we can create multiple {Foo,Bar}LinearOperators.

However, if we allow sub-blocks to be LinearOperators than we might have to resort to lists of blocks, which (generally?) can't be scanned/vmapped as easily, so compile times might be longer? This might not be too bad if we generally assume that block_size >> num_blocks.

This sounds really tough. I think figuring out what is optimal here is probably the main thing we still need to do.

For blocks that share a particular structure, can we vmap block creation, instead of having to create each one separately and store them in a list?

I think this is probably doable by vmap'ing the creation, then indexing into them, then passing those into the BlockLinearOperator. I speculate that for most use cases that there's little that could be gained by pushing the indexing later into the computation graph.

That said we could consider the option of allowing a component linear operator to span multiple blocks, e.g.

+-----------------+
|      FooOp      |
+--------+--------+
| BarOp1 | BarOp2 |
+-----------------+

for which such adjacencies are really the only scenario I can consider in which it might be advantageous not split things up.

For a given sub-operator, how to decide which solver to use to invert it? Likely this can be done with proper tags using AutoLinearSolver but we may want to add some new tags to indicate that AutoLinearSolver use an iterative method in some cases etc? Or maybe have solver similarly be a blocked/pytree type thing allowing the user to specify which solver to use at each level/block of the operator

For InverseLinearOperator, I think the choice of solver should be an argument to InverseLinearOperator.

For the solver for the overall block operator: I think we should expect each block structure to have its own solver, which is free to do anything it likes. For example something like this:

class BlockLinearOperator(AbstractLinearOperator):
    operators: list[list[AbstractLinearOperator]]
    ...

op = BlockLinearOperator([[FooLinearOperator(), ZeroLinearOperator()], [ZeroLinearOperator(), BarLinearOperator()]])
linear_solve(op, vec, solver=BlockDiagonal())

and in particular note that there are no "tags" needed for BlockLinearOperator itself. Its structure is already determinable at compile time, by checking its operators.

For example BlockDiagonal.init can check that all off-diagonal members are isinstance(x, ZeroLinearOperator).

WDYT?