CAST-genomics / haptools

Ancestry and haplotype aware simulation of genotypes and phenotypes for complex trait analysis
https://haptools.readthedocs.io
MIT License
19 stars 4 forks source link

speeding up the `transform` command #92

Open aryarm opened 2 years ago

aryarm commented 2 years ago

Description

Transforming a set of haplotypes can take a while, especially if you have many haplotypes. A quick benchmark via the bench_transform.py script seems to indicate that it scales linearly with the number of samples, the max number of alleles in each haplotype, and the number of haplotypes. plot

@s041629 wants to transform roughly 21,000 haplotypes for 500,000 samples where the max number of alleles in a haplotype is 115. Currently (as of f45742f466ac0eb1de6d0358494d6ace8fd5856f), the Haplotypes.transform() function takes about 10-25 real minutes to accomplish this task. (It varies based on the willingness of the kernel to prioritize our job on the node.)

For now, I think this will be fast enough. But we might want to brainstorm some ideas for the future. In any case, I don't really think I can justify spending a lot more time to improve this right now (since I've already spent quite a bit), so I'm gonna take a break and solicit suggestions until I feel energized enough to tackle this again.

Details

The vast majority of the time seems to be spent in this for-loop.

https://github.com/CAST-genomics/haptools/blob/639985ec639b0f45be36f7f3e6641152dc056267/haptools/data/haplotypes.py#L1054-L1055

Unfortunately, I can't use broadcasting to speed up this loop because the size of idxs[i] might vary on each iteration. And I can't pad the array to make broadcasting work because I quickly run out of memory for arrays of that shape. Ideally, I would like to find a solution that doesn't require multi-threading or the installation of new dependencies besides those we have already.

Steps to reproduce

A basic test

This test uses just 5 samples, 4 variants, and 3 haplotypes. It's a sanity check to make sure the Haplotypes.transform() function works and that none of the code will fail.

pytest -sx tests/test_data.py::TestHaplotypes::test_haps_transform

A benchmarking test

To benchmark the function and see how it scales with the number of samples, max number of alleles in each haplotype, and the number of haplotypes in each allele, you can use the bench_transform.py script as so. It will create a plot like the one I've included above.

https://github.com/CAST-genomics/haptools/blob/639985ec639b0f45be36f7f3e6641152dc056267/tests/bench_transform.py#L19-L21

A pure numpy-based test of the for-loop

The tests above require that you install the dev setup for haptools. If you'd like to reproduce just the problematic for-loop for @s041629's situation in pure numpy you can do something like this:

import numpy as np

num_samps, num_alleles, num_haps, max_alleles = 500000, 6200, 21000, 115

idxs = [np.random.randint(0, num_alleles, np.random.randint(0, max_alleles)) for j in range(num_haps)]
arr = np.random.choice([True, False], size=num_samps*num_alleles*2).reshape((num_samps, num_alleles, 2))
output = np.empty((num_samps, num_haps, 2), dtype=np.bool_)

for i in range(num_haps):
    output[:, i] = np.all(arr[:, idxs[i]], axis=1)

A full, long-running test

If requested, I can provide files that reproduce the entire setup so that you can execute the transform command with them. The files are a bit too large to attach here.

Some things I've already tried

  1. Adding a dimension to arr (aka equality_arr) of size equal to the number of haplotypes and then performing np.all() on the other dimension. This takes longer and requires too much more memory. See fac79b07c085490351b1ad2e3b61bc06732d731b.
  2. Using numba to compile the for-loop to C. This took longer for some reason. Still not sure why. Also, I'm averse to installing new dependencies. See 5afcf81667b8a718ab28ff39f101f1de78614c43.
  3. Looping over the allele dimension instead of the haplotype dimension (since num_alleles < num_haps, at least in this situation). This was a clever strategy from Ryan Eveloff. Strangely, this also took longer than looping over the haplotypes. I'm not sure why. See 335e5f7063f84f8f739dd62e789cd5a48a1468df.

Potential ideas

  1. Try taking advantage of the fact that some of the haplotypes share alleles (ie some of the haplotypes are subsets of other haplotypes). So I could try to construct some sort of tree and then iterate through that via a dynamic programming approach, instead. In theory, this should reduce the time complexity by a log factor somewhere. A back-of-the-envelope analysis of the .hap file Matteo gave me indicates that there would be at least 1,000 leaves in such a tree.
  2. Talk to Tara! She might have had a similar issue.
aryarm commented 2 years ago

Just recording a suggestion from @RossDeVito for us to use multi-processing!

from functools import partial
import numpy as np
from multiprocessing import Pool

def single_check(idxs, arr):
    return np.all(arr[:, idxs], axis=1)

def parallel_loop(idxs, arr, num_haps, n_processes=None):
    with Pool(processes=n_processes) as pool:
        res = pool.map(partial(single_check, arr=arr), idxs)
    return np.stack(res, axis=1)

I was initially hesitant to use this because the transform command itself will be run in a multi-threaded way: we expect it to be called multiple times in parallel for each chromosome. But perhaps the multi-threading could just be an option for the user. Another consideration is that the arrays might be getting copied when we use subprocesses, which itself might take some time.

In any case, I would still like to explore other options to improve the speed first. I know that we'll ultimately have to use multi-processing in the end, but I'd like to leave it as the method of last resort, especially since other strategies we implement may change everything anyway? This code will definitely be helpful to have for then.

aryarm commented 2 years ago

Another idea from Tara is to use np.where or something similar to figure out which indices are 0s and then try to implement the short-circuiting manually. According to https://github.com/numpy/numpy/issues/3446, np.all does not really short-circuit properly.

aryarm commented 1 year ago

Another thing to consider is that we could load the data from pgenlib as a packed-bit array or as chunks. This might significantly reduce memory and presumably speed, as well?

Also, we would first need to check whether it's possible to do all of our operations on packed bit arrays or chunks of data at a time. If not, then we could consider writing our own compiled Cython extension for parts of this command, but that's a whole thing.

aryarm commented 1 year ago

@Ayimany and I discussed this issue today, and we've begun to identify a path forward! We can probably implement multiple strategies, all of which should speed things up by a bit

d-laub commented 10 months ago

I've used numba to great success with variable length data (inspired by all of Awkward Array being built with numba). For example, I use it in GenVarLoader to construct haplotypes with indels about fast as numpy can copy reference sequence into an output array (i.e. out = construct_haplotypes(reference, variants) as fast as out[...] = reference). Example of refactoring _transform() from 5afcf81 so you can get the idea (not shown: concatenating idxs into a flat array and creating offset array at caller):

import numpy as np
import numba as nb

# enable parallel execution
@nb.njit(nogil=True, parallel=True, cache=True)
def _transform(
    arr: npt.NDArray[np.bool_],
    idxs: npt.NDArray[np.integer],
    idx_offsets: npt.NDArray[np.integer],
    out: npt.NDArray[np.bool_]
):
    # use nb.prange() to mark loops that are parallelizable
    for i in nb.prange(len(idx_offsets) - 1):
        hap_idx = idxs[idx_offsets[i] : idx_offsets[i+1]]
        for _j in nb.prange(len(hap_idx)):
            j = hap_idx[_j]
            out[:, i] &= arr[:, j]

I’d also recommend speed profiling with py-spy to confirm that this is the bottleneck if you haven't already. (Can also recommend memray for memory profiling.)

aryarm commented 10 months ago

Thank you so much, @d-laub ! I'm looking forward to trying this!! I didn't think to use an idx_offsets array like that. That's quite clever.