cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
95 stars 28 forks source link

linear_operator cat_rows performance improvement #93

Closed naefjo closed 8 months ago

naefjo commented 8 months ago

Hello :)

This PR addresses the observations about computational bottlenecks made in cornellius-gp/gpytorch#2468

In cat_rows, the schur_root is converted to a dense operator using to_dense. Hence, the subsequent inversion of the root fails to exploit the structure of the operator and defaults to stable_pinverse which uses a QR decomposition.

The PR contains the following modifications:

Balandat commented 8 months ago

Thanks for the contribution! Overall this makes sense to me. Would you be able to provide some benchmark results of this change relative to the previous implementation?

naefjo commented 8 months ago

Sure thing. Timing cat_rows in isolation seems to not make any noticable difference at all in a toy example i tried to cook up. However, it is very noticable in gpytorch's get_fantasy_model. Here is a graph showing the computation times of gpytorch's get_fantasy_model method as a function of number of datapoints based on the basic example notebook. The updates were performed in a "batched" setting, i.e. 10 points at a time. Note that this is in conjunction with the changes from cornellius-gp/gpytorch#2494.

image

BTW, the failing CI seems to be related to an updated version of mpmath downloaded here. 1.4.0 seems to have some breaking API changes. A quick google search shows that other repos were affected as well https://github.com/pytorch/pytorch/issues/120995 https://github.com/NVIDIA/TensorRT-LLM/issues/1145

naefjo commented 8 months ago

Disregard my last comment about there being no difference in cat_rows. Apparently the matrices I was testing were not p.d. enough for cholesky factorization which led to root decompositions being performed with symeig which again led to root inverses being computed with stable_pinverse in cat_rows.... If the matrices are well conditioned enough for cholesky not to fail, then the result look as follows:

image
Balandat commented 8 months ago

Nice, this seems like a meaningful improvement. Thanks for the perf fix.

BTW, the failing CI seems to be related to an updated version of mpmath

Yes, thanks, I've run into this with other libraries before. #94 pins the version to avoid this issue for now. Could you rebase on that change so we can run the test?

Balandat commented 8 months ago

Not sure what is going on with the docs, but it's unrelated to this PR