starsimhub / starsim

Starsim disease modeling framework
http://starsim.org
MIT License
15 stars 9 forks source link

Divergence in performance with multi-RNG #285

Closed cliffckerr closed 8 months ago

cliffckerr commented 8 months ago

On main, this test script is comparably fast with and without multi-RNG (except for the first run when it recompiles the numba functions):

import starsim as ss
import sciris as sc
import scipy.stats as sps
import warnings

n_agents = 20_000

if sc.compareversions(ss, '0.2.0') >= 0:
    Deaths = ss.Deaths
    RandNet = ss.RandomNet
else:
    Deaths = ss.background_deaths
    RandNet = ss.RandomNetwork

def make_sim_pars():
    pars = dict(
        n_agents = n_agents,
        birth_rate = 20,
        death_rate = 0.015,
        networks = dict(
            type = 'randomnet',
            n_contacts = 4  # sps.poisson(mu=4),
        ),
        diseases = dict(
            type = 'sir',
            dur_inf = 10,
            beta = 0.1,
        ),
        verbose = 0,
    )
    return pars

def test_sir_epi(i, multirng):
    sc.printgreen(f'Test with i={i+1}, multirng={multirng}, starsim={ss.__version__}')

    pars0 = make_sim_pars()
    dis = pars0.pop('diseases')
    net = pars0.pop('networks')
    net['n_contacts'] = sps.poisson(mu=net['n_contacts'])
    dis.pop('type')
    sir = ss.SIR(**dis)
    net = RandNet(**net)
    dem = ss.Pregnancy(pars=dict(birth_rate=pars0.pop('birth_rate')))
    dea = Deaths(pars=dict(death_rate=pars0.pop('death_rate')))

    # Run the simulations and pull out the results
    s0 = ss.Sim(pars0, diseases=sir, demographics=[dem, dea], networks=net).run()
    # print(s0.summarize()) # for debugging

    return s0

if __name__ == '__main__':

    with warnings.catch_warnings(action="ignore"): # requires python >= 3.11
        for i,multirng in enumerate([True, True, True, False, False, False]):
            T = sc.timer()
            ss.options(multirng=multirng)
            s0 = test_sir_epi(i, multirng)
            T.toc()

on main gives:

STIsim 0.1.8 (2024-01-30) — © 2024 by IDM
Test with i=1, multirng=True, starsim=0.1.8
Elapsed time: 1.62 s
Test with i=2, multirng=True, starsim=0.1.8
Elapsed time: 0.325 s
Test with i=3, multirng=True, starsim=0.1.8
Elapsed time: 0.415 s
Test with i=4, multirng=False, starsim=0.1.8
Elapsed time: 0.240 s
Test with i=5, multirng=False, starsim=0.1.8
Elapsed time: 0.375 s
Test with i=6, multirng=False, starsim=0.1.8
Elapsed time: 0.350 s

But on branch apis, multirng=False has gotten 30% faster (yay!), and multirng=True has gotten 7x slower (boo!):

Starsim 0.2.0 (2024-02-15) — © 2023-2024 by IDM
Test with i=1, multirng=True, starsim=0.2.0
Elapsed time: 2.62 s
Test with i=2, multirng=True, starsim=0.2.0
Elapsed time: 2.57 s
Test with i=3, multirng=True, starsim=0.2.0
Elapsed time: 2.54 s
Test with i=4, multirng=False, starsim=0.2.0
Elapsed time: 0.277 s
Test with i=5, multirng=False, starsim=0.2.0
Elapsed time: 0.273 s
Test with i=6, multirng=False, starsim=0.2.0
Elapsed time: 0.239 s
cliffckerr commented 8 months ago

Ah, so, it turns out that on main it's not even using the multi-RNG make_new_cases, so that's why it's fast. Here's another test script, similar to test_baselines.py, that includes an MSM network so node degree >1. It fails on main (the MSM network fails to update for some reason), but on apis illustrates the speed difference (also ~10x).

