cupy / cupy

NumPy & SciPy for GPU
https://cupy.dev
MIT License
9.4k stars 845 forks source link

Feature Request: Add support to find smallest magnitude eigenvalues in eigsh #4692

Open safreita1 opened 3 years ago

safreita1 commented 3 years ago

Hi,

Thanks for all the work on this great library!

I'm wondering if there are any plans to add support for finding the smallest eigenvalues using the eigsh function (similar to the Scipy function: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.eigsh.html)?

Being able to calculate the smallest eigenvalues is quite important for a range of analysis using the Laplacian matrix.

leofang commented 3 years ago

Looks like the support for which='SM'/'SA' has been in upstream for quite a long time. Doesn't seem to be very difficult to support.

safreita1 commented 3 years ago

That's great to hear! I'm hoping to use it to accelerate some graph algorithms in a library I'm developing: https://github.com/safreita1/TIGER

leofang commented 3 years ago

Hi @safreita1 I skimmed over the code. If you're interested in putting up a PR, it should be fairly straightforward: The part handling SM/SA is done on CPU, so it's just straight NumPy. Looks to me here are the only two place to change: https://github.com/cupy/cupy/blob/880918c978dc6f8eb91585764e2501b1ca2bea89/cupyx/scipy/sparse/linalg/_eigen.py#L61-L63 https://github.com/cupy/cupy/blob/880918c978dc6f8eb91585764e2501b1ca2bea89/cupyx/scipy/sparse/linalg/_eigen.py#L271-L278 It's just a matter of reverting the sort order.

As for the tests, here's the file to touch: https://github.com/cupy/cupy/blob/755025fa7966acd94f66799f0897961dfc8b1c6f/tests/cupyx_tests/scipy_tests/sparse_tests/test_linalg.py#L115-L122 I'm more than happy to provide advices whenever needed πŸ™‚

povinsahu1909 commented 3 years ago

Can I work on this issue?

povinsahu1909 commented 3 years ago

While going through this issue I found that if device.get_cusparse_handle() is None then eigsh function ends up on a value error. I think the most probable reason for his behaviour is because of line 204, which is the multiplication of NumPy array and Cupy array.

https://github.com/cupy/cupy/blob/880918c978dc6f8eb91585764e2501b1ca2bea89/cupyx/scipy/sparse/linalg/_eigen.py#L200-L205

If I am wrong please guide me to overcome this issue.

povinsahu1909 commented 3 years ago

@leofang, I also want to bring to your notice that this function was not listed on the comparison list. And many more functions was not listed there but were already implemented.

leofang commented 3 years ago

Can I work on this issue?

Absolutely πŸ™‚ AFAIK no one is working on this.

While going through this issue I found that if device.get_cusparse_handle() is None then eigsh function ends up on a value error. I think the most probable reason for his behaviour is because of line 204, which is the multiplication of NumPy array and Cupy array.

I think line 204 is fine. Both A and v are CuPy arrays IIUC. Do we really need to worry about this part? What happens when you follow my suggestion https://github.com/cupy/cupy/issues/4692#issuecomment-782930397? I expected that to enable SM/SA only requires very small changes, as the workflow is largely identical.

I also want to bring to your notice that this function was not listed on the comparison list. And many more functions was not listed there but were already implemented.

It would be great if you can push a separate PR to document all missing functions you've found! πŸ‘ If it's unclear where to add the entries just let me know.

povinsahu1909 commented 3 years ago

I think line 204 is fine. Both A and v are CuPy arrays IIUC. Do we really need to worry about this part?

There is no need to worry about this, as I overlooked it.

povinsahu1909 commented 3 years ago

What happens when you follow my suggestion #4692 (comment)? I expected that to enable SM/SA only requires very small changes, as the workflow is largely identical.

I applied your suggestion and made some changes. Currently, tests fail for SM and were giving AssertionError due to either array mismatch or due to value greater than the tolerance.

I added following lines

elif which == 'SA':
        idx = numpy.argsort(w)[::-1]
elif which == 'SM':
        idx = numpy.argsort(numpy.absolute(w))[::-1]

to https://github.com/cupy/cupy/blob/880918c978dc6f8eb91585764e2501b1ca2bea89/cupyx/scipy/sparse/linalg/_eigen.py#L271-L278 Could you please guide me on how to pass the unit tests.

