theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
347 stars 46 forks source link

Profile GPCCA with krylov-schur #104

Closed Marius1311 closed 4 years ago

Marius1311 commented 4 years ago

I think we should run a profiler over this to identify current bottlenecks - I have a feeling we can further speed this up.

michalk8 commented 4 years ago

Here's my log on the pancreas data with n=12:

Sun May 10 23:31:49 2020    restats

         3546064 function calls (3498571 primitive calls) in 16.301 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   167458    5.093    0.000    5.093    0.000 {method 'reduce' of 'numpy.ufunc' objects}
421166/376374    3.103    0.000   10.332    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
    24201    1.803    0.000    9.848    0.000 gpcca.py:579(_fill_matrix)
        1    1.375    1.375   14.458   14.458 optimize.py:456(_minimize_neldermead)
        1    0.520    0.520    0.520    0.520 {built-in method scipy.sparse.linalg.dsolve._superlu.gstrf}
    24200    0.423    0.000   11.267    0.000 gpcca.py:327(_objective)
    45679    0.386    0.000    0.386    0.000 {method 'take' of 'numpy.ndarray' objects}
144176/144008    0.320    0.000    5.203    0.000 fromnumeric.py:73(_wrapreduction)
    89064    0.295    0.000    0.295    0.000 {method 'dot' of 'numpy.ndarray' objects}
    24200    0.201    0.000    0.201    0.000 {method 'trace' of 'numpy.ndarray' objects}
        1    0.197    0.197    0.581    0.581 gpcca.py:438(_indexsearch)
    24205    0.183    0.000    0.274    0.000 twodim_base.py:216(diag)
    44354    0.178    0.000    0.306    0.000 linalg.py:2316(norm)
        2    0.146    0.073    0.146    0.073 {method 'solve' of 'slepc4py.SLEPc.EPS' objects}
    22729    0.141    0.000    0.141    0.000 {method 'argsort' of 'numpy.ndarray' objects}
   144176    0.099    0.000    0.099    0.000 fromnumeric.py:74(<dictcomp>)
    48595    0.098    0.000    0.573    0.000 fromnumeric.py:2092(sum)
    48568    0.097    0.000    0.097    0.000 {built-in method numpy.zeros}
    92377    0.089    0.000    0.689    0.000 fromnumeric.py:55(_wrapfunc)
    46924    0.084    0.000    4.524    0.000 fromnumeric.py:2551(amax)
    67291    0.061    0.000    0.061    0.000 {method 'ravel' of 'numpy.ndarray' objects}
117928/117721    0.058    0.000    0.060    0.000 {built-in method numpy.array}
    48595    0.055    0.000    0.681    0.000 <__array_function__ internals>:2(sum)
    46924    0.053    0.000    4.629    0.000 <__array_function__ internals>:2(amax)
    45428    0.053    0.000    0.487    0.000 fromnumeric.py:97(take)
    24200    0.052    0.000   11.319    0.000 optimize.py:325(function_wrapper)
136874/136872    0.048    0.000    0.058    0.000 {built-in method builtins.isinstance}
    68569    0.048    0.000    2.916    0.000 <__array_function__ internals>:2(dot)
    24661    0.047    0.000    0.047    0.000 {method 'reshape' of 'numpy.ndarray' objects}
    45428    0.041    0.000    0.569    0.000 <__array_function__ internals>:2(take)
    24200    0.040    0.000    0.274    0.000 fromnumeric.py:1625(trace)
    22723    0.039    0.000    0.113    0.000 fromnumeric.py:1693(ravel)
      134    0.037    0.000    0.037    0.000 {built-in method scipy.sparse._sparsetools.csc_matvec}
    71534    0.036    0.000    0.081    0.000 _asarray.py:88(asanyarray)
    24204    0.036    0.000    0.187    0.000 fromnumeric.py:2676(amin)
       12    0.035    0.003    0.035    0.003 decomp_svd.py:16(svd)
120867/120866    0.031    0.000    0.032    0.000 {built-in method builtins.getattr}
    22716    0.030    0.000    0.205    0.000 fromnumeric.py:997(argsort)
    24207    0.030    0.000    0.184    0.000 fromnumeric.py:2236(any)

I currently don't see an quicky and easy way to further optimize the code, I'd have to have a deeper look, since based on this, we can only optimize the usage of numpy functions.

Marius1311 commented 4 years ago

What I take from this is that the sparse schur decomposition is not the bottleneck here - the solve method of SLEPc only had a total time of 0.146 seconds.

Why do you sort by internal time? Can you sort by total time and post the output please?

Marius1311 commented 4 years ago

It could also be that this example was too small to see what the bottleneck will be for large datasets - these computations scale differently with cell number, so we should also try on a larger example.

michalk8 commented 4 years ago

Agreed, testing on a larger dataset will give us a better overview. Here's a callgraph for a better reprsentation: callgraph

Marius1311 commented 4 years ago

Super nice! That clearly shows that on a dataset of this size, the bottleneck is scipy's fmin! And within that, it's the fill_matrix method...

I would be interested to see whether this changes on a larger dataset, i.e. whether the schur decomposition ever becomes the bottleneck.

Marius1311 commented 4 years ago

I run it on the lung and got the following results:

1381597 function calls (1281515 primitive calls) in 681.425 seconds

   Ordered by: internal time
   List reduced from 1076 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1  643.843  643.843  643.843  643.843 {built-in method scipy.sparse.linalg.dsolve._superlu.gstrf}
        2   13.045    6.522   13.045    6.522 {method 'solve' of 'slepc4py.SLEPc.EPS' objects}
    65380    9.463    0.000    9.463    0.000 {method 'dot' of 'numpy.ndarray' objects}
        2    3.832    1.916    3.832    1.916 {method 'createAIJ' of 'petsc4py.PETSc.Mat' objects}
197467/99612    2.494    0.000    4.652    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
        2    0.962    0.481    0.962    0.481 {built-in method _imp.create_dynamic}
        6    0.883    0.147    0.883    0.147 {built-in method scipy.sparse._sparsetools.csr_matmat_pass2}
        1    0.829    0.829   27.512   27.512 gpcca.py:164(_do_schur)
        1    0.506    0.506    0.506    0.506 {method 'solve' of 'SuperLU' objects}
        1    0.452    0.452    3.018    3.018 gpcca.py:744(coarsegrain)
       16    0.423    0.026    0.717    0.045 numeric.py:2244(within_tol)
        1    0.421    0.421    2.883    2.883 gpcca.py:507(_opt_soft)
        7    0.393    0.056    0.393    0.056 {method 'computeError' of 'slepc4py.SLEPc.EPS' objects}
        1    0.360    0.360    0.360    0.360 {built-in method scipy.sparse._sparsetools.csc_tocsr}
        6    0.344    0.057    0.344    0.057 {built-in method scipy.sparse._sparsetools.csr_matmat_pass1}
        1    0.304    0.304    0.949    0.949 gpcca.py:443(_indexsearch)
      164    0.294    0.002    0.294    0.002 {built-in method builtins.abs}
        1    0.273    0.273    0.273    0.273 {built-in method scipy.sparse._sparsetools.csr_sort_indices}
        1    0.269    0.269  645.110  645.110 stationary_vector.py:76(stationary_distribution_from_backward_iteration)
    97679    0.257    0.000    0.510    0.000 linalg.py:2316(norm)
        2    0.214    0.107    0.614    0.307 {scipy.sparse.csgraph._traversal.connected_components}
        2    0.211    0.106    1.097    0.549 assessment.py:32(is_transition_matrix)
      132    0.188    0.001    0.188    0.001 twodim_base.py:216(diag)
        1    0.180    0.180    0.180    0.180 {built-in method scipy.sparse._sparsetools.csc_minus_csc}
       16    0.114    0.007    0.841    0.053 numeric.py:2167(isclose)
        1    0.107    0.107    0.107    0.107 {built-in method scipy.sparse._sparsetools.csr_row_index}
     1028    0.081    0.000    0.081    0.000 {method 'reduce' of 'numpy.ufunc' objects}
        2    0.051    0.025    0.051    0.025 {petsc4py.PETSc._initialize}
        1    0.049    0.049   27.561   27.561 gpcca.py:1112(_do_schur_helper)
    97774    0.047    0.000    0.047    0.000 {method 'ravel' of 'numpy.ndarray' objects}
        3    0.042    0.014    0.042    0.014 {method 'reduceat' of 'numpy.ufunc' objects}
    97679    0.040    0.000    0.602    0.000 <__array_function__ internals>:2(norm)
    97821    0.038    0.000    3.081    0.000 <__array_function__ internals>:2(dot)
        1    0.038    0.038    0.038    0.038 {built-in method scipy.sparse._sparsetools.csr_sum_duplicates}
       80    0.035    0.000    0.035    0.000 {method 'astype' of 'numpy.ndarray' objects}
        1    0.028    0.028  680.676  680.676 gpcca.py:1197(optimize)
        1    0.025    0.025   18.593   18.593 sorted_schur.py:445(sorted_schur)
