Kevin-Haigis-Lab / speclet

A Bayesian hierarchical model to discover tissue-specific cancer driver genes and synthetic lethal interactions from CRISPR/Cas9 LoF screens.
GNU General Public License v3.0
0 stars 0 forks source link

Add support for faster PyMC backends #161

Closed jhrcook closed 2 years ago

jhrcook commented 2 years ago

Add support for the JAX backend for PyMC. It seems there are substantial speed gains even just on the CPU, though GPU is worth trying to get working. Seems like the only change is to use a different "inference button" function from PyMC – no changes to the model.

Here is a post comparing the speeds of the different samplers (and to Stan): MCMC for big datasets: faster sampling with JAX and the GPU.

This is a plot from the post showing the runtimes on the largest experimental dataset:

walltime_full

One concern is that GPU time is far more restrictive on O2: "Currently there is an active limit of 160 GPU hours for each user... If you use just 1 GPU card, the partition maximum wall time will limit you to 120 hours (5 days)." If I can figure out a system for saving chains mid-run, then I can use the GPU's over multiple runs.

Using O2 GPU resources

jhrcook commented 2 years ago

Code for the blog post linked above: https://github.com/martiningram/mcmc_runtime_comparison

jhrcook commented 2 years ago

Got Numpyro backend working locally. Definitely a speed-up just with CPU.

jhrcook commented 2 years ago

In the notebooks notebooks/misc/010_015_pymc-backends(_o2).ipynb, I ran MCMC on two benchmark models using the default pm.sample() and the Numpyro backend sampler. The first model is just a linear regression model with a lot of data and the second is the current hierarchical negative binomial model for the CRISPR screen.

Model Sampler Local (MacBook) O2
linear reg. default 00:22 00:14
linear reg. Numpyro 00:18 00:12
CRISPR default 18:38 18:32
CRISPR Numpyro 02:15 03:52

There was a difference in the "transformation time" for the linear regression model on my MacBook Pro vs. on O2. It took a lot longer on my computer. I think the difference was in RAM. The RAM was completely maxed on my computer, but I allocated 32 GB RAM on O2 and the "transformation time" was trivial. The "transformation times" were the same on both computers for the CRISPR model.

Overall, I think this was a successful demonstration of the Numpyro backend for PyMC.