Open jaeyoo opened 3 years ago
Higham and Hopkin's updated catalog of matrix functions recommends using the inverse scaling and squaring algorithm for the matrix logarithm instead of the Schur-Parlett algorithm: http://eprints.maths.manchester.ac.uk/1851/. The corresponding Frechet derivative is given in http://eprints.maths.manchester.ac.uk/2015/. A custom JVP for the Frechet derivative may be needed, depending on how jax implements JVPs of triangular matrices.
Any progress on this? @jaeyoo @sethaxen
I don't plan to work on this. I don't know if @jaeyoo is.
As there is still interest, I don't mind taking a look.
Taking a first look I noticed some crucial building blocks to implement inverse scaling and squaring are also still missing (such as sqrtm). Because of that, just including logm would already add quite some code. I'm not sure it's the goal to support the full scipy.linalg API? Perhaps one of the core developers could confirm this?
@MichaelMarien note there's an open PR adding sqrtm
: https://github.com/google/jax/pull/9544
I don't think it's necessarily a goal to support everything in scipy.linalg
, but if there are useful functions in there that people want, we have no objection to adding them. In general, these are in the "contributions welcome" category.
Having logm implemented would be quite useful, especially since expm has already been added.
expm and logm would be very useful to have when implementing matrix manifolds
This solution works well enough for matrices that are close to the identity:
@jax.jit
def logm(b):
I = jnp.eye(b.shape[0])
res = jnp.zeros_like(b)
ITERATIONS = 20
for k in range(1, ITERATIONS):
res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k
return res
ITERATIONS must be tuned to reach the precision you need.
Has there been any updates on this? logm
would be very useful for me, and I can take a stab at implementing it.
No progress that I'm aware of!
Thanks for the response. If I were to implement it, would the (apparently better, though I don’t know enough to comment on it) inverse scaling and squaring algorithm be necessary or is the Schur-Parlett (which scipy I believe uses) ok?
I have no idea :) I'd have to read up on those.
Sometimes JAX's/XLA's constraints, like having statically-shaped intermediates, inform what the best approach is.
Maybe we can check for a TF implementation?
I guess the thread above already talks about a TF implementation. We could check for any changes there, I guess! But it sounds like @sethaxen has some expertise on this, so maybe we should try his recommendations first.
Yes, TF used the Schur (and assuming their documentation is correct, still does), but if there is expertise that says otherwise, I would be interested in hearing it
As noted in https://github.com/google/jax/issues/5469#issuecomment-773567386, Nicholas Higham recommends inverse scaling and squaring for matrix logarithm. Some direct comparisons are made in Figs 1 and 2 of http://eprints.maths.manchester.ac.uk/1851/. For a few of the matrices they examined, Schur-Parlett had a much higher error than inverse scaling and squaring. While they don't directly comment on computational cost, I believe the specialized algorithm is also faster than Schur-Parlett.
Also note that http://eprints.maths.manchester.ac.uk/2015/ gives a 2x faster version of the algorithm if the real Schur form and and real matrix square root are computed for a real matrix (I haven't checked if JAX has the real Schur form). This also avoids unnecessarily complexifying the output (i.e. if a real matrix sqrt/log exists, it is returned). That same paper gives an algorithm for the Frechet derivative (i.e. jvp), which just removes the need to AD through the Schur decomposition.
Thanks for the response. If I were to implement it, would the (apparently better, though I don’t know enough to comment on it) inverse scaling and squaring algorithm be necessary or is the Schur-Parlett (which scipy I believe uses) ok?
Scipy implements the inverse scaling and squaring algorithm for logm
(https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.logm.html) but uses the Schur-Parlett algorithm for funm
(https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.funm.html).
I suggest that @lockwo if you're already interested in implementing funm
, that would be a useful general-purpose addition, and logm
(and maybe some other matrix functions in scipy.linalg
, https://docs.scipy.org/doc/scipy/reference/linalg.html#matrix-functions) could for now default to using funm
. But I do think someone should eventually implement the inverse scaling and squaring algorithm with special-casing for real matrices.
I haven't checked scipy's funm
implementation, but the docstring claims they use the pointwise Schur-Parlett algorithm, which is problematic when eigenvalues are nonunique. Higham's catalog recommends this blocked approach: https://doi.org/10.1137/S0895479802410815
FYI: In MATLAB Central File Exchange, there is an implementation of Matrix logarithm and its Frechet derivative using the algorithm of http://eprints.ma.man.ac.uk/1852/. Samuel Relton (2023). Matrix Logarithm with Frechet Derivatives and Condition Number.
I realized that jax has expm, but doesn't have its inverse, logm. could you please add this feature? In TF, the matrix logarithm using the Schur-Parlett algorithm. Details of the algorithm can be found in Section 11.6.2 of: Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. ISBN 978-0-898716-46-7.