98815/98814    0.025    0.000    0.025    0.000 {built-in method numpy.array}
197305/196218    0.023    0.000    0.035    0.000 {built-in method builtins.issubclass}
        6    0.022    0.004    0.022    0.004 <frozen importlib._bootstrap_external>:830(get_data)
        1    0.020    0.020   17.507   17.507 sorted_schur.py:274(sorted_krylov_schur)
98168/98167    0.019    0.000    0.041    0.000 _asarray.py:16(asarray)
    97685    0.019    0.000    0.029    0.000 linalg.py:121(isComplexType)
        2    0.013    0.007    0.013    0.007 {built-in method io.open}
    97679    0.010    0.000    0.010    0.000 linalg.py:2312(_norm_dispatcher)
    97821    0.010    0.000    0.010    0.000 multiarray.py:707(dot)
       12    0.010    0.001    0.011    0.001 decomp_svd.py:16(svd)
       10    0.006    0.001    0.006    0.001 {method 'nonzero' of 'numpy.ndarray' objects}
      128    0.006    0.000    0.106    0.001 gpcca.py:583(_fill_matrix)
        2    0.005    0.003    0.005    0.003 {built-in method _imp.exec_dynamic}
Marius1311 commented 4 years ago

I realised that there was a problem with computing the stationary vectors, changing this to an eigenvector method reduced the computation time by a factor 20:

 1374120 function calls (1276155 primitive calls) in 34.288 seconds

   Ordered by: internal time
   List reduced from 912 to 50 due to restriction <50>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2   13.143    6.572   13.143    6.572 {method 'solve' of 'slepc4py.SLEPc.EPS' objects}
    65380    6.040    0.000    6.040    0.000 {method 'dot' of 'numpy.ndarray' objects}
        2    3.666    1.833    3.666    1.833 {method 'createAIJ' of 'petsc4py.PETSc.Mat' objects}
      126    3.170    0.025    3.170    0.025 {built-in method scipy.sparse._sparsetools.csc_matvec}
