StefanDevAccount / LearnMHN

A Python package to work with Mutual Hazard Networks (MHN)
MIT License
0 stars 2 forks source link

Parallelization using `joblib` #1

Open pawel-czyz opened 1 year ago

pawel-czyz commented 1 year ago

Hi, thanks a lot for a great and scalable implementation!

I don't have an actual issue, just a note for people who have access only to CPUs, but need to process a lot of patients: the implementation is very nice and compatible with joblib, so that a simple parallelization trick could help to speed up the computations:

$ python benchmark.py --n-mutations 30 --n-patients 1000
Patients: 1000
Mutations 30
Maximal observed number of mutations: 18
Score (joint): -34594.397
Score (indiv): -34594.397
Time (joint): 3.193
Time (indiv): 1.234

However, for small number of patients using the function provided is unrivaled in performance by a large margin:

$ python benchmark.py --n-mutations 30 --n-patients 10  
Patients: 10
Mutations 30
Maximal observed number of mutations: 12
Score (joint): -344.539
Score (indiv): -344.539
Time (joint): 0.007
Time (indiv): 0.378

Benchmark script

The script benchmark.py is the following:

import argparse
import joblib
import time

import mhn
import numpy as np

def generate_mutations(rng, n_patients: int, n_mutations: int, p: float) -> np.ndarray:
    """Prepares the mutation matrix.

    Returns:
        array with mutations, shape (n_patients, n_mutations)
    """
    return np.asarray(
        rng.binomial(1, p, size=(n_patients, n_mutations)),
        dtype=np.int32,
    )

def generate_theta(rng, n_mutations: int) -> np.ndarray:
    """Returns the (log-)theta matrix."""
    return np.asarray(
        rng.normal(size=(n_mutations, n_mutations)),
        dtype=np.float64,
    )

def create_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser()
    parser.add_argument("--n-patients", type=int, default=30, help="Number of patients")
    parser.add_argument("--n-mutations", type=int, default=20, help="Number of genes in which mutations are measured.")
    parser.add_argument("--prob", type=int, default=0.3, help="Probability that a gene is mutated (to simulate a naive data set).")

    return parser

def calculate_score_joint(mutations, theta) -> float:
    container = mhn.ssr.state_containers.StateContainer(mutations)
    grad_, s_ = mhn.ssr.state_space_restriction.cython_gradient_and_score(theta, container)
    # Note that s is the total loglikelihood *divided by* the number of patients,
    # so we multiply it again
    return len(mutations) * s_

def _get_function_for_theta(theta):
    def helper(mvec):
        container = mhn.ssr.state_containers.StateContainer(mvec.reshape((1, -1)))
        return mhn.ssr.state_space_restriction.cython_gradient_and_score(theta, container)
    return helper

def calculate_score_individual(mutations, theta) -> float:
    fn = _get_function_for_theta(theta)

    grads_and_scores = joblib.Parallel(n_jobs=4)(
            joblib.delayed(fn)(state)
            for state in mutations
    )

    grads = [x[0] for x in grads_and_scores]
    scores = [x[1] for x in grads_and_scores]

    return np.sum(scores)

    # grads, scores = 0.0, 0.0
    # for state in state_containers:
        # grad, s = 

        # grads = grad + grads
        # scores = scores + s

    # return scores

def main() -> None:
    parser = create_parser()
    args = parser.parse_args()

    n_patients = args.n_patients
    n_mutations = args.n_mutations

    rng = np.random.default_rng(42)

    print(f"Patients: {n_patients}\nMutations {n_mutations}")
    mutations = generate_mutations(
        rng,
        n_patients=n_patients,
        n_mutations=n_mutations,
        p=args.prob,
    )

    print(f"Maximal observed number of mutations: {np.max(np.sum(mutations, axis=1))}")

    theta = generate_theta(rng, n_mutations)

    t0 = time.time()
    score_joint = calculate_score_joint(mutations, theta)   

    dt_joint = time.time() - t0

    t0 = time.time()
    score_individual = calculate_score_individual(mutations, theta)
    dt_individual = time.time() - t0

    print(f"Score (joint): {score_joint:.3f}")
    print(f"Score (indiv): {score_individual:.3f}")

    print(f"Time (joint): {dt_joint:.3f}")
    print(f"Time (indiv): {dt_individual:.3f}")

if __name__ == "__main__":
    main()
StefanDevAccount commented 1 year ago

Hey, very cool results! Maybe the mhn package will also include a parallelized version of the CPU code in future versions, we'll see.