google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

Support (nonsymmetric) np.linalg.eig on GPU #1259

Open clemisch opened 5 years ago

clemisch commented 5 years ago

Dear jax team,

this is just a friendly bump on the implementation of eigendecomposition and batched SVD on GPU. Are you planning on implementing these?

Should I want to implement it myself, would I be able to do it with the primitives in jax.lax, or would I have to hook up a new part of cuSolver? I am willing to spend the time as I would benefit a lot from these features, but I have no experience with expanding jax and would not know where to look.

mattjj commented 5 years ago

Thanks for the ping! Are there open issues for these already?

@hawkinsp is the expert on these things and can provide the best advice, but for GPU implementations of linalg my understanding is we set up some wrappers in jaxlib, then set up backend-specific translation rules for the appropriate primitives in lax_linalg.py. As for adding batching specifically, I think we just need to make sure batch dimensions are plumbed through properly, which if the cusolver kernels themselves don't support batch dimensions might mean adding some kind of a loop over cusolver calls. It looks like Peter added batched triangular solve and LU decomposition for GPU in #1144, so that might provide hints for the plumbing needed.

What do you think? Questions welcome! I can only provide high-level pointers to the right places, but if we sniff around there I bet we'll find things.

clemisch commented 5 years ago

Thanks for your quick response! I think there are already some issues concerning linear ops, but not specifically eigendecomp or batched SVD.

Also thanks a lot for the explanation! I'll try to get oriented and come here if I have questions.

hawkinsp commented 5 years ago

@clemisch I can take a look at these if you aren't already working on them.

clemisch commented 5 years ago

Hey @hawkinsp, thank you for getting back on this! Tbh I have not looked into this so far. It would be great if you could have a look too!

hawkinsp commented 5 years ago

PR #1314 adds batched SVD on CPU and GPU. On CPU or for large matrices on GPU it merely calls the current code in a loop. On GPU for small matrices it calls the batched Jacobi kernel from Cusolver.

Unfortunately np.linalg.eig is a little harder. I can add a "batched" implementation on CPU (simply looping over the batch elements.) However there is no support for non-symmetric eigendecomposition in Cusolver (batched or unbatched). If you really need this, then we'd need to add another dependency (probably MAGMA), which is a bunch more work. Does SVD and symmetric eigendecomposition satisfy you for now?

hawkinsp commented 5 years ago

I merged the PR that adds batched SVD support. You'll need to rebuild Jaxlib (or wait for us to make a release.)

I retitled the issue to reflect the open action item (nonsymmetric eigendecomposition on GPU).

hawkinsp commented 5 years ago

GPU Eigendecomposition via MAGMA might fall into the "contributions welcome" category, unless it proves to be a popular request.

clemisch commented 5 years ago

Thank you @hawkinsp, this is great! Non-symmetric eigendecomposition is not very urgent for me, especially if it's so cumbersome to add to jax.

Concerning batched SVD I have a question about speed: In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

import jax
import jax.numpy as np
import numpy as onp

x_host = onp.random.rand(100000, 3, 3).astype(onp.float32)
x_gpu = np.array(x_host)

svd_batch = jax.jit(jax.vmap(np.linalg.svd, 0, 0))

u1, s1, v1 = onp.linalg.svd(x_host)
u2, s2, v2 = np.linalg.svd(x_gpu)
u3, s3, v3 = svd_batch(x_gpu)

%timeit onp.linalg.svd(x_host)                       # 495 ms
%timeit np.linalg.svd(x_gpu)[0].block_until_ready()  # 122 ms
%timeit svd_batch(x_gpu)[1].block_until_ready()      # 123 ms

(sorry about the repost, I deleted the original comment by mistake)

clemisch commented 4 years ago

Bump :smiley_cat:

In this little test I only see x4 speedup vs. single-core numpy. Is this expected?

hawkinsp commented 4 years ago

I believe that's just how fast the NVidia's Cusolver batched jacobi implementation is. On my GPU, it seems we spend 99.9% of the time in the batched Jacobi kernel:

 GPU activities:
99.90%  3.20600s        12  267.17ms  233.68ms  305.82ms  void batched_svd_parallel_jacobi_32x16<float, float>(int, int, int, int, float*, unsigned long, int, float*, float*, unsigned long, int, float*, unsigned long, int, float, int, int*, float, int, int*, int, float)

The algorithm does have some tunable parameters that one might explore setting: https://docs.nvidia.com/cuda/cusolver/index.html#cuds-lt-t-gt-gesvdjbatch

If you wanted to try that, I think you just need to call the functions that modify the Jacobi parameters at this line and then rebuild Jaxlib. https://github.com/google/jax/blob/master/jaxlib/cusolver.cc#L731

clemisch commented 4 years ago

Thank you very much for clarifying!

mganahl commented 4 years ago

Hi! Just popping up to ask if there is any progress regarding eig. I'm currently preparing a JAX implementation of implicitly restarted arnoldi (non-symmetric operators). The working CPU implementation relies on jax.numpy.linalg.eig to compute eigenvalues of the Hessenberg matrix returned by Arnoldi. Would be great to have this run on GPU eventually.

