pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
361 stars 108 forks source link

Allow access to `solve_triangular` in `pt.linalg` #291

Closed jessegrabowski closed 1 year ago

jessegrabowski commented 1 year ago

Description

The current depreciation warning for pt.slinalg.solve_lower_triangular and pt.slinalg.solve_upper_triangular directs the user to use pt.linalg.solve, but the specialized LAPACK triangular solver implemented in scipy.linalg.solve_triangular (TRTRS) is not accessible via any combination of arguments to pt.linalg.solve.

Currently, if one of the inputs to pt.linalg.solve is a Cholesky node, it will be replaced with a solve Op with assume_a='sym'. This invokes the SYSV routine, which is not the same as TRTRS. It'd be nice to speed test these two routines and make sure pytensor is using the faster one by default. If SYSV is faster in all cases, it doesn't make sense to have a solve triangular routine at all, so a priori I suspect it's not.

Anyway, I would suggest the function pt.slinalg.solve_triangular be exposed in pt.linalg, and the depreciation warning changed to direct users to this function. This would match the scipy API.

Another alternative would be to add an additional option for assume_a in pt.linalg.solve for triangular matrices, and add some additional control flow inside pt.slinlag.Solve.perform that would route to scipy.linalg.solve_triangular if assuma_a == 'triang'. The disadvantage here is that the API would not longer follow that of scipy.

ricardoV94 commented 1 year ago

Sounds fine to me to match the scipy API (even if we were to use Solve for it under the hood). In this case, since there's a performance difference we can use the already implemented SolveTriangular Op.

More generally, are there any trade-offs? In what cases would we want to use Solve(assume="sym") under the hood instead of SolveTriangular?

jessegrabowski commented 1 year ago

The LAPACK documentation for SYSV, used in Solve(assume="sym") and TRTRS (used in SolveTriangular) indicate that:

From my admittedly limited understanding, Solve(assume="sym") just performs wasted computation if we knowingly give it a triangular matrix, because it will go about re-factorizing the matrix into another triangular matrix of a different form. The tradeoff might be that if we have a symmetric positive definite matrix, Solve(assume_a='sym') might be faster than TriangularSolve(Cholesky(A), Eye). Requires testing, especially given that there is a CholeskySolve Op that doesn't see much use?

ricardoV94 commented 1 year ago

I updated the title of this issue for the thing that needs to be done. The related bug is now tracked in #382

ricardoV94 commented 1 year ago

I think this one is completed

jessegrabowski commented 1 year ago

Yes this was done en passant in #417, and can be closed.