Closed quattro closed 1 year ago
This should be doable using jax.vmap
.
Fantastic!
In [49]: eqx.filter_vmap(lx.linear_solve, in_axes=(None, 1))(A, G).value.shape
Out[49]: (100, 50)
In [51]: jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, G).value.shape
Out[51]: (100, 50)
It would be great to have an expanded FAQ or Cookbook for new users.
Agreed! This should go in the FAQ. (If you wish we'd be happy to take a PR on this.)
Hi all, thanks again for developing such an intuitive and useful library! It would be great to have functionality that generalizes solving $Ax = b$ to more general $AX = B$ systems. for matrices $X$, $B$. Likewise for the normal equations $A'AX = A'B$.
I tried to accomplish this by naive use of the API and received shape errors.