flatironinstitute / inferelator

Task-based gene regulatory network inference using single-cell or bulk gene expression data conditioned on a prior network.
BSD 2-Clause "Simplified" License
47 stars 12 forks source link

Performance improvements #46

Closed bgorissen closed 3 years ago

bgorissen commented 3 years ago

The improvements to base regression (base_regression.py and bayes_stats.py) are:

  1. Using the LDL decomposition instead of the Bunch–Kaufman decomposition to solve the regression problem X'Xb = X'y for b, which requires only half the flops.
  2. Removing the explicit rank checks because scipy.linalg.solve already throws a LinAlgError if X'X is singular. It throws a LinAlgWarning if X'X is near-singular, which is almost as bad from a theoretical perspective (multicollinearity), and equally bad from a practical perspective (numerically the boundary between between the cases is fuzzy). The code therefore treats a LinAlgWarning as an error.
  3. Replacing np.log to math.log because that's faster for scalars and it's called so often that the difference matters. There is a slight difference between the two: np.log(0) triggers a (supressed) warning and returns -inf while math.log(0) gives an error, so there is an if-statement for log(0).

For amusr, almost all computation time is spent in the functions updateS and updateB, which mostly contain simple loops and basic arithmetic operations. Compiling this code to machine code reduces the runtime with a factor of 10 (on my workstation without multiprocessing). For this I used numba, which only requires a few lines of code. Because numba cannot convert a matrix to column major order, which was a step in updateB, this is now done in the line that calls updateB (np.asarray(sparse_matrix, order="F")). It is not possible to check within updateB to check if the matrix that is passed is indeed in column major order (Numba does not support S.flags['F_CONTIGUOUS']), so this is something that requires care when updateB is used in a different context.

asistradition commented 3 years ago

Thanks for the PR! Looks great.

  1. There's a failing unit test because we can't actually guarantee that xtx is positive definite (only that it's positive semidefinite) and the 'pos' hint doesn't work as a result. That's not a big deal though, I can roll that back to symmetric.

  2. Numba is a bit more of a problem. There's a lot of code out there that's hard to get running because of numba and the knock-on effect numba has on other dependencies. I don't want it as a required dependency for this project. That said, I'm fine with it as an optional dependency. So what I'm gonna suggest is that it gets rolled into an optional dependency (e.g. pass use_numba=True and it'll use the numba routines).

I'll go ahead and make those changes and then merge it into the dev branch so it can be tested on the cluster.

asistradition commented 3 years ago

Aight I went ahead and made those changes and merged them into dev. The implementation is that you can set wkf.set_run_parameters(use_numba=True) and it will replace updateB and updateS with the numba JIT routines at runtime.

I will find some time next week to write some tests for this functionality so I can package this up as a release.

bgorissen commented 3 years ago

There's a failing unit test because we can't actually guarantee that xtx is positive definite (only that it's positive semidefinite) and the 'pos' hint doesn't work as a result. That's not a big deal though, I can roll that back to symmetric.

I'm not sure why the test fails. The test shows that LinAlgWarning was issued, and there is code in base_regression.py to treat this warning as an error (resulting in beta_hat=0 being selected in base_regression.predict_error_reduction). Locally this test is passed.

Numba is a bit more of a problem. There's a lot of code out there that's hard to get running because of numba and the knock-on effect numba has on other dependencies.

What specific code are you referring to? Unless numba is imported and enabled for a specific function, it does not interfere with other code. Numba gives an order of magnitude improvement, so perhaps you can display a warning for low performance when use_numba=False.

asistradition commented 3 years ago

The scipy solve test failure is specific to newer versions of scipy (I didn't get it locally with 1.6.x but it failed when I updated to 1.7.1). It's not unexpected, as the test case is positive semidefinite, and the 'pos' hint requires positive definite.

asistradition commented 3 years ago

The performance boost from numba is really impressive btw, I have the cluster regression tests going and they're way faster now. It's just that the dependency is hard to maintain (same as the math kernel library, which is why that's optional as well). I just don't want to be answering issues about cryptic runtime errors that occur when numba and numpy aren't exactly matching versions for the next three years.

bgorissen commented 3 years ago

The scipy solve test failure is specific to newer versions of scipy

That's not the point. The logs of the failed test show:

inferelator/tests/test_base_regression.py::TestBaseRegression::test_predict_error_reduction /home/runner/work/inferelator/inferelator/inferelator/regression/base_regression.py:259: LinAlgWarning: Ill-conditioned matrix (rcond=8.88178e-17): result may not be accurate. beta_hat = scipy.linalg.solve(xtx, xty, assume_a='pos')

Due to base_regression.py:15 this should be treated as an error: warnings.filterwarnings(action='error', category=scipy.linalg.LinAlgWarning)

So this code (around base_regression.py:259) should run the except block:

        try:
            xt = x_leaveout.T
            xtx = np.dot(xt, x_leaveout)
            xty = np.dot(xt, y)
            beta_hat = scipy.linalg.solve(xtx, xty, assume_a='pos')
        except (np.linalg.LinAlgError, scipy.linalg.LinAlgWarning):
            beta_hat = np.zeros(len(leave_out), dtype=np.dtype(float))

On my local machine, when I run python -m nose, a LinAlgWarning is triggered, and indeed the except block is run. I'm at a loss why this doesn't work here or on your local machine. I've tried Scipy 1.4.1 and 1.7.1.

asistradition commented 3 years ago

Sorry, I'm not sure exactly what the problem is. Scipy and numpy are backed LAPACK and BLAS so it's possible that it's a difference in those libraries.

The regression test package passed for amusr (with the numba JIT versions of the updateB and updateS routines) and runs substantially faster. It's failing for BBSR (the differences aren't huge, but they are there), so I'm going to roll those changes back.

I'll try to get tests written this week so that it can get merged into the stable version and released with a version number.

Thanks again!

bgorissen commented 3 years ago

Sorry, I'm not sure exactly what the problem is.

test_base_regression.py:test_predict_error_reduction gives the wrong output:

x: array([-116., -116., -116.]) y: array([-133.33, -133.33, -133.33])

where x is the output from base_regression.predict_error_reduction and y is the (hardcoded) expected output. The reason for this erroneous output is that beta_hat scipy.linalg.solve(xtx, xty, assume_a='pos') gives beta_hat = [0.05 0.08333333] instead of failing (in which case beta_hat = [0 0] is used).

This is not an issue about blas implementations or the numerically vague boundary between singular and ill-conditioned matrices. Solve throws a LinAlgWarning, so it should be possible to handle that warning and return beta_hat = [0 0]. In fact, the code to do that is there, and works locally.

After some digging I found that the coverage module interferes with the warnings module. The same test passes with nose and there is nothing wrong with the code itself.

asistradition commented 3 years ago

Oh, that's interesting; I wouldn't have expected there to be a problem with the warnings module. I'm running coverage with the test package locally (& it's part of the CI workflow), so that would explain it.

Good detective work.