It would be great if you can push a separate PR to document all missing functions you've found! +1 If it's unclear where to add the entries just let me know.

It would be easy if you tell me where to add entries as I am currently new to this project.

leofang commented 3 years ago

I applied your suggestion and made some changes. Currently, tests fail for SM and were giving AssertionError due to either array mismatch or due to value greater than the tolerance.

Could you post the full traceback here? It's not possible for me or most people to immediately see where an AssertionError could be raised from without more info.

It would be great if you can push a separate PR to document all missing functions you've found! +1 If it's unclear where to add the entries just let me know.

It would be easy if you tell me where to add entries as I am currently new to this project.

Well, for this task a quick search is the best first step to start with. You can first take a look at what functions in the same module are documented (which you already did), then you can search the module name either using GitHub's search bar on the top left corner -- which is surprisingly powerful if you select search "in this repository" -- or just grep locally:

$ grep -R docs/ -e 'cupyx.scipy.sparse.linalg'
docs/source/_comparison_generator.py:        'scipy.sparse.linalg', 'cupyx.scipy.sparse.linalg', 'SciPy')
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.LinearOperator
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.norm
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.spsolve
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.spsolve_triangular
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.cg
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.gmres
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.lsqr
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.eigsh
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.svds
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.aslinearoperator
docs/source/reference/sparse.rst:   cupyx.scipy.sparse.linalg.lobpcg

then you'll notice functions in the same module arer documented together, and in this case you know docs/source/reference/sparse.rst is the place to go πŸ™‚

povinsahu1909 commented 3 years ago

Could you post the full traceback here? It's not possible for me or most people to immediately see where an AssertionError could be raised from without more info.

When return_eigenvectors=True in testcase:

_______________________TestEigsh.test_sparse[csr-_param_3_{k=3, return_eigenvectors=True, use_linear_operator=True, which='SM'}]___________________

AssertionError: Only cupy raises error
Traceback (most recent call last):
    File "/home/povins/project/cupy-main/cupy/cupy/testing/helper.py", line 47, in _call_func
        result = impl(self, *args, **kw)
    File "/home/povins/project/cupy-main/cupy/tests/cupyx_tests/scipy_tests/sparse_tests/test_linalg.py", line 158, in test_sparse
        return self._test_eigsh(a, xp, sp)
    File "/home/povins/project/cupy-main/cupy/tests/cupyx_tests/scipy_tests/sparse_tests/test_linalg.py", line 145, in _test_eigsh
        assert(res < to)
AssertionError: assert array(14.449945, dtype=float32) < 1e-05

cupy/tests/cupyx_tests/scipy_tests/sparse_tests/test_linalg.py:145: AssertionError

When return_eigenvectors=False in testcase:

_______________________TestEigsh.test_sparse[csr-_param_11_{k=3, return_eigenvectors=False, use_linear_operator=True, which='SM'}]___________________

self = <<cupyx_tests.scipy_tests.sparse_tests.test_linalg.TestEigsh object at 0x7fdf94e51a10>  parameter: {'k': 3, 'return_eigenvectors': False, 'use_linear_operator': True, 'which': 'SM'}>, args = ()
kw = {'dtype': <class 'numpy.float32'>, 'format': 'csr'}, dtype = 'f'

    @_wraps_partial(impl, name)
    def test_func(self, *args, **kw):
        for dtype in dtypes:
            try:
                kw[name] = numpy.dtype(dtype).type
>               impl(self, *args, **kw)

cupy/cupy/testing/helper.py:825: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cupy/cupy/testing/helper.py:343: in test_func
    check_func(cupy_r, numpy_r)
cupy/cupy/testing/helper.py:495: in check_func
    array.assert_allclose(c, n, rtol1, atol1, err_msg, verbose)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

