patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

Multivariate linear solver #29

Closed quattro closed 1 year ago

quattro commented 1 year ago

Hi all, thanks again for developing such an intuitive and useful library! It would be great to have functionality that generalizes solving $Ax = b$ to more general $AX = B$ systems. for matrices $X$, $B$. Likewise for the normal equations $A'AX = A'B$.

I tried to accomplish this by naive use of the API and received shape errors.

Vector and operator structures do not match. Got a vector with structure ShapeDtypeStruct(shape=(50, 100), dtype=float32) and an operator with out-structure ShapeDtypeStruct(shape=(50,), dtype=float32)
patrick-kidger commented 1 year ago

This should be doable using jax.vmap.

quattro commented 1 year ago

Fantastic!

In [49]: eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))(A, G).value.shape
Out[49]: (100, 50)

In [51]: jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, G).value.shape
Out[51]: (100, 50)

It would be great to have an expanded FAQ or Cookbook for new users.

patrick-kidger commented 1 year ago

Agreed! This should go in the FAQ. (If you wish we'd be happy to take a PR on this.)