joncarter1 commented 3 years ago

Hey, thought I'd also express my desire for this, my use case being finding the poles of many auto-regressive models in parallel with np.roots. Thanks to all the contributors to JAX for where it already is, it's amazing.

hawkinsp commented 3 years ago

I'm curious how folks would feel about the following: suppose MAGMA were an optional dependency of JAX. i.e., we don't bundle it in jaxlib builds, but if you install it yourself (or perhaps via conda?) and JAX can find the shared library in your library path, then jnp.linalg.eig works on GPU.

(I'm a bit reluctant to bundle it with jaxlib unconditionally for just one function!)

joncarter1 commented 3 years ago

I'd be totally fine with this. Could always be bundled in later down the line but as you say I feel the critical threshold for functional usage is perhaps a bit higher than one! :)

ianwilliamson commented 3 years ago

+1 that support for GPU-backed eig would be great.

drscook commented 1 year ago

+1 for GPU-support for nonsymmetric eig to allow GPU-enabled numpy.roots

melsophos commented 1 year ago

I also support strongly the implementation of this feature, in order to be able to use jnp.roots with GPU. I am training a network whose loss function requires computing roots of a polynomial, and training on CPU is really too slow.

mfschubert commented 1 year ago

I developed a workaround for my use case, which involves using the jax.experimental.host_callback module. Just sharing it in case it's useful.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
    return host_callback.call(
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        jax.jit(jnp.linalg.eig, device=jax.devices("cpu")[0]),
        matrix.astype(complex),
        result_shape=[eigenvalues_shape, eigenvectors_shape],
    )

jax.jit(_eig_host, device=jax.devices("gpu")[0])(m)  # This works, we can jit on GPU.
mfschubert commented 1 year ago

A brief update to this: we have a slightly modified version of this which avoids the device specification in the call to jax.jit, which is the new recommended practice:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return host_callback.call(
        _eig_cpu,
        matrix.astype(complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )
tsunhopang commented 5 months ago

Hi there! I was just wondering if there has been any progress made on this particular issue. Since it is quite a common and essential function for scientific studies.

ju-kreber commented 4 months ago

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

qiyang-ustc commented 4 months ago

I have implemented (matrix-free) eigs in JAX for scientific purposes in jaxeigs. I have borrowed some code from TensorNetwork and performed Arnoldi decomposition on the GPU. However, I have kept the last step, which involves solving the eigenproblem in the projected Krylov space, implemented on the CPU (via callback) since the algorithms is divide and conquer thus not efficient on GPU.

I must admit that this code is currently extremely unstable, and the documentation is incomplete. Despite these limitations, it is functional for my own use.

moskomule commented 4 months ago

Note for anyone using the above workaround with the host callback nowadays: host_callback has been deprecated, use external callbacks instead (most probably pure_callback()). These also work nicely under function transformations.

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

Looking forward to the implementation of eig on GPU.

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], matrix.dtype)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            val, vec =  jax.jit(jnp.linalg.eig)(matrix)
            return (val.real, val.imag), (vec.real, vec.imag)

    val, vec = jax.pure_callback(_eig_cpu,
                                 ((eigenvalues_shape, eigenvalues_shape),
                                  (eigenvectors_shape, eigenvectors_shape)),
                                 matrix)
    return val[0] + 1j * val[1], vec[0] + 1j * vec[1]
mfschubert commented 4 months ago

As pure_callback does not seem to support fp64 at the moment, you need additional tricks (in case you are using fp32).

We don't seem to have issues supporting fp32 and fp64 with the following implementation in fmmax:

def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return jax.pure_callback(
        _eig_cpu,
        (
            jnp.ones(matrix.shape[:-1], dtype=complex),  # Eigenvalues
            jnp.ones(matrix.shape, dtype=complex),  # Eigenvectors
        ),
        matrix.astype(complex),
        vectorized=True,
    )
mfschubert commented 3 weeks ago

I did some tests comparing eig performance for scipy, numpy, jax, and torch and found that they can differ quite a bit, with torch generally being the fastest. In lieu of a GPU-accelerated eig, simply using the torch version may be of benefit.

I also created a pip-installable jeig package which wraps all of these for use with jax. All implementations can be jit-compiled, including on machines with GPUs.

Here is an example of the performance difference I am seeing. This was generated on CPU colab, but torch comes out ahead also on my Apple and Intel machines. I didn't investigate the origin of the difference, but presumably there's a different linear algebra library being used in each of these packages.

image

hawkinsp commented 3 weeks ago

JAX just calls scipy's copy of LAPACK. You can probably accelerate it by installing e.g., Intel's MKL scipy.

Torch, as far as I know, also just calls LAPACK. It may be linking it with a different BLAS library; JAX will just be using openblas from scipy.

gautierronan commented 6 days ago

Hi, bumping this to ask if there is any plan from the jax team to implement this feature? @jakevdp

We'd also need this feature for dynamiqs, for the simulation of quantum systems in the so-called Floquet basis (time-periodic quantum systems).

Thanks!