rmnldwg / lymph

Python package for statistical modelling of lymphatic metastatic spread in head & neck cancer.
https://lymph-model.readthedocs.io
MIT License
5 stars 4 forks source link

explosion of sampling time depending on training set size #89

Open YoelPH opened 1 week ago

YoelPH commented 1 week ago

I noticed an interesting issue, where increasing the training dataset results in a massive increase of computation time. Example code below:

import lymph
import emcee
from multiprocess import Pool
import lymph.modalities
number_of_pats = 300
graph = {
    ('tumor', 'primary')  : ['II', 'III', 'IV'], 
    ('lnl'  , 'II') :       ['III'], 
    ('lnl'  , 'III'):       ['IV'], 
    ('lnl'  , 'IV') :       []
}
model = lymph.models.Unilateral(graph_dict= graph,tumor_state = 2, allowed_states = [0,1,2], max_time = 15)
model.set_modality(name = 'pathology',spec = 1,sens = 1, kind = 'pathological')
model.set_modality(name = 'diagnostic_consensus',spec = 0.94,sens = 1, kind = 'clinical')
model.load_patient_data(combined_dataset,side = 'ipsi')

def binom_pmf(k: np.ndarray, n: int, p: float):
    """Binomial PMF"""
    if p > 1. or p < 0.:
        raise ValueError("Binomial prob must be btw. 0 and 1")
    q = (1. - p)
    binom_coeff = factorial(n) / (factorial(k) * factorial(n - k))
    return binom_coeff * p**k * q**(n - k)

def late_binomial(support: np.ndarray, p: float = 0.5) -> np.ndarray:
    """Parametrized binomial distribution."""
    return binom_pmf(support, n=support[-1], p=p)

max_t = 15
model.set_distribution(t_stage = 'early', distribution = sp.stats.binom.pmf(np.arange(max_t+1), max_t, 0.3))
model.set_distribution(t_stage = 'late', distribution = late_binomial)

starting_points = {'growth': 0.5,
                   'II_growth': 0.7,
                   'primarytoII_spread': 0.24,
                   'primarytoIII_spread': 0.03,
                   'primarytoIV_spread': 0.2,  
                   'IItoIII_spread': 0.18,
                   'IIItoIV_spread': 0.18,
                   'late_p': 0.5,}
model.set_params(**starting_points)
model.likelihood()

random_pats = model.draw_patients(number_of_pats,[1,0],seed = 13)
random_pats.replace('early',1, inplace = True)

model.load_patient_data(random_pats,side = 'ipsi')
backend = emcee.backends.HDFBackend(filename = "trinary", name = 'artificial')

ndim = 8

# number of concurrent walkers that sample the space
nwalkers = 10 * ndim
# define the log-likelihood
def log_prob_fn(theta):    
    return model.likelihood(given_params=theta, log=True)

#
# this chain will surely be too short, but it doesn't matter here
max_steps = 10000
backend.reset(nwalkers, ndim)

# initialize the sampler with some random samples
starting_points = np.random.uniform(size=(nwalkers,ndim))

# use Pool() from multiprocessing for parallelisation
with Pool() as pool:
    original_sampler = emcee.EnsembleSampler(
        nwalkers, ndim, log_prob_fn,
        pool=pool, backend=backend, parameter_names = [
                   'growth',
                   'II_growth',
                   'primarytoII_spread',
                   'primarytoIII_spread',
                   'primarytoIV_spread',                 
                   'IItoIII_spread',
                   'IIItoIV_spread',
                   'late_p']
    )
    original_sampler.run_mcmc(initial_state=starting_points, nsteps=max_steps, progress=True)

The provided code runs in roughly 15 min on my laptop. If I increase number_of_pats to 600 for example, the computation time explodes (2h). At some number of patients the code seems to have a problem with saving former results.

rmnldwg commented 6 days ago

Ok, in my testing any performance issues come down to multiprocess and the scope of the model variable. But I don't think these are the issues in your case, because model appears to be a globally shared variable.

Anyways, the number of patients did not seem to make any difference in my testing. Whether I used 50 or 5000 patients, the 10000 steps (would) always take around half an hour (office PC is slow).

I've linked a slightly modified version of your script that I adapted so I could turn multiprocess and the global/local thingy on and off. Try to replicate your issue with that.

https://gist.github.com/rmnldwg/ea790ac9fa469a6cd51613c94aa005a9

rmnldwg commented 6 days ago

Oh and caching does not seem to be an issue. The cache of the data and diagnose matrix don't change at all, as long as nobody changes either the data or the modalities. So, the cache limit isn't hit.

YoelPH commented 2 days ago

Problem has been solved :) The issue was the numpy version. While lymph 1.2.2.dev0 still required numpy < 2.0, the "newest" version allows numpy > 2.0 With numpy >2.0 the matrixmultiplication does not seem to slow down after a specific size, thus the explosion of computation time disappears.