emdann / scATAC_prep

A not-too-orderly collection of scripts for preprocessing and dimensionality reduction of scATAC-seq data
7 stars 5 forks source link

faster LDA implementation #2

Open vitkl opened 3 years ago

vitkl commented 3 years ago

Hi Emma

I would encourage you to look into Jax and numpyro for faster LDA implementation. Jax backend enables efficient multi-core CPU and GPU acceleration with no changes in code needed. You might find this repo a useful inspiration - https://github.com/srush/jax-lda

emdann commented 3 years ago

Thanks for the suggestion Vitalii! If you get to try this yourself on scATAC feel free to add a notebook with a pull request

vitkl commented 3 years ago

LDA is a relatively simple model so makes sense to use tailored inference (EM). It should be fairly easy to modify this numpy implementation of variational inference for LDA (https://github.com/ddbourgin/numpy-ml/blob/master/numpy_ml/lda/lda.py#L5-L247) to work with jax. Maybe even the following could work:

# replace
import numpy as np
from scipy.special import digamma, polygamma, gammaln
# with
import jax.numpy as np
from jax.scipy.special import digamma, polygamma, gammaln

I am not sure I am invested enough to do that now.

I am curious if numpy_ml/lda is already better than CisTopic.

vitkl commented 3 years ago

This code is weird. For some reason, it handles data as a list of lists rather than an array. It would take some strategic rethinking to make that work with jax.

https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.LatentDirichletAllocation.html is probably more immediately useful as python alternative of cisTopic