theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
341 stars 47 forks source link

Computing macrostates takes very long for certain matrices #518

Closed Marius1311 closed 3 years ago

Marius1311 commented 3 years ago

It's not really a bug, I think it's an inefficiency. For certain matrices, which are reducible, or almost reducible, it can take very long to compute macrostates because of the time pyGPCCA takes to compute the stationary distribution. I've observed this in two datasets so far. I'm posting below the snakeviz output I get for a matrix with 12k states when I compute 2 macrostates. Note that this matrix comes from a Palantir kernel with k=10, for k=3, everything is fine. image

Essentially, the entire time is taken to compute the stationary distribution. I'm posting the spectrum of this matrix below: image

It looks like it's reducible (several 1 eigenvalues), but it's actually not, there is only 1 real 1-eigenvalue (I checked for irreducibility by computing strongly connected components, this matrix is irreducible. So what's going on here? I checked the condition number, which is order 1e7, so not particularly ill conditioned, compared to other examples we deal with. Also, computing the schur decomposition is not a problem, so I wonder why computing the stationary distribution would be so difficult?

Versions:

cellrank==1.2.0+g62915ec.d20210222 scanpy==1.6.0 anndata==0.7.5 numpy==1.19.2 numba==0.51.2 scipy==1.5.3 pandas==1.1.4 pygpcca==1.0.1 scikit-learn==0.23.2 statsmodels==0.12.1 python-igraph==0.8.3 scvelo==0.2.3 pygam==0.8.0 matplotlib==3.3.2 seaborn==0.11.0

...

michalk8 commented 3 years ago

Seems like scipy.sparse.linalg.eigs issue. Maybe we could default to a different solver if PETSc/SLEPc is installed?

Marius1311 commented 3 years ago

yes, I would strongly vote for that.

Marius1311 commented 3 years ago

Could you open an PR on pyGPCCA for that please? I can then go ahead and test this again on my matrix.

Marius1311 commented 3 years ago

We need to look into the computation of the stationary distribution in pyGPCCA a bit more. On a reducible example, I just got

---------------------------------------------------------------------------
ArpackNoConvergence                       Traceback (most recent call last)
<ipython-input-58-f86032d967d9> in <module>
----> 1 g_fwd.compute_macrostates(n_states=6, cluster_key='Ground_truth')

~/Projects/cellrank/cellrank/tl/estimators/_gpcca.py in compute_macrostates(self, n_states, n_cells, use_min_chi, cluster_key, en_cutoff, p_thresh)
    174         start = logg.info(f"Computing `{n_states}` macrostates")
    175         try:
--> 176             self._gpcca = self._gpcca.optimize(m=n_states)
    177         except ValueError as e:
    178             # this is the following case - we have 4 Schur vectors, user requests 5 states, but it splits the conj. ev.

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/pygpcca/_gpcca.py in optimize(self, m)
   1061         # coarse-grained stationary distribution
   1062         self._pi_coarse = (
-> 1063             None if self.stationary_probability is None else np.dot(self.memberships.T, self.stationary_probability)
   1064         )
   1065         # coarse-grained input (initial) distribution of states

~/miniconda3/envs/py38_devel/lib/python3.8/functools.py in __get__(self, instance, owner)
    965                 val = cache.get(self.attrname, _NOT_FOUND)
    966                 if val is _NOT_FOUND:
--> 967                     val = self.func(instance)
    968                     try:
    969                         cache[self.attrname] = val

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/pygpcca/_gpcca.py in stationary_probability(self)
   1251         """
   1252         try:
-> 1253             return stationary_distribution(self._P)
   1254         except ValueError as e:
   1255             warnings.warn(f"Stationary distribution couldn't be calculated. Reason: {e}.")

~/miniconda3/envs/py38_devel/lib/python3.8/functools.py in wrapper(*args, **kw)
    873                             '1 positional argument')
    874 
--> 875         return dispatch(args[0].__class__)(*args, **kw)
    876 
    877     funcname = getattr(func, '__name__', 'singledispatch function')

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/pygpcca/utils/_utils.py in _sds(P)
    224 def _sds(P: spmatrix) -> np.ndarray:
    225     # get the top two eigenvalues and vecs so we can check for irreducibility
--> 226     vals, vecs = eigs(P.transpose(), k=2, which="LR", ncv=None)
    227 
    228     # check for irreducibility

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/scipy/sparse/linalg/eigen/arpack/arpack.py in eigs(A, k, M, sigma, which, v0, ncv, maxiter, tol, return_eigenvectors, Minv, OPinv, OPpart)
   1345     with _ARPACK_LOCK:
   1346         while not params.converged:
-> 1347             params.iterate()
   1348 
   1349         return params.extract(return_eigenvectors)

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/scipy/sparse/linalg/eigen/arpack/arpack.py in iterate(self)
    755                 pass
    756             elif self.info == 1:
--> 757                 self._raise_no_convergence()
    758             else:
    759                 raise ArpackError(self.info, infodict=self.iterate_infodict)

~/miniconda3/envs/py38_devel/lib/python3.8/site-packages/scipy/sparse/linalg/eigen/arpack/arpack.py in _raise_no_convergence(self)
    375             vec = np.zeros((self.n, 0))
    376             k_ok = 0
--> 377         raise ArpackNoConvergence(msg % (num_iter, k_ok, self.k), ev, vec)
    378 
    379 

ArpackNoConvergence: ARPACK error -1: No convergence (34271 iterations, 0/2 eigenvectors converged) [ARPACK error -14: DNAUPD  did not find any eigenvalues to sufficient accuracy.]
Marius1311 commented 3 years ago

I will check whether this is solved by the linked PR in pyGPCCA.

Marius1311 commented 3 years ago

Happy to report that on the 12k example matrix from above, https://github.com/msmdev/pyGPCCA/pull/22 reduced computation time to about half a second (from over 4 min). Furthermore, in another, reducible example, this also worked and correctly identified the matrix to be reducible. In conclusion, https://github.com/msmdev/pyGPCCA/pull/22 fully solved this issue.