sgkit-dev / sgkit

Scalable genetics toolkit
https://sgkit-dev.github.io/sgkit
Apache License 2.0
227 stars 32 forks source link

Understand Hail GWAS regression implementation #448

Closed hammer closed 3 years ago

hammer commented 3 years ago

We'd like to get our costs closer to Hail's costs. To do so, it would be helpful to understand the Hail implementation and see if there are any ideas in their implementation that we might reuse in ours.

tomwhite commented 3 years ago

In sgkit, the gwas_linear_regression function "[removes] sample (i.e. person/individual) covariates through orthogonal projection of both the genetic variant and phenotype data". It cites the BOLT-LMM paper as the source for doing this, where in the Supplementary Information it states "We model covariates by projecting them out from both genotypes and phenotypes, which is equivalent to including them as fixed effects."

Projecting out the covariates from the genotypes (as opposed to just the phenotypes, say) is an expensive operation, since it involves computing an outer matmul whose output is of shape (variants, samples).

https://github.com/pystatgen/sgkit/blob/c210fae783e927a45d5b852105cf42d9b47a32e4/sgkit/stats/association.py#L74

Note that this operation is at the heart of the performance issues observed in #390. (There is a mitigation in #454, but it's still an expensive operation.)

For the UKB GWAS run on Hail, there are a couple of linear regression implementations. In Hail 0.1 there is linreg3, and in Hail 0.2 there is linear_regression_rows. It's not clear to me if either of these projects out covariates from the genotypes. I think it's important to understand if they do or not, so we have a like-for-like comparison in the performance and cost of running GWAS using sgkit and Hail.

hammer commented 3 years ago

You've likely already seen it, but I believe linear_regression_rows is the right place to start given https://github.com/Nealelab/UK_Biobank_GWAS/issues/37#issuecomment-751293038:

One key development was an early version of https://hail.is/docs/0.2/methods/stats.html#hail.methods.linear_regression_rows that we used.

tomwhite commented 3 years ago

Exactly - Hail's linear_regression_rows uses standard least-squares linear regression, whereas sgkit's gwas_linear_regression uses a mixed model, which is computationally more expensive.

For the purposes of comparison I think we should have an implementation of the standard linear regression in sgkit. I'd appreciate some help writing it though, as I'm not sure how to express it using NumPy/Dask array operations.

eric-czech commented 3 years ago

sgkit's gwas_linear_regression uses a mixed model

What makes you say that @tomwhite? There are no random effects in it and it's equivalent to ordinary least-squares/fixed effects models, e.g. the tests are against statsmodels.OLS.

It cites the BOLT-LMM paper as the source for doing this

IIRC, lots of other GWAS modeling methods do the covariate projection (e.g. REGENIE and Fast-LMM in addition to BOLT) and it's often used as both a performance enhancement and a way to simplify the code/algebra for the per-variant regressions. It wouldn't surprise me though if, in general, it's an improvement for single-server software but not in a distributed system. I remember being worried at first that it was doing something super-linear in XL (the variants x samples matrix) and then thinking that it wouldn't be a problem based on dask/array/linalg.py#L1404-L1407 (a QR factorization is run on only the much smaller XC). Do you have any sense of which part of that might be blowing things up? Apologies if I missed that somewhere since I'm only partially caught up.

It's not clear to me if either of these projects out covariates from the genotypes. I think it's important to understand if they do or not, so we have a like-for-like comparison in the performance

I don't think Hail does it, or I've never seen anything obvious in the code for it. In theory it wouldn't be difficult to broadcast/repeat the covariates for all the individual regressions instead, but that should mean introducing another dimension into the matrix algebra. It might be worth the effort regardless. I am though at a loss for an intuition on why it would be substantially better that way (if not worse).

FYI the REGENIE spark (but not c++) code avoids covariate projection based on https://github.com/projectglow/glow/issues/266#issuecomment-670146265, which wasn't a very convincing reason IMO since every other part of the method operates based on out-of-core chunks anyhow. Maybe there is some wisdom to neither Hail nor Glow doing the projection though.

