Hi, thanks a lot for a great and scalable implementation!
I don't have an actual issue, just a note for people who have access only to CPUs, but need to process a lot of patients: the implementation is very nice and compatible with joblib, so that a simple parallelization trick could help to speed up the computations:
$ python benchmark.py --n-mutations 30 --n-patients 1000
Patients: 1000
Mutations 30
Maximal observed number of mutations: 18
Score (joint): -34594.397
Score (indiv): -34594.397
Time (joint): 3.193
Time (indiv): 1.234
However, for small number of patients using the function provided is unrivaled in performance by a large margin:
$ python benchmark.py --n-mutations 30 --n-patients 10
Patients: 10
Mutations 30
Maximal observed number of mutations: 12
Score (joint): -344.539
Score (indiv): -344.539
Time (joint): 0.007
Time (indiv): 0.378
Benchmark script
The script benchmark.py is the following:
import argparse
import joblib
import time
import mhn
import numpy as np
def generate_mutations(rng, n_patients: int, n_mutations: int, p: float) -> np.ndarray:
"""Prepares the mutation matrix.
Returns:
array with mutations, shape (n_patients, n_mutations)
"""
return np.asarray(
rng.binomial(1, p, size=(n_patients, n_mutations)),
dtype=np.int32,
)
def generate_theta(rng, n_mutations: int) -> np.ndarray:
"""Returns the (log-)theta matrix."""
return np.asarray(
rng.normal(size=(n_mutations, n_mutations)),
dtype=np.float64,
)
def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--n-patients", type=int, default=30, help="Number of patients")
parser.add_argument("--n-mutations", type=int, default=20, help="Number of genes in which mutations are measured.")
parser.add_argument("--prob", type=int, default=0.3, help="Probability that a gene is mutated (to simulate a naive data set).")
return parser
def calculate_score_joint(mutations, theta) -> float:
container = mhn.ssr.state_containers.StateContainer(mutations)
grad_, s_ = mhn.ssr.state_space_restriction.cython_gradient_and_score(theta, container)
# Note that s is the total loglikelihood *divided by* the number of patients,
# so we multiply it again
return len(mutations) * s_
def _get_function_for_theta(theta):
def helper(mvec):
container = mhn.ssr.state_containers.StateContainer(mvec.reshape((1, -1)))
return mhn.ssr.state_space_restriction.cython_gradient_and_score(theta, container)
return helper
def calculate_score_individual(mutations, theta) -> float:
fn = _get_function_for_theta(theta)
grads_and_scores = joblib.Parallel(n_jobs=4)(
joblib.delayed(fn)(state)
for state in mutations
)
grads = [x[0] for x in grads_and_scores]
scores = [x[1] for x in grads_and_scores]
return np.sum(scores)
# grads, scores = 0.0, 0.0
# for state in state_containers:
# grad, s =
# grads = grad + grads
# scores = scores + s
# return scores
def main() -> None:
parser = create_parser()
args = parser.parse_args()
n_patients = args.n_patients
n_mutations = args.n_mutations
rng = np.random.default_rng(42)
print(f"Patients: {n_patients}\nMutations {n_mutations}")
mutations = generate_mutations(
rng,
n_patients=n_patients,
n_mutations=n_mutations,
p=args.prob,
)
print(f"Maximal observed number of mutations: {np.max(np.sum(mutations, axis=1))}")
theta = generate_theta(rng, n_mutations)
t0 = time.time()
score_joint = calculate_score_joint(mutations, theta)
dt_joint = time.time() - t0
t0 = time.time()
score_individual = calculate_score_individual(mutations, theta)
dt_individual = time.time() - t0
print(f"Score (joint): {score_joint:.3f}")
print(f"Score (indiv): {score_individual:.3f}")
print(f"Time (joint): {dt_joint:.3f}")
print(f"Time (indiv): {dt_individual:.3f}")
if __name__ == "__main__":
main()
Hi, thanks a lot for a great and scalable implementation!
I don't have an actual issue, just a note for people who have access only to CPUs, but need to process a lot of patients: the implementation is very nice and compatible with joblib, so that a simple parallelization trick could help to speed up the computations:
However, for small number of patients using the function provided is unrivaled in performance by a large margin:
Benchmark script
The script
benchmark.py
is the following: