jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.45k stars 2.8k forks source link

Implement `jax.scipy.linalg.logm` #5469

Open jaeyoo opened 3 years ago

jaeyoo commented 3 years ago

I realized that jax has expm, but doesn't have its inverse, logm. could you please add this feature? In TF, the matrix logarithm using the Schur-Parlett algorithm. Details of the algorithm can be found in Section 11.6.2 of: Nicholas J. Higham, Functions of Matrices: Theory and Computation, SIAM 2008. ISBN 978-0-898716-46-7.

sethaxen commented 3 years ago

Higham and Hopkin's updated catalog of matrix functions recommends using the inverse scaling and squaring algorithm for the matrix logarithm instead of the Schur-Parlett algorithm: http://eprints.maths.manchester.ac.uk/1851/. The corresponding Frechet derivative is given in http://eprints.maths.manchester.ac.uk/2015/. A custom JVP for the Frechet derivative may be needed, depending on how jax implements JVPs of triangular matrices.

oleg-kachan commented 2 years ago

Any progress on this? @jaeyoo @sethaxen

sethaxen commented 2 years ago

I don't plan to work on this. I don't know if @jaeyoo is.

MichaelMarien commented 2 years ago

As there is still interest, I don't mind taking a look.

MichaelMarien commented 2 years ago

Taking a first look I noticed some crucial building blocks to implement inverse scaling and squaring are also still missing (such as sqrtm). Because of that, just including logm would already add quite some code. I'm not sure it's the goal to support the full scipy.linalg API? Perhaps one of the core developers could confirm this?

hawkinsp commented 2 years ago

@MichaelMarien note there's an open PR adding sqrtm: https://github.com/google/jax/pull/9544

I don't think it's necessarily a goal to support everything in scipy.linalg, but if there are useful functions in there that people want, we have no objection to adding them. In general, these are in the "contributions welcome" category.

donthomasitos commented 2 years ago

Having logm implemented would be quite useful, especially since expm has already been added.

flywithmath commented 2 years ago

expm and logm would be very useful to have when implementing matrix manifolds

donthomasitos commented 2 years ago

This solution works well enough for matrices that are close to the identity:

@jax.jit
def logm(b):
    I = jnp.eye(b.shape[0])
    res = jnp.zeros_like(b)
    ITERATIONS = 20
    for k in range(1, ITERATIONS):
        res += pow(-1, k+1) * jnp.linalg.matrix_power(b-I, k)/k
    return res

ITERATIONS must be tuned to reach the precision you need.

lockwo commented 1 year ago

Has there been any updates on this? logm would be very useful for me, and I can take a stab at implementing it.

mattjj commented 1 year ago

No progress that I'm aware of!

lockwo commented 1 year ago

Thanks for the response. If I were to implement it, would the (apparently better, though I don’t know enough to comment on it) inverse scaling and squaring algorithm be necessary or is the Schur-Parlett (which scipy I believe uses) ok?

mattjj commented 1 year ago

I have no idea :) I'd have to read up on those.

Sometimes JAX's/XLA's constraints, like having statically-shaped intermediates, inform what the best approach is.

Maybe we can check for a TF implementation?

mattjj commented 1 year ago

I guess the thread above already talks about a TF implementation. We could check for any changes there, I guess! But it sounds like @sethaxen has some expertise on this, so maybe we should try his recommendations first.

lockwo commented 1 year ago

Yes, TF used the Schur (and assuming their documentation is correct, still does), but if there is expertise that says otherwise, I would be interested in hearing it

sethaxen commented 1 year ago

As noted in https://github.com/google/jax/issues/5469#issuecomment-773567386, Nicholas Higham recommends inverse scaling and squaring for matrix logarithm. Some direct comparisons are made in Figs 1 and 2 of http://eprints.maths.manchester.ac.uk/1851/. For a few of the matrices they examined, Schur-Parlett had a much higher error than inverse scaling and squaring. While they don't directly comment on computational cost, I believe the specialized algorithm is also faster than Schur-Parlett.

Also note that http://eprints.maths.manchester.ac.uk/2015/ gives a 2x faster version of the algorithm if the real Schur form and and real matrix square root are computed for a real matrix (I haven't checked if JAX has the real Schur form). This also avoids unnecessarily complexifying the output (i.e. if a real matrix sqrt/log exists, it is returned). That same paper gives an algorithm for the Frechet derivative (i.e. jvp), which just removes the need to AD through the Schur decomposition.

Thanks for the response. If I were to implement it, would the (apparently better, though I don’t know enough to comment on it) inverse scaling and squaring algorithm be necessary or is the Schur-Parlett (which scipy I believe uses) ok?

Scipy implements the inverse scaling and squaring algorithm for logm (https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.logm.html) but uses the Schur-Parlett algorithm for funm (https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.funm.html).

I suggest that @lockwo if you're already interested in implementing funm, that would be a useful general-purpose addition, and logm (and maybe some other matrix functions in scipy.linalg, https://docs.scipy.org/doc/scipy/reference/linalg.html#matrix-functions) could for now default to using funm. But I do think someone should eventually implement the inverse scaling and squaring algorithm with special-casing for real matrices.

sethaxen commented 1 year ago

I haven't checked scipy's funm implementation, but the docstring claims they use the pointwise Schur-Parlett algorithm, which is problematic when eigenvalues are nonunique. Higham's catalog recommends this blocked approach: https://doi.org/10.1137/S0895479802410815

tfunatomi commented 1 year ago

FYI: In MATLAB Central File Exchange, there is an implementation of Matrix logarithm and its Frechet derivative using the algorithm of http://eprints.ma.man.ac.uk/1852/. Samuel Relton (2023). Matrix Logarithm with Frechet Derivatives and Condition Number.