197390/99541    2.469    0.000    4.173    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
        1    0.795    0.795   23.361   23.361 gpcca.py:164(_do_schur)
        6    0.741    0.124    0.741    0.124 {built-in method scipy.sparse._sparsetools.csr_matmat_pass2}
        1    0.433    0.433    2.981    2.981 gpcca.py:744(coarsegrain)
        1    0.415    0.415    2.179    2.179 gpcca.py:507(_opt_soft)
        7    0.377    0.054    0.377    0.054 {method 'computeError' of 'slepc4py.SLEPc.EPS' objects}
       16    0.329    0.021    0.514    0.032 numeric.py:2244(within_tol)
        1    0.325    0.325    0.325    0.325 {built-in method scipy.sparse._sparsetools.csc_tocsr}
        1    0.290    0.290    0.933    0.933 gpcca.py:443(_indexsearch)
        6    0.263    0.044    0.263    0.044 {built-in method scipy.sparse._sparsetools.csr_matmat_pass1}
    97677    0.246    0.000    0.503    0.000 linalg.py:2316(norm)
        2    0.202    0.101    0.571    0.286 {scipy.sparse.csgraph._traversal.connected_components}
      166    0.184    0.001    0.184    0.001 {built-in method builtins.abs}
      132    0.182    0.001    0.182    0.001 twodim_base.py:216(diag)
        2    0.165    0.082    0.816    0.408 assessment.py:32(is_transition_matrix)
        1    0.097    0.097    0.097    0.097 {built-in method scipy.sparse._sparsetools.csr_row_index}
       16    0.093    0.006    0.615    0.038 numeric.py:2167(isclose)
     1026    0.080    0.000    0.080    0.000 {method 'reduce' of 'numpy.ufunc' objects}
      127    0.079    0.001    3.258    0.026 arpack.py:720(iterate)
    97896    0.045    0.000    0.045    0.000 {method 'ravel' of 'numpy.ndarray' objects}
       41    0.040    0.001    0.040    0.001 {method 'astype' of 'numpy.ndarray' objects}
        3    0.039    0.013    0.039    0.013 {method 'reduceat' of 'numpy.ufunc' objects}
    97677    0.038    0.000    0.591    0.000 <__array_function__ internals>:2(norm)
    97819    0.038    0.000    2.864    0.000 <__array_function__ internals>:2(dot)
        1    0.025    0.025   33.854   33.854 gpcca.py:1197(optimize)
    99088    0.023    0.000    0.023    0.000 {built-in method numpy.array}
   196080    0.021    0.000    0.021    0.000 {built-in method builtins.issubclass}
        1    0.021    0.021   17.436   17.436 sorted_schur.py:445(sorted_schur)
    98263    0.019    0.000    0.040    0.000 _asarray.py:16(asarray)
    97683    0.018    0.000    0.029    0.000 linalg.py:121(isComplexType)
        1    0.016    0.016   17.414   17.414 sorted_schur.py:274(sorted_krylov_schur)
        1    0.014    0.014   23.375   23.375 gpcca.py:1112(_do_schur_helper)
    97677    0.010    0.000    0.010    0.000 linalg.py:2312(_norm_dispatcher)
    97819    0.010    0.000    0.010    0.000 multiarray.py:707(dot)
       12    0.009    0.001    0.010    0.001 decomp_svd.py:16(svd)
        9    0.009    0.001    0.009    0.001 {method 'nonzero' of 'numpy.ndarray' objects}
      128    0.006    0.000    0.106    0.001 gpcca.py:583(_fill_matrix)
        2    0.005    0.003    0.005    0.003 _util.py:103(_aligned_zeros)
        1    0.005    0.005    3.283    3.283 stationary_vector.py:101(stationary_distribution_from_eigenvector)
        1    0.004    0.004    0.011    0.011 arpack.py:598(__init__)
        1    0.004    0.004    0.369    0.369 csc.py:136(tocsr)
        1    0.003    0.003    0.004    0.004 arpack.py:760(extract)
        2    0.003    0.002    0.003    0.002 {method 'create' of 'petsc4py.PETSc.Mat' objects}
        1    0.003    0.003    0.004    0.004 interface.py:658(aslinearoperator)
        6    0.003    0.000    0.003    0.000 {built-in method scipy.sparse._sparsetools.coo_tocsr}
        1    0.003    0.003    0.003    0.003 arpack.py:310(__init__)
Marius1311 commented 4 years ago

We realized that the current bottleneck was computing the stationary distribution using the backwards iteration. I changed this in https://github.com/msmdev/msmtools/pull/3.