epiverse-trace / serofoi

Estimates the Force-of-Infection of a given pathogen from population based sero-prevalence studies
https://epiverse-trace.github.io/serofoi/
Other
17 stars 4 forks source link

Simulate from age-varying FOI serocatalytic model #74

Closed ekamau closed 7 months ago

ekamau commented 1 year ago

Add a feature and functionality for age varying FOI model:


\displaylines{ 
λ(a) = λ_{0} e^{-ra} \ \
dS/da = -λS + μZ \ \
dZ/da = λS - μZ \ \
dZ/da = λ(a)(1-Z) - μZ \ \

Z = exp(- \sum_{{a\prime}=0}^a [μ + λ({a\prime})]) * \sum_{{a\prime\prime}=0}^a exp( \sum_{{a\prime}=0}^{a\prime\prime} [μ + λ({a\prime})]) λ({a\prime\prime})
 }

With a binomial likelihood:

X_{i} ∼ B(N_{i}, Z_{i})
ben18785 commented 1 year ago

After thinking about the above quite deeply, we realised it's not correct to discretise the numerator because it varies within a year. Instead, we can actually integrate the ODE within each piece (where $\lambda$ is constant). For an initial $Z(t)$ value this yields:

$$ Z(t+1)=\frac{e^{-t (\lambda +\mu )} \left(\lambda \left(e^{t (\lambda +\mu )}+Z(t)-1\right)+\mu Z(t)\right)}{\lambda +\mu } $$

ben18785 commented 1 year ago

The below R code simulations from this model and shows how the solution corresponds with the exact solution of the system when $\lambda$ is constant:

library(tidyverse)

# age-dependent FOIs
as <- 1:80

# seroreversion rate
mu <- 0.05

simulate_foi_age_exact <- function(a, fois, mu) {

  I <- 0
  # solves ODE exactly within pieces
  for(i in 1:a) {
    lambda <- fois[i]
    I <- (1 / (lambda + mu)) * exp(- (lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + I * (lambda + mu))
  }

  I
}

# check soln if FOI is constant (where an analytic solution is possible)
constant_solution <- function(a, foi, mu) {
  foi * (1 - exp(-a * (foi + mu))) / (foi + mu)
}

fois <- rep(0.02, 80)
prob_infected <- map_dbl(as, ~simulate_foi_age_exact(., fois, mu))
prob_infected_true <- map_dbl(as, ~constant_solution(., fois[1], mu))

tibble(a=as, approx=prob_infected, true=prob_infected_true) %>%
  pivot_longer(-a) %>%
  ggplot(aes(x=a, y=value, colour=name)) +
  geom_line()

# Try exponentially declining FOI with mu = 0 (which permits an exact solution but only if lambda varies continuously not by year)
lambda0 <- 0.1
r <- 0.1
fois <- lambda0 * exp(-r * as)
mu <- 0

exp_solution <- function(a, lambda0, r) {
  1 - exp(-lambda0 / r) * exp((exp(-a * r) * lambda0) / r)
}

prob_infected <- map_dbl(as, ~simulate_foi_age_exact(., fois, mu))

tibble(a=as, approx=prob_infected) %>%
  pivot_longer(-a) %>%
  ggplot(aes(x=a, y=value, colour=name)) +
  geom_line()
ekamau commented 1 year ago

@ben18785 - in above R code, where are we calling the exp_solution function? I have looked and can't really see it ..

ben18785 commented 1 year ago

Ah, good spot. No I'm not (since the analytical solution is different because it assumes that FOIs are continuously varying with age). So, you can ignore that part!

ekamau commented 1 year ago

Hi @ben18785 - been going through the R code above: fois <- rep(0.02, 80) To me, this means FOI is same for all ages (assuming 80 age classes).

So, prob_infected <- map_dbl(as, ~simulate_foi_age_exact(., fois, mu)) is simulating an age-constant FOI model .. with seroreversion, and so is this: prob_infected_true <- map_dbl(as, ~constant_solution(., fois[1], mu))

Wouldn't age-varying FOI simulation have different FOIs as? fois <- rexp(80, 0.2)

If so, what's the solution for Z(a)? or solution for

\displaylines{ 
dZ/da = &lambda;(a)(1-Z) - &mu;Z
}
ben18785 commented 1 year ago

Hi @ekamau -- so, I do two things in the code above.

  1. I use a constant FOI since, in this limit, the model has an analytical solution, which I can compare my solution (for more general FOIs) with.
  2. I try solving for an exponentially declining foi via fois <- lambda0 * exp(-r * as) (note this is different to rexp which is drawing random values from an exponential distribution).

The solution for $Z(a)$ is given by piecewise integration, which is handled in the function simulate_foi_age_exact.

ekamau commented 1 year ago

Without seroreversion: In R, simulate / prepare mock data:

library(tidyverse)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores=4)