import sciris as sc
import starsim as ss

pars = sc.objdict(
    start         = 2000,       # Starting year
    n_years       = 20,         # Number of years to simulate
    dt            = 0.2,        # Timestep
    verbose       = 0,          # Don't print details of the run
    rand_seed     = 2,          # Set a non-default seed
)

def make_sim():
    n_agents = int(10e3)
    networks = [ss.MFNet(), ss.MSMNet(pars=dict(part_rates=0.5))]
    ppl = ss.People(n_agents)

    hiv = ss.HIV()
    hiv.pars['beta'] = {'mf': [0.15, 0.10], 'msm': [0.15, 0.15]}
    sim = ss.Sim(pars=pars, people=ppl, networks=networks, demographics=ss.Pregnancy(), diseases=hiv)
    sim.run()

    return sim

if __name__ == '__main__':

    T = sc.timer()
    ss.options(multirng=False)
    sim = make_sim()
    T.toc()

The slowness is entirely due to the pandas operations, specifically the groupby/apply:

   332       105   18127919.0 172646.8      0.3          df = pd.DataFrame({'p1': dfp1, 'p2': dfp2, 'p': dfp})
   333       105     158623.0   1510.7      0.0          if len(df) == 0:
   334                                                       return np.empty((0,), dtype=int), np.empty((0,), dtype=int)
   335                                           
   336       105 5202520344.0    5e+07     91.2          p_acq_node = df.groupby('p2').apply(lambda x: 1 - np.prod(1 - x['p']))  # prob(inf) for each potential infectee
   337       105     170461.0   1623.4      0.0          uids = p_acq_node.index.values  # UIDs of those who get come into contact with 1 or more infected person
   338                                           
   339                                                   # Slotted draw, need to find a long-term place for this logic
   340       105    2151946.0  20494.7      0.0          slots = people.slot[uids]  # Slots for the possible infectee
   341       105   15729519.0 149804.9      0.3          new_cases_bool = ss.uniform.rvs(size=np.max(slots) + 1)[slots] < p_acq_node.values
   342       105     402472.0   3833.1      0.0          new_cases = uids[new_cases_bool]
   343                                           
   344                                                   # Now choose infection source for new cases
   345       105      36159.0    344.4      0.0          def choose_source(df):
   346                                                       if len(df) == 1:  # Easy if only one possible source
   347                                                           src_idx = 0
   348                                                       else:
   349                                                           # Roulette selection using slotted draw r associated with this new case
   350                                                           cumsum = df['p'].cumsum() / df['p'].sum()
   351                                                           src_idx = np.argmax(cumsum >= df['r'])
   352                                                       return df['p1'].iloc[src_idx]
   353                                           
   354       105   38309361.0 364851.1      0.7          df['r'] = ss.uniform.rvs(size=np.max(slots) + 1)[slots[df.p2.values]]  # Draws for each potential infectee
   355       105  353537793.0    3e+06      6.2          sources = df.set_index('p2').loc[new_cases].groupby('p2').apply(choose_source)
   356                                           
   357       105   24881995.0 236971.4      0.4          return new_cases, sources[new_cases].values
daniel-klein commented 8 months ago

That groupby/apply is doing a lot of the work to compute the probability that each node acquires a case of each disease. This Pandas-based approach is more performant and memory efficient than the previous dumpy-based matrix approach, but still has opportunities for speedups. Options are Cython if need-be, but hopefully numba or one of the other jit / GPU approaches and/or figuring out a pure numpy approach that uses less than (nodes) x (edges) entries.

daniel-klein commented 8 months ago

Keen to try with #337 in place!

kaa2102 commented 8 months ago

I am evaluating the time and space/memory benefits of using polars instead of pandas. https://github.com/amath-idm/starsim/issues/196. I'm also looking at what/where the code could be refactored with polars.

cliffckerr commented 8 months ago

Fixed