Open vitkl opened 3 years ago
Thanks for the suggestion Vitalii! If you get to try this yourself on scATAC feel free to add a notebook with a pull request
LDA is a relatively simple model so makes sense to use tailored inference (EM). It should be fairly easy to modify this numpy implementation of variational inference for LDA (https://github.com/ddbourgin/numpy-ml/blob/master/numpy_ml/lda/lda.py#L5-L247) to work with jax. Maybe even the following could work:
# replace
import numpy as np
from scipy.special import digamma, polygamma, gammaln
# with
import jax.numpy as np
from jax.scipy.special import digamma, polygamma, gammaln
I am not sure I am invested enough to do that now.
I am curious if numpy_ml/lda
is already better than CisTopic.
This code is weird. For some reason, it handles data as a list of lists rather than an array. It would take some strategic rethinking to make that work with jax.
https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.LatentDirichletAllocation.html is probably more immediately useful as python alternative of cisTopic
Hi Emma
I would encourage you to look into Jax and numpyro for faster LDA implementation. Jax backend enables efficient multi-core CPU and GPU acceleration with no changes in code needed. You might find this repo a useful inspiration - https://github.com/srush/jax-lda