actual = array([0.03023103, 0.1156988 , 0.27153265], dtype=float32), desired = array([-1.1323305, -0.2813205,  1.182373 ], dtype=float32), rtol = 1e-05, atol = 1e-05, err_msg = '', verbose = True

    def assert_allclose(actual, desired, rtol=1e-7, atol=0, err_msg='',
                        verbose=True):
        """Raises an AssertionError if objects are not equal up to desired tolerance.

        Args:
             actual(numpy.ndarray or cupy.ndarray): The actual object to check.
             desired(numpy.ndarray or cupy.ndarray): The desired, expected object.
             rtol(float): Relative tolerance.
             atol(float): Absolute tolerance.
             err_msg(str): The error message to be printed in case of failure.
             verbose(bool): If ``True``, the conflicting
                 values are appended to the error message.

        .. seealso:: :func:`numpy.testing.assert_allclose`

        """  # NOQA
        numpy.testing.assert_allclose(
            cupy.asnumpy(actual), cupy.asnumpy(desired),
>           rtol=rtol, atol=atol, err_msg=err_msg, verbose=verbose)
E       AssertionError: 
E       Not equal to tolerance rtol=1e-05, atol=1e-05
E       
E       Mismatched elements: 3 / 3 (100%)
E       Max absolute difference: 1.1625615
E       Max relative difference: 1.4112705
E        x: array([0.030231, 0.115699, 0.271533], dtype=float32)
E        y: array([-1.132331, -0.281321,  1.182373], dtype=float32)

cupy/cupy/testing/array.py:26: AssertionError
leofang commented 3 years ago

OK I see what you meant. In this case, I would suggest to look at the return_eigenvectors description:

Return eigenvectors (True) in addition to eigenvalues. This value determines the order in which eigenvalues are sorted. The > sort order is also dependent on the which variable. For which = β€˜LM’ or β€˜SA’: If return_eigenvectors is True, eigenvalues are sorted by algebraic value. If return_eigenvectors is False, eigenvalues are sorted by absolute value. For which = β€˜BE’ or β€˜LA’: eigenvalues are always sorted by algebraic value. For which = β€˜SM’: If return_eigenvectors is True, eigenvalues are sorted by algebraic value. If return_eigenvectors is False, eigenvalues are sorted by decreasing absolute value.

https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.eigsh.html

It also controls how return values are sorted, so it's likely that the patch you applied in https://github.com/cupy/cupy/issues/4692#issuecomment-797713190 is not sufficient to guarantee the sort order.

povinsahu1909 commented 3 years ago

you applied in #4692 (comment) is not sufficient to guarantee the sort order

In my case, only the second condition for 'SM' was not fulfilled, but then also test cases for return_eignevectors=True is failing

For which = β€˜SM’: If return_eigenvectors is True, eigenvalues are sorted by algebraic value. If return_eigenvectors is False, eigenvalues are sorted by decreasing absolute value.

So I tried after modifying code by adding condition for second case, but still tests were unsuccessful Modification that I made are

    if return_eigenvectors:
        idx = cupy.argsort(w)
        return w[idx], x[:, idx]
    else:
        if which == 'SM':
            return cupy.sort(w)[::-1]
        else:    
            return cupy.sort(w)

https://github.com/cupy/cupy/blob/880918c978dc6f8eb91585764e2501b1ca2bea89/cupyx/scipy/sparse/linalg/_eigen.py#L126-L130

Am I doing something wrong :confused:?

povinsahu1909 commented 3 years ago

I tried solving the problem of adding the SM method, and I found that for return_eigenvectors = True, all the eigenvalues generated match with what is expected, but their corresponding eigenvectors don't. And thus causing the failure of tests.

And for return_eigenvectors = False, up to three iteration generated eigenvalue matches desired eigenvalue but after that due to tolerance test fails.

But all this is not happening with any other method.

leofang commented 3 years ago

Sorry for not getting back @povinsahu1909. I don't have time to look into this yet, but looks like eigsh might have a bug somewhere (#5001, #5024), not sure if it is also the root cause that you couldn't get it right...

povinsahu1909 commented 3 years ago

I think that the major reason behind the error with 'SM' is due to the current algorithm used. Currently, we are using the Thick-Restart Lanczos method which did not produce solutions with the same accuracies for smallest eigenvalues. Page 38 para 4 https://sdm.lbl.gov/~kewu/ps/trlan.ps

Whereas Scipy is using the Implicitly Restarted Lanczos Method (IRLM) with shift and invert spectral transformation mode, which transforms the eigenvalue problem to an equivalent problem with different eigenvalues. This can be found here

