google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

Better expm #17756

Open howsiyu opened 9 months ago

howsiyu commented 9 months ago

Hi, it seems that jax.scipy.linalg.expm 's implementation is based on "The Scaling and Squaring Method for the Matrix Exponential Revisited" by Higham, Nicholas J. 2005 which is also featured in "Functions of Matrices: Theory and Computation". However, scipy's implementation is based on "A New Scaling and Squaring Algorithm for the Matrix Exponential" by the same author in 2009 which gives a few improvements; notably, it calculates the diagonal and superdiagonal of exponential of triangular matrix in exact arithmetic and it uses norms of powers of A to better determine the number of squarings needed.

Also, I noticed that the current implementation of expm uses pade approximation of degree 3, 5, 7, 9, or 13 depending on the norm of matrix. I believe this is done in the original paper mainly to save computing power as we can afford to use pade approximation of smaller degree for matrix of small norm; and that pade approximation of higher degree doesn't hurt the accuracy. However, in GPU we probably execute all the branches anyway. I wonder whether we should just use the degree 13 pade approximation, which simplify the code substantially and may help performance in GPU.

jakevdp commented 9 months ago

cc/ @sriharikrishna who first contributed expm in #1940