LBHB / NEMS

The Neural Encoding Model System (NEMS) provides tools to fit and evaluate models of sensory encoding and decoding.
GNU General Public License v3.0
7 stars 1 forks source link

speed up scipy optimize #8

Open jacobpennington opened 2 years ago

jacobpennington commented 2 years ago

Try out the suggestion here re: using JAX library to compute cost function gradient and providing that information to scipy. (for their specific example, quoting a ~5000x speedup) https://stackoverflow.com/questions/68507176/faster-scipy-optimizations

A little more involved than I first thought, since np.<whatever> operations have to be replaced with jnp.<whatever>. But aside from a few caveats like not using in-place operations most Layer implementations would be otherwise identical, so this could be added in as a backend (and would be much simpler than TF, just define evaluate_jax and still use scipy, but with a hook to use the gradient version). Without configuring GPU usage this would still be slower than TF, but may be a good intermediate option that's still much faster than vanilla scipy/numpy and easier for new users to implement.

Separately, try adding numba to the standard scipy evaluates (http://numba.pydata.org/). It looks like it's supposed to work with standard numpy unlike JAX, so may be simple to integrate improvements.

jacobpennington commented 2 years ago

Notes so far on Numba: worked great for STP revision, ~1000x speed up vs old non-quick algorithm, ~40x speedup vs quick_eval algorithm. Adds one extra dependency but so far seems worth it for those speedups, pending more testing to make sure this doesn't interfere with optimization (it shouldn't) and the outputs are close enough numerically for a variety of inputs.

Other @njit options to try: nogil=True : unlocks global interpreter lock so that multiple threads can run simultaneously. This one requires some thought, but can speed up things like error checks (detecing NaNs for example) with no side-effects that are safe to run asynchronously. cache=True: saves compiled functions to pycache (or a fallback directory) to speedup subsequent runs. Default cache=False means functions get compiled again every time the program is run, so first-uses will always be slower. Not a big factory for optimization, but turning this on for things like STP could make post-fit analyses faster (like loading and plotting a model), although those are generally fast enough already. parallel=True: automatically parallelize a lot of numpy functions and other operators that already support parallelization.

https://numba.readthedocs.io/en/stable/reference/jit-compilation.html

jacobpennington commented 2 years ago

For FIR:

1) Should be able to get rid of the loop over output channels with proper reshaping. 2) Try routines from scipy.ndimage? They're supposed to have some additional optimizations for specific use-cases. 3) Alternatively, try coding up the custom 1D filtering in cython or numba as simple nested for loops. Those should still be easy enough to read, and may speed things up by ignoring a lot of options/checks that generic convolution functions include.

Not clear how much effect on performance these changes would have.