fjbay commented 3 years ago

You might be able to use this shift-inverse based algorithm directly from the cuSolver API: https://docs.nvidia.com/cuda/cusolver/index.html#cusolver-lt-t-gt-csreigsi

However, note that the current implementation only returns one eigenpair, with the eigenvalue closest to the target.

fjbay commented 3 years ago

In order to find multiple eigenvalues closest to a target value (e.g. magnitude zero), the straightforward strategy would be to go the same way that is used in the original scipy package:

If the eigsh has an argument sigma is not None we follow those steps:

  1. Calculate shift mat -= sigma * eye(mat.shape)
  2. Calculate LU decomposition of mat
  3. Construct a LinearOperator and define _matvec(self, b) that uses splu.solve() to calculate x in A * x = b. This is an efficient approach to invert the matrix A on the fly.
  4. Pass the LinearOperator into the actual eigenvalue solver instead of the original matrix and calculate it's eigenpairs.
  5. Undo the shift-invert for eigenvalues (trivial, can be done on CPU). Eigenvectors are the same for a matrix and its shift-invert.

As far as I can see, every we need is already there. One can follow the scipy code path and just replace it with cupyx code equivalents as needed.

albertomercurio commented 2 years ago

Hello to everyone.

I tried the method suggested by @fjbay and it seems to work. More precisely, I implemented the following function

import scipy.sparse as scisp
import cupy as cp
import cupyx.scipy.sparse as cpsp
import cupyx.scipy.sparse.linalg as cpsp_la

def sp_eigenvalues_shift_inverse_gpu(A, k = 6, sigma = 0, which = "LM"):
    A_gpu_shifted = cpsp.csr_matrix(A - sigma * scisp.eye(A.shape[0]))

    A_gpu_LU = cpsp_la.splu(A_gpu_shifted) # LU decomposition
    A_gpu_LO = cpsp_la.LinearOperator(A_gpu_shifted.shape, A_gpu_LU.solve) # Linear Operator

    eigenvalues_gpu, eigenstates_gpu = cpsp_la.eigsh(A_gpu_LO, k = k, which = which)

    eigenvalues_gpu = eigenvalues_gpu.get()
    eigenstates_gpu = eigenstates_gpu.get()
    eigenvalues_gpu = (1 + eigenvalues_gpu * sigma) / eigenvalues_gpu

    idx = np.argsort(eigenvalues_gpu)
    # eigenstates_gpu = cp.transpose(eigenstates_gpu)
    eigenvalues_gpu = eigenvalues_gpu[idx]
    # eigenstates_gpu = cp.transpose(eigenstates_gpu[idx]).get()

    return eigenvalues_gpu, eigenstates_gpu

which accepts a scipy sparse matrix and returns it eigenvalues and eigenvectors using the shift-inverse method.

I compared this approach with the scipy one. However, the cupy version is about 1.5 times slower than scipy. Here I show the comparison with the following matrix:

<100000x100000 sparse matrix of type '<class 'numpy.complex128'>'
    with 1684000 stored elements in Compressed Sparse Row format>

Using scipy %time scisp.linalg.eigsh(H, k = 8, sigma = sigma_eigs)

CPU times: user 3min 19s, sys: 1.61 s, total: 3min 21s
Wall time: 35.4 s
array([-2.59464574, -2.51914667, -2.51564061, -2.44332617, -2.43867536,
       -2.36739023, -2.36182555, -2.29138912])

Using cupy %time sp_eigenvalues_shift_inverse_gpu(H, k = 8, sigma = sigma_eigs)

CPU times: user 2min 31s, sys: 2.04 s, total: 2min 33s
Wall time: 45.3 s
array([-2.59464574, -2.51914667, -2.51564061, -2.44332617, -2.43867536,
       -2.36739023, -2.36182555, -2.29138912])

which is a little bit slower. I noticed that it spends about 27 seconds in the cpsp_la.splu(A_gpu_shifted) function, with the GPU utilization of 30%. And it spends about 17 seconds in the cpsp_la.eigsh(A_gpu_LO, k = k, which = which) with the 100% of utilization.

I'm using the NVIDIA GeForce GTX 1650 Ti

