patrick-kidger / lineax

Linear solvers in JAX and Equinox. https://docs.kidger.site/lineax
Apache License 2.0
365 stars 24 forks source link

AbstractTraceEstimation #88

Closed quattro closed 8 months ago

quattro commented 8 months ago

Hi 👋 ,

Along with a trainee, we've implemented multiple stochastic trace estimators (e.g., Hutchinson, Hutch++, XTrace, etc), for square matrices using lineax's linear operators, and was curious if this would be something useful within the lineax library itself.

We've coded up a fairly general solution, but I would need help for any internals regarding PyTrees or partitioning/combining, (e.g., internals to linear_solve).

The general thought is to have a function along the lines of,


A = ... # square AbstractLinearOperator
k = ... # roughly the number of matvec operations/samples to perform
estimator = ... # AutoTraceEstimator, HutchinsonEstimator, HutchPlusPlusEstimator, XTraceEstimator, XNysTraceEstimor (if PSD) etc...
estimate = lx.estimate_trace(A, k, estimator)

General options could be which sampling distribution to use (e.g., Rademacher, Gaussian, Sphere)

Is this of interest and worth integrating? If so I can set up a pull request for initial reviews.

patrick-kidger commented 8 months ago

I'm afraid I think I'm probably going to have to declare that as out-of-scope. This is simply because I don't think we can commit to maintaining that ourselves. More generally there are a lot of linear algebra operations out there, and if we accept all of them into Lineax then things are definitely going to get beyond the scope we can handle!

That said, I would really encourage you to set up this up in your own library. If you can get things working then I'd be very happy to link to it from the Lineax documentation. I'm starting to think about trying to create a more structured set of ecosystem links.

quattro commented 8 months ago

I understand completely, and would be more than happy to push these into a small library. Cheers, and thanks again.

quattro commented 8 months ago

It's still hyper alpha, but if you have interest, check it out here: https://github.com/mancusolab/traceax