tomwhite commented 3 years ago

There are no random effects in it and it's equivalent to ordinary least-squares/fixed effects models, e.g. the tests are against statsmodels.OLS.

Ah, sorry I was mistaken.

I remember being worried at first that it was doing something super-linear in XL (the variants x samples matrix) and then thinking that it wouldn't be a problem based on dask/array/linalg.py#L1404-L1407 (a QR factorization is run on only the much smaller XC). Do you have any sense of which part of that might be blowing things up?

XC @ da.linalg.lstsq(XC, XL)[0] is large: it has shape (variants, samples), whereas the two inputs to the matmul are both small. So there's a lot of communication. I wonder if it's possible to broadcast the covariates (so they are only sent to each worker once), and combine them with each block in XL.

tomwhite commented 3 years ago

I tried another experiment, where I used Dask map_blocks to independently process each block of variants. This is akin to what Hail does (except I'm still doing covariate projection as discussed above). It's important that the array is not chunked in the samples dimension, which means that the chunk size in the variants dimension has to be quite small. I used 64 to give ~100MB chunks.

On 8x data the processing time on a 16 node cluster was 77s, compared to 110s from the equivalent run in https://github.com/pystatgen/sgkit/issues/390#issuecomment-768332568. This is a 1.4x speedup.

Translating this into normalized numbers (using https://github.com/pystatgen/sgkit/issues/390#issuecomment-768380382):

This is a ~6x speedup from the original, and if we could use preemptible instances to get a ~5x cost saving, I think that would put us in the same cost ballpark as Hail.

Ideally Dask would do this kind of optimization for us so we didn't have to resort to map_blocks, but it's good to know that this is a technique we can fall back to if needed.

Here's the notebook I used, and the performance report.

eric-czech commented 3 years ago

That's amazing @tomwhite!

It's important that the array is not chunked in the samples dimension

Do you know what happens using the original code without chunking in the samples dimension (instead of map_blocks)?
Would the chunking mismatch of https://github.com/pystatgen/sgkit/issues/390#issuecomment-768332568 still apply as a limitation?

tomwhite commented 3 years ago

Thanks @eric-czech.

Do you know what happens using the original code without chunking in the samples dimension (instead of map_blocks)?

No, I haven't tried that.

I haven't tried it, but given that the map_blocks approach is embarrassingly parallel, it's more likely to be robust when running with preemptible instances.

ravwojdyla commented 3 years ago

@tomwhite I would be +1 to trying "original code without chunking" since that is essentially the case from the suggested "optimisation flow" in https://github.com/pystatgen/sgkit/issues/390#issuecomment-764950336:

To be more concrete, we saw benefit of going from 5216 -> 652 in the variant axis, specifically no spilling (less memory overhead), and increasing the size of samples should reduce the communication/transfer time. So overall to be more precise, I believe we should try to:

  • increase the chunk size in the sample axis as much as practical
  • and if we start spilling reduce the chunk size in the variant axis
  • if we need to increase WT/Memory ratio previous comments describe that
  • so a "degenerated" case would be chunking: {variant: 1, sample: MAX}.
tomwhite commented 3 years ago

I tried with chunks: {variant: 64, sample: -1} and it was a bit faster than map_blocks on this dataset:

Here's the notebook and performance report.

tomwhite commented 3 years ago

I updated the benchmarking code to use Dask 2021.3.1 (it was previously using 2.30.0) and I got the same result as before for gwas_simulation_unchunked_sample.ipynb (70s).

This version of Dask has the improved matmul implementation (https://github.com/dask/dask/pull/7000), and dask-cloudprovider has had some changes that resulted in alterations. See https://github.com/tomwhite/gwas-benchmark/commit/2931599f5b67c51943711e222c88151f8237a8bd.