cai4cai / torchsparsegradutils

A collection of utility functions to work with PyTorch sparse tensors
Apache License 2.0
24 stars 3 forks source link

Make SparseSolveJ4T more flexible #19

Closed tvercaut closed 1 year ago

tvercaut commented 1 year ago

We should provide a means of specifying options to choose the backend solver and any corresponding option.

Currently jax.scipy.sparse.linalg.cg with default options is hardcoded: https://github.com/cai4cai/torchsparsegradutils/blob/a50406e985afeaf62d4befd8c496103c6bd0c336/torchsparsegradutils/jax/jax_sparse_solve.py#L30 https://github.com/cai4cai/torchsparsegradutils/blob/a50406e985afeaf62d4befd8c496103c6bd0c336/torchsparsegradutils/jax/jax_sparse_solve.py#L53

However, several solvers are availble in JAX: https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.sparse.linalg

This is a follow-up from #5