zdebruine / RcppML

Rcpp Machine Learning: Fast robust NMF, divisive clustering, and more
GNU General Public License v2.0
89 stars 15 forks source link

Make crossValidate compatible with large matrices #38

Closed Wainberg closed 1 year ago

Wainberg commented 1 year ago

Fixes the bug that when nrow(data) * ncol(data) > .Machine$integer.max, crossValidate() gives the error "(nnz <- as.integer(nnz)) >= 0 is not TRUE".

Wainberg commented 1 year ago

Awesome! One other enhancement that would be really cool to integrate into RcppML is integrative NMF (iNMF, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC5006236), which lets you find shared NMF components across multiple studies while accounting for between-study batch effects/heterogeneity.

iNMF is used in the LIGER method for single-cell integration (https://www.cell.com/cell/fulltext/S0092-8674%2819%2930504-5, https://www.nature.com/articles/s41587-021-00867-x, https://github.com/welch-lab/liger). However, their code base has a reputation for being slow and, probably as a result, they choose the number of NMF components heuristically rather than doing a hyperparameter search.

If you could extend your ultra-fast NMF solver (really, it's ridiculous how much faster it is than anything else I've tried) and cross-validation setup to iNMF, it would allow people to find continuous axes of variation shared across multiple single-cell studies in a fast and rigorous way. I'm currently co-supervising a large meta-analysis of Alzheimer's single-cell studies (3.4 million cells across 7 studies) and we would eagerly use something like that if it were available!

zdebruine commented 1 year ago

Indeed. I am happy to work with Josh Welch on iNMF if we can get the funding to support someone to do that.

But there are theoretical problems with iNMF, it assumes that shared and unique signals (W and U_1, U_2, ... U_n) are linearly and additively associated. There is no such reasonable assumption that batch effects should at all correspond to biological signal. We have explored other methods (Linked NMF) without success. We have now implemented Graph Convolutional NMF for spatial analysis, and incidentally find it works well for integration as well. Now the secret is finding the best graph construction algorithm to squash two or more datasets together along only batch effects and not biological information.

For most applications, I still recommend learning large joint NMF models (cbind everything together -> Run NMF) and then discard factors that are largely dataset specific and do not contain biologically interesting information.

Would be excited to learn more about your work! We have used our NMF implementations on 20 million single-cell transcriptomes and have learned trillion-parameter models from WGS sequencing data. It's promising. We also have a CUDA implementation forthcoming -- scales much better for large datasets.

Wainberg commented 1 year ago

Fair point that iNMF is suboptimal. I like the idea of running NMF and discarding dataset-specific factors, I'll try that out. What's the ETA on the CUDA implementation?

We're trying to find NMF components within each broad brain cell type (like microglia and inhibitory neurons) that generalize across the 7 datasets, and then test which of these components differ between Alzheimer's cases and controls. Some issues we've run into:

  1. How to regularize k, the number of components. We're currently using the one standard error rule: selecting the smallest k that has a validation MSE within one standard error of the k with the best validation MSE, similar to lambda.1se in glmnet. But we'd like to go even farther: selecting an even smaller k with an even worse validation MSE, so long as it's not "that much worse", whatever that means. I'm guessing this would have to be done heuristically, but not sure what would be a principled heuristic. Bayesian NMF methods like https://github.com/getzlab/SignatureAnalyzer explicitly put a prior on k but are very slow.
  2. How to select the L1 penalties L1_w and L1_h on the W and H matrices in an efficient way. We tried Bayesian hyperparameter search over k, L1_w and L1_h with Optuna, but it converges too slowly. Our current strategy is to set L1_w and L1_h to 0 and just do a grid search over k, which still takes 30 hours for excitatory neurons (the largest cell type) with reps=3 on even 1.6 million cells (the largest dataset), which is still not the full 3.4 million cells we're planning to analyze eventually.
  3. How to normalize the data before running NMF. Counts per million seems to work a fair bit better than raw counts. Haven't played around with normalization much beyond that.
  4. How to fit the data into memory. Excitatory neurons can't fit into a dgCMatrix/dgRMatrix because those sparse matrix classes use 32-bit integers for indexing, so they can only fit 2,147,483,647 non-zero elements. We're getting around this by converting to a dense matrix, using a machine with several hundred GB of memory, and using float32 instead of float64 to represent the data to save memory.
zdebruine commented 1 year ago

CUDA implementation should be available March/April next year. It scales better, but is memory-bound to datasets <30 Gb or so at the moment, so it's great for Graph-Convolutional NMF (coming soon) and cross-validation on moderately sized datasets but doesn't scale yet to Tb-sized datasets.

  1. This is a very simple problem: select k that gives the best generalization of the data, as measured by mean squared error of reconstruction of a random speckled test set. This is available in singlet. In other words, we select the rank that gives the best imputation accuracy or the most robust transfer learning objective. This also happens to be the most robust model across random restarts. We are implementing use of the "adam" optimizer against test set reconstruction error this month, should be in RcppML/singlet soon. You really don't need a single heuristic. I admire your attempt to solve the problem, but it's not particularly elegant.

  2. What is the objective you are trying to minimize when tuning L1_w and L1_h?

  3. Standard log-normalization as Seurat does it. Divide columns by their sum, multiply by a scaling factor (i.e. 10,000) and take the log1p. Works as well as at least 6 other methods I've tried.

  4. This is painful to hear, we are working on releasing a solution possibly even this month. I am unable to use dgCMatrix for my work because of the size limitations you mention and have been using my own sparse matrix structures for a while. We will be rolling out new data structures for single-cell sparse matrices that require 5-10x less RAM than dgCMatrix without compromising read speed sometime Jan-Feb and play well with R/Rcpp/RcppEigen and have a C++ API that allows distributed computing. We will immediately implement NMF for these structures, but the challenge is getting developers to plug in dgCMatrix for new encodings.

Wainberg commented 1 year ago

Thanks, this is very helpful! Looking forward to the CUDA implementation and improved sparse matrices. I'll try out log normalization and singlet.

Is singlet's "mean squared error of reconstruction of a random speckled test set" conceptually different from RcppML's crossValidate(), or just implemented more efficiently? I'm currently using crossValidate() with the one-standard-error trick, which helps, but even smaller ks than the one-standard-error k give almost as good results. Maybe adding an L1 penalty on W/H would help? (Incidentally, if you're interested in implementing the one-standard-error rule in singlet, glmnet's code for it is here, and here's some theoretical justification for it.)

Re: point #2, ideally you'd be able to automatically choose the L1 penalty to optimize the same metric you're optimizing k for, the mean squared error of reconstruction of a random speckled test set. What do you think of the idea of implementing 2D adam optimization over k and L1 (and L2, if specified) simultaneously? So in singlet's RunNMF, you could choose whether to turn off L1 (L1 = 0), set it to a fixed value, or select it automatically (L1 = NULL), and similarly for L2. The defaults could be L1 = NULL and L2 = 0.