starsimhub / starsim

Starsim disease modeling framework
http://starsim.org
MIT License
15 stars 9 forks source link

Cannot run custom measles: distribution already initialised #468

Closed wagathu closed 7 months ago

wagathu commented 7 months ago

# -*- coding: utf-8 -*-
"""
Created on Tue Apr 16 16:25:15 2024

@author: macuser
"""
# Importing the Modules
import sciris as sc
import numpy as np
import starsim as ss
import matplotlib.pyplot as plt

# Creating the class measles
class Measles2(ss.SIR):

    def __init__(self, pars=None, par_dists=None, *args, **kwargs):
        """ Initialize with parameters """

        pars = ss.omergeleft(pars,
            # Natural history parameters, all specified in days
            dur_exp = 8,       # (days) - source: US CDCv3
            dur_inf = 11,      # (days) - source: US CDC
            p_death = 0.005,   # Probability of death

            # Initial conditions and beta
            init_prev = 0.005,
            beta = None,
        )

        par_dists = ss.omergeleft(par_dists,
            dur_exp   = ss.normal,
            dur_inf   = ss.normal,
            init_prev = ss.bernoulli,
            p_death   = ss.bernoulli,
        )

        super().__init__(pars=pars, par_dists=par_dists, *args, **kwargs)

        # SIR are added automatically, here we add E
        self.add_states(
            ss.State('exposed', bool, False),
            ss.State('ti_exposed', float, np.nan),
        )

        self.age_bins = [0, 1, 5]

        return

    @property
    def infectious(self):
        return self.infected | self.exposed

    def update_pre(self, sim):
        # Progress exposed -> infected
        infected = ss.true(self.exposed & (self.ti_infected <= sim.ti))
        self.exposed[infected] = False
        self.infected[infected] = True

        # Progress infected -> recovered
        recovered = ss.true(self.infected & (self.ti_recovered <= sim.ti))
        self.infected[recovered] = False
        self.recovered[recovered] = True

        # Trigger deaths
        deaths = ss.true(self.ti_dead <= sim.ti)
        if len(deaths):
            sim.people.request_death(deaths)
        return

    def get_age_bins(self, sim, uids):
        ages = sim.people.age[uids]
        age_bins = np.digitize(ages, self.age_bins)
        return age_bins

    def set_prognoses(self, sim, uids, source_uids=None):
        """ Set prognoses for those who get infected """
        # Do not call set_prognosis on parent
        # super().set_prognoses(sim, uids, source_uids)

        age_bins = self.get_age_bins(sim, uids)

        self.susceptible[uids] = False
        self.exposed[uids] = True
        self.ti_exposed[uids] = sim.ti

        p = self.pars

        # Determine when exposed become infected
        self.ti_infected[uids] = sim.ti + p.dur_exp.rvs(uids) / sim.dt

        # Sample duration of infection, being careful to only sample from the
        # distribution once per timestep.
        dur_inf = p.dur_inf.rvs(uids)

        # Determine who dies and who recovers and when
        will_die = p.p_death.rvs(uids)
        dead_uids = uids[will_die]
        rec_uids = uids[~will_die]
        self.ti_dead[dead_uids] = self.ti_infected[dead_uids] + dur_inf[will_die] / sim.dt
        self.ti_recovered[rec_uids] = self.ti_infected[rec_uids] + dur_inf[~will_die] / sim.dt

        return

    def update_death(self, sim, uids):
        # Reset infected/recovered flags for dead agents
        for state in ['susceptible', 'exposed', 'infected', 'recovered']:
            self.statesdict[state][uids] = False
        return

# Creating the class for the two doses
class measles_vaccine(ss.sir_vaccine):
    """
    Create a vaccine product that changes susceptible people to recovered (i.e., perfect immunity)
    """
    def administer(self, people, uids):
        people.measles2.rel_sus[uids] *= 1-self.pars.efficacy
        return

measles2 = Measles2(beta = .6)

# The parameters
pars = sc.objdict(
    n_agents = 5000,
    dt = 1/12,
    birth_rate = 27,
    death_rate = 8,
    networks = dict(
        type = 'randomnet',
        n_contacts = 10
        ),
    )

# The vaccines
my_vax1 = measles_vaccine(name='vax1', pars=dict(efficacy=0.85))
my_vax2 = measles_vaccine(name='vax2', pars=dict(efficacy=0.95))
my_vax3 = measles_vaccine(name='vax3', pars=dict(efficacy=0.9))

# The interventions from the vaccine
intv1 = ss.routine_vx(name='routine1', start_year=2010, prob= 0.95, product=my_vax1)
intv2 = ss.routine_vx(name='routine2', start_year=2013, prob= 0.95, product=my_vax2)
sia = ss.campaign_vx(name='SIA', years= [2015, 2017, 2019, 2021], prob=0.9, product=my_vax3) # The campaigns representing the SIAs
intv = [intv1, intv2, sia]

# The simulations
sim_base = ss.Sim(pars=pars, diseases=measles2, start = 2005, end = 2040)
sim_intv = ss.Sim(pars = pars, diseases=measles2, interventions = intv, start = 2005, end = 2040)
sim_base.run()
sim_intv.run()

# Plotting
plt.figure()
plt.plot(sim_base.yearvec, sim_base.results.measles2.prevalence, label='Baseline')
plt.plot(sim_intv.yearvec, sim_intv.results.measles2.prevalence, label='Vax')
plt.axvline(x=2013, color='k', ls='--')
plt.axvline(x=2010, color='k', ls='--')

plt.title('Prevalence')
plt.legend()
plt.show();

sim_base.plot()
robynstuart commented 7 months ago

@wagathu The problem is with these lines:

# The simulations
sim_base = ss.Sim(pars=pars, diseases=measles2, start = 2005, end = 2040)
sim_intv = ss.Sim(pars = pars, diseases=measles2, interventions = intv, start = 2005, end = 2040)
sim_base.run()
sim_intv.run()

It will work if you change to:

# The simulations
sim_base = ss.Sim(pars=pars, diseases=Measles2(beta = .6), start = 2005, end = 2040)
sim_intv = ss.Sim(pars = pars, diseases=Measles2(beta = .6), interventions = intv, start = 2005, end = 2040)
sim_base.run()
sim_intv.run()
cliffckerr commented 7 months ago

Related to #103