# generate synthetic data assuming we have data for all individuals aged between 1-10
ages <- seq(1, 10, 1)
sample_size <- 100
fois <- c(rep(0.01, 4), rep(0.05, 2), rep(0.1, 2), rep(0.03, 2))
mu <- 0 # no seroreversion

simulate_foi_age_exact <- function(a, fois, mu) {
  I <- 0
  # solves ODE exactly within pieces
  for(i in 1:a) {
    lambda <- fois[i]
    I <- (1 / (lambda + mu)) * exp(- (lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + I * (lambda + mu))
  }
  I
}

prob_infected_age <- map_dbl(ages, ~simulate_foi_age_exact(., fois, mu))
n_positive_age <- vector(length = length(ages))
for(i in seq_along(n_positive_age)) {
  n_positive_age[i] <- rbinom(1, size = sample_size, prob = prob_infected_age[i])
}

# data for stan model
data_stan <- list(
  N = length(n_positive_age),
  n_pos = n_positive_age,
  n_total = rep(sample_size, length(n_positive_age)),
  prior_choice = 3,
  is_binomial = TRUE,
  mu = mu

  # these parameters have different meanings dependent on prior choice
  prior_a = 0,
  prior_b = 5
)

initfn <- function() {
  list(log_foi = rep(-8, length(ages)))
}

model_age <- stan_model("age_foi_no_seroreversion.stan")
fit_age <- optimizing(model_age, data = data_stan, init = initfn, as_vector = FALSE)
prob_age <- fit_age$par$prob_infected
fois_age <- fit_age$par$foi

fit_age2 <- sampling(model_age, data = data_stan, chains = 4, cores = 2, init = initfn, iter = 3000,
                warmup = 900, refresh = 0, seed = 345)

Stan:

// age-dependent FOI model - assuming a single cross-section

functions{
  real age_foi_calc(int N, vector foi, real mu){
    real prob = 0.0;
    for(j in 1:N){
      real lambda = foi[j];
      prob = (1 / (lambda + mu)) * exp(-(lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + prob * (lambda + mu));
    }
    return(prob);
  }

}

data{
  int<lower=0> N; // No. rows in data or no. age classes
  int n_pos[N]; // seropositive
  int n_total[N]; // tested

  // prior choices
  int<lower=1, upper=7> prior_choice;
  real<lower=0> prior_a; 
  real<lower=0> prior_b;

  real<lower=0> mu;

}

transformed data{
  int is_random_walk = prior_choice <= 4 ? 1 : 0;

}

parameters{
  row_vector[N] log_foi; 
  real<lower=0> sigma[is_random_walk ? 1 : 0]; 
  real<lower=0> nu[is_random_walk ? 1 : 0]; 

}

transformed parameters{
  real prob_infected[N];
  vector[N] foi;
  foi = to_vector(exp(log_foi));

  for(i in 1:N){
    prob_infected[i] = age_foi_calc(i, foi, mu);
  }

}

model{
  // likelihood
  n_pos ~ binomial(n_total, prob_infected) ; 

  // priors
  if(prior_choice == 1) { // forward random walk

    sigma ~ cauchy(0, 1);
    log_foi[1] ~ normal(prior_a, prior_b);

    for(i in 2:N)
      log_foi[i] ~ normal(log_foi[i - 1], sigma);

  } else if (prior_choice == 2) { // backward random walk

    sigma ~ cauchy(0, 1);
    log_foi[N] ~ normal(prior_a, prior_b);

    for(i in 1:(N - 1))
      log_foi[N - i] ~ normal(log_foi[N - i + 1], sigma);

  } else if(prior_choice == 3){ // forward random walk with Student-t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[1] ~ normal(prior_a, prior_b);

    for(i in 2:N)
      log_foi[i] ~ student_t(nu, log_foi[i - 1], sigma);

  } else if (prior_choice == 4) { // backward random walk with Student t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[N] ~ normal(prior_a, prior_b);

    for(i in 1:(N - 1))
      log_foi[N - i] ~ student_t(nu, log_foi[N - i + 1], sigma);

  } else if(prior_choice == 5){ // uniform

    foi ~ uniform(prior_a, prior_b);
    target += sum(foi);

  } else if(prior_choice == 6) { // weakly informative

    foi ~ cauchy(prior_a, prior_b);
    target += sum(foi);

  } else if(prior_choice == 7) { // Laplace (sparsity-inducing)

    foi ~ double_exponential(prior_a, prior_b);
    target += sum(foi);

  }

}

generated quantities {
  int pos_pred[N];
  pos_pred = binomial_rng(n_total, prob_infected);

}
ekamau commented 1 year ago

@ben18785 - the code above works without errors or warnings, as it is. I noticed in #69 that 'is_binomial = TRUE' is supplied in the stan data list. are we using this somewhere?

ekamau commented 1 year ago

Small update to the stan code above:

// age-dependent FOI model

functions{
  real age_foi_calc(int age, int[] chunks, vector foi, real mu){
    real prob = 0.0;
    for(j in 1:age){
      real lambda = foi[chunks[j]];
      prob = (1 / (lambda + mu)) * exp(-(lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + prob * (lambda + mu));
    }
    return prob;
  }

 real [] prob_infection_calc(int n_obs, int [] ages, int[] chunks, vector foi, real mu) {

   real prob_infected[n_obs];

   for(i in 1:n_obs){
     int age = ages[i];
     prob_infected[i] = age_foi_calc(age, chunks, foi, mu);
   }

   return prob_infected;
}

data{
  int<lower=0> n_obs; // No. rows in data or no. age classes
  int n_pos[n_obs]; // seropositive
  int n_total[n_obs]; // tested
  int age_max;
  int chunks[age_max]; // vector of chunks of length age_max
  int ages[n_obs];

  // model type
  int<lower=0, upper=1> include_seroreversion;

  // prior choices
  int<lower=1, upper=6> prior_choice;
  real<lower=0> prior_a; 
  real<lower=0> prior_b;
}

transformed data{
  int is_random_walk = prior_choice <= 4 ? 1 : 0;
  int n_chunks = max(chunks); // max value in the vector n_chunks
}

parameters{
  row_vector[n_chunks] log_foi; // length of vector = max value in the vector n_chunks
  real<lower=0> sigma[is_random_walk ? 1 : 0]; // ?? only for R/W models
  real<lower=0> nu[is_random_walk ? 1 : 0]; // ?? only for R/W models
  real<lower=0> seroreversion_rate[include_seroreversion ? 1 : 0]; // rate of seroreversion
}

transformed parameters{
  real<lower=0> mu;
  real<lower=0> prob_infected[n_obs];
  vector<lower=0>[n_chunks] foi = to_vector(exp(log_foi));

  if(include_seroreversion){
    mu = seroreversion_rate[1];
  } else{
    mu = 0.0;
  }

  prob_infection = prob_infection_calc(n_obs, ages, chunks, foi, mu);
}

model{
  // likelihood
  n_pos ~ binomial(n_total, prob_infected) ; // good!

  // priors
  if(include_seroreversion){
    seroreversion_rate ~ cauchy(0, 1);
  }

  if(prior_choice == 1) { // forward random walk

    sigma ~ cauchy(0, 1);
    log_foi[1] ~ normal(prior_a, prior_b);

    for(i in 2:n_chunks)
      log_foi[i] ~ normal(log_foi[i - 1], sigma);

  } else if (prior_choice == 2) { // backward random walk

    sigma ~ cauchy(0, 1);
    log_foi[n_chunks] ~ normal(prior_a, prior_b);

    for(i in 1:(n_chunks - 1))
      log_foi[n_chunks - i] ~ normal(log_foi[n_chunks - i + 1], sigma);

  } else if(prior_choice == 3){ // forward random walk with Student-t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[1] ~ normal(prior_a, prior_b);

    for(i in 2:n_chunks)
      log_foi[i] ~ student_t(nu, log_foi[i - 1], sigma);

  } else if (prior_choice == 4) { // backward random walk with Student t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[n_chunks] ~ normal(prior_a, prior_b);

    for(i in 1:(n_chunks - 1))
      log_foi[n_chunks - i] ~ student_t(nu, log_foi[n_chunks - i + 1], sigma);

  } else if(prior_choice == 5){ // uniform

    foi ~ uniform(prior_a, prior_b);
    target += sum(foi);

  } else if(prior_choice == 6) { // weakly informative

    foi ~ cauchy(prior_a, prior_b);
    target += sum(foi);

  }
}

generated quantities {
  int pos_pred[n_obs];
  pos_pred = binomial_rng(n_total, prob_infected);

}
ekamau commented 1 year ago

@ben18785 - Updated R and stan code:

R - for simulating data and running age varying FOI model

# age-FOI model: piece-wise lambda estimation 

library(tidyverse)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = 4)

simulate_foi_age_exact <- function(age, fois, chunks, mu) {
  I <- 0
  # solves ODE exactly within pieces
  for(i in 1:age) {
    lambda <- fois[chunks[i]]
    I <- (1 / (lambda + mu)) * exp(- (lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + I * (lambda + mu))
  }
  I
}

ages <- seq(1, 40, 1)
sample_size <- 100
#fois <- c(rep(0.01, 10), rep(0.05, 10), rep(0.1, 10), rep(0.03, 10))
fois <- c(0.01, 0.05, 0.1, 0.03)
chunks <- unlist(map(seq(1, 4, 1), ~rep(., 10))) # how many FOIs to estimate
mu <- 0.0 # no seroreversion

# (1) Model without seroreversion
# generate synthetic data assuming we have data for all individuals aged between 1-10

prob_infected_simulated <- map_dbl(ages, ~simulate_foi_age_exact(., fois, chunks, mu))
n_positive_age <- vector(length = length(ages))
for(i in seq_along(n_positive_age)) {
  n_positive_age[i] <- rbinom(1, size = sample_size, prob = prob_infected_simulated[i])
}

df <- data.frame(npos = n_positive_age, total = rep(sample_size, length(n_positive_age)))
# write.csv(df, "age_no_seroreversion_mock_data.csv", row.names = FALSE)

# check the stan functions:
expose_stan_functions("age_foi_seroreversion_13Aug23.stan")
prob_infected_stan <- prob_infection_calc(ages = ages, chunks = chunks, foi = fois, mu = mu, n_obs = n_obs)

identical(prob_infected_simulated, prob_infected_stan)

fig1 <- tibble(age=ages, R_function=prob_infected_simulated, stan_function=prob_infected_stan) %>%
  pivot_longer(-age) %>%
  ggplot(aes(x=age, y=value, colour=name)) +
  geom_line() + geom_point() + labs(y = "Probability infected") + guides(colour = "none") + 
  facet_wrap(.~ name)

ggsave(filename = "ageFOI_check_stan_fxn.png", fig1, width = 6, height = 4, dpi = 640)

#(2) Run stan model with simulated data:
# data for stan model
data_stan <- list(
  n_obs = length(n_positive_age),
  n_pos = n_positive_age,
  n_total = rep(sample_size, length(n_positive_age)),
  ages = ages,
  is_binomial = TRUE,
  include_seroreversion = 1,
  age_max = max(ages),
  chunks = chunks,

  # these parameters have different meanings dependent on prior choice
  foi_prior_choice = 3, # works w/o errors or warnings
  foi_prior_a = 0,
  foi_prior_b = 5,
  serorev_prior_choice = 1,
  serorev_prior_a = 0,
  serorev_prior_b = 1
)

model_age <- stan_model("age_foi_seroreversion_13Aug23.stan")
initfn <- function() { list(log_foi = rep(-8, max(chunks))) }
fit <- optimizing(model_age, data = data_stan, init = initfn, as_vector = FALSE)
prob <- fit$par$prob_infected
foi_optim <- fit$par$foi

fit_age <- sampling(model_age, data = data_stan, chains = 4, init = initfn, 
                     iter = 3000, warmup = 900, refresh = 0, seed = 345,
                     control = list(adapt_delta = 0.9999, max_treedepth = 25))

fit_age
s <- as.data.frame(summary(fit_age, probs = c(0.025, 0.975))$summary)
model_ppc <- tibble::rownames_to_column(s, "parameter") %>% filter(grepl("pos_pred", parameter))

fig2 <- tibble(age=ages, simulated=n_positive_age, model_predicted=model_ppc$mean) %>%
  pivot_longer(-age) %>%
  ggplot(aes(x=age, y=value, colour=name)) +
  geom_line() + geom_point() + labs(y = "Number positive", colour="")

ggsave(filename = "ageFOI_compare_outputs.png", fig2, width = 6, height = 4, dpi = 640)

Stan code:

// age-dependent FOI model - v1

functions{
  real age_foi_calc(int age, int[] chunks, vector foi, real mu){
    real prob = 0.0;
    for(j in 1:age){
      real lambda = foi[chunks[j]];
      prob = (1 / (lambda + mu)) * exp(-(lambda + mu)) * (lambda * (exp(lambda + mu)  - 1) + prob * (lambda + mu));
    }

    return prob;
  }

  real[] prob_infection_calc(int[] ages, int[] chunks, vector foi, real mu, int n_obs) {
    real prob_infected[n_obs];
    for(i in 1:n_obs){
      int age = ages[i];
      prob_infected[i] = age_foi_calc(age, chunks, foi, mu);
    }

    return prob_infected;
  }

}

data{
  int<lower=0> n_obs; // No. rows in data or no. age classes
  int n_pos[n_obs]; // seropositive
  int n_total[n_obs]; // tested
  int age_max;
  int chunks[age_max]; // vector of length age_max
  int ages[n_obs];

  // model type
  int<lower=0, upper=1> include_seroreversion;

  // prior choices
  int<lower=1, upper=6> foi_prior_choice;
  real<lower=0> foi_prior_a; 
  real<lower=0> foi_prior_b;

  int<lower=1, upper=3> serorev_prior_choice;
  real<lower=0> serorev_prior_a; 
  real<lower=0> serorev_prior_b;

}

transformed data{
  int is_random_walk = foi_prior_choice <= 4 ? 1 : 0;
  int n_chunks = max(chunks); // max value in the vector n_chunks

}

parameters{
  row_vector[n_chunks] log_foi; // length of vector = max value in the vector n_chunks
  real<lower=0> sigma[is_random_walk ? 1 : 0]; // only for R/W models
  real<lower=0> nu[is_random_walk ? 1 : 0]; // only for R/W models
  real<lower=0> seroreversion_rate[include_seroreversion ? 1 : 0]; // rate of seroreversion

}

transformed parameters{
  real<lower=0> mu;
  real<lower=0> prob_infection[n_obs];
  vector<lower=0>[n_chunks] foi = to_vector(exp(log_foi));

  if(include_seroreversion){
    mu = seroreversion_rate[1];
  } else{
    mu = 0.0;
  }

  prob_infection = prob_infection_calc(ages, chunks, foi, mu, n_obs);

}

model{
  // likelihood
  n_pos ~ binomial(n_total, prob_infection);

  // priors - seroreversion
  if(include_seroreversion) {
    if(serorev_prior_choice == 1) {
      seroreversion_rate ~ cauchy(serorev_prior_a, serorev_prior_b);

    } else if (serorev_prior_choice == 2) {
      seroreversion_rate ~ normal(serorev_prior_a, serorev_prior_b);

    } else if (serorev_prior_choice == 3) {
      seroreversion_rate ~ uniform(serorev_prior_a, serorev_prior_b);

    }
  }

  // priors - FOI
  if(foi_prior_choice == 1) { // forward random walk

    sigma ~ cauchy(0, 1);
    log_foi[1] ~ normal(foi_prior_a, foi_prior_b);

    for(i in 2:n_chunks)
      log_foi[i] ~ normal(log_foi[i - 1], sigma);

  } else if (foi_prior_choice == 2) { // backward random walk

    sigma ~ cauchy(0, 1);
    log_foi[n_chunks] ~ normal(foi_prior_a, foi_prior_b);

    for(i in 1:(n_chunks - 1))
      log_foi[n_chunks - i] ~ normal(log_foi[n_chunks - i + 1], sigma);

  } else if(foi_prior_choice == 3){ // forward random walk with Student-t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[1] ~ normal(foi_prior_a, foi_prior_b);

    for(i in 2:n_chunks)
      log_foi[i] ~ student_t(nu, log_foi[i - 1], sigma);

  } else if (foi_prior_choice == 4) { // backward random walk with Student t

    sigma ~ cauchy(0, 1);
    nu ~ cauchy(0, 1);
    log_foi[n_chunks] ~ normal(foi_prior_a, foi_prior_b);

    for(i in 1:(n_chunks - 1))
      log_foi[n_chunks - i] ~ student_t(nu, log_foi[n_chunks - i + 1], sigma);

  } else if(foi_prior_choice == 5){ // uniform

    foi ~ uniform(foi_prior_a, foi_prior_b);
    target += sum(foi);

  } else if(foi_prior_choice == 6) { // weakly informative

    foi ~ cauchy(foi_prior_a, foi_prior_b);
    target += sum(foi);

  }
}

generated quantities {
  int pos_pred[n_obs];
  pos_pred = binomial_rng(n_total, prob_infection);

}

Figures: ageFOI_check_stan_fxn ageFOI_compare_outputs

ekamau commented 1 year ago

Model estimated FOI vs. true value (used in data simulation):

model = c(0.01, 0.05, 0.14, 0.16)
true = c(0.01, 0.05, 0.1, 0.03)

seroreversion rate model estimate = 0.02
ben18785 commented 1 year ago

Thanks @ekamau -- these look great. @davidsantiagoquevedo -- these are now ready for your review. Thanks!