Am I doing something wrong?

Edit

using cupyx.scipy.sparse.linalg.spilu instead of cupyx.scipy.sparse.linalg.splu is a little bit faster. It takes about 5 seconds for the lu decomposition and 8 seconds for the diagonalization. The eigenvalues are not exactly the same, but it might be acceptable.

CPU times: user 28.7 s, sys: 268 ms, total: 29 s
Wall time: 14.9 s
array([-2.5942168 , -2.51823041, -2.51609732, -2.44250206, -2.43907822,
       -2.36634605, -2.36239581, -2.29051711])
fjbay commented 2 years ago

Hello @albertomercurio,

from briefely looking at your description, I don't think you are doing anything wrong.

Your matrix is probably not very well suited for the combination of solvers you are using.

Let me get a bit into details here: For a this shift-invert algorithm, we are in fact switching between two kinds of 'solvers' here, which is on the one hand the repeated solution of the equation system for the partial inversion of your matrix (column or row vector) via the LU solve. This function is called multiple times from the actual Lanzcos solver on the other hand, which finds all the eigenvectors you desire.

Now, last time I had a look at the cupy code, LU solver provided by cupy does the LU factorization on CPU (one time) and then solves the equation system on the GPU (multiple time). However, this solve turned out to be relatively slow on my graphic cards (Quadro RTX 4000 and RTX A4000) and much faster on my CPU (Threadripper 2970WX).

So I use the cupy eigsh solver in a combination with a CPU based LU solver (either scipy default or Intel PARDISO, which can be much better on some(!) matrices).

As you already noticed, the LU solver often limits the performance here. With your incomplete LU, the total solve times are much faster.

Concerning the comparison of cupy/scipy eigsh itself: If you were to analyze the stacks of your processes (or profile the Kernel usage on your GPU), for sufficiently large problem sets (large matrices and many eigenvectors), you'd see that most of the time is spend in gemm operations. These matrix-matrix-multiplications are needed for the orthogonalization step of the eigenvectors and in fact will be the limiting factors for large problems.

These operations do not scale well on multicore CPUs, since they will run into CPU cache and RAM bandwidth issues. On GPUs you'd typically perform a bit better, as of the larger RAM bandwidth. However, you can get even more performance out of the system by making efficient use of Nvidia's TensorCores (TC). They are built for matrix-matrix-multiplications with KI systems in mind, where single precision is a sufficient data type. If can get away with this reduced precision for your type of problem, then switching to a single data type can drastically improve performance on the GPU end. But please keep in mind, that the reduced precision (and range) of a single data type can introduce additional issues, so careful testing is adviced.

Edit: I just looked up your GPU. This model does not have any TCs yet, so switching to a single data type is probably not worth testing for you at the moment.

I'd recommend testing PARDISO as LU solver, which I'm interfacing using pyMKL (for single precision support and fixed issues with latest MKL release see: https://github.com/dwfmarchant/pyMKL/pull/16).

MartinXM commented 2 years ago

Hey guys, I found Pytorch has a function torch.linalg.eigh which could achieve good results but is much faster than cupy or scipy. I use it to replace scipy.sparse.linalg.eigsh in normalized_cut. In my case, the whole implementation is around 100x faster. And it can be GPU accelerated if you need. (Although sometimes CPU is faster than GPU).

fjbay commented 2 years ago

Please note that this issue concerns an eigensolver for sparse matrices, while torch.linalg.eigh is for dense matrices.

In your case, the matrices are likely small and rather dens and a sparse solver is just the wrong tool for the job. You shoud compare pytorch to scipy.linalg.eigh instead.

RisaKirisu commented 1 year ago

Hi @fjbay , Which library did you use to compute LU decomposition on GPU? Did you conclude that LU decomposition on GPU is slower than on CPU?

I am also trying to use Cupy to accelerate sparse eigenpair computation and needs to find eigenvalues around an initial guess. I tried to implement a shift-invert method exactly like descriped in this thread, and found that LU decomposition is the speed bottleneck here. The cupyx.scipy.sparse.linalg.splu method just calls the scipy splu function. It uses only a single CPU thread, so it doesn't benefit from multithreading on CPU either. Is your LU solver multithreaded?