fintzij / stemr

Fit Stochastic Epidemic Models via Bayesian Data Augmentation
https://fintzij.github.io/stemr/
8 stars 6 forks source link

list of multiple datasets in `stem_measure` causes error #12

Closed chvandorp closed 4 years ago

chvandorp commented 4 years ago

Hi Jon,

Found another small bug.. I want to use multiple data sets, with different sampling intervals. Therefore I passed a list of data frames to the data argument of the stem_measure function, as described in the documentation. However this leads to an indexing error. I think this is because of how the validity of the arguments is checked at the beginning of the stem_measure function. When I comment the following lines out, the list of data is accepted without error:

## lines 68-70 of stem_measure.R
if(max(data[,1]) != dynamics$tmax) { ## if data is a list, then the indexing causes an error
    warning("The maximum observation time in the data object is not equal to the maximum observation time in the emission lists.")
}

Hope this helps! Many thanks,

Chris

fintzij commented 4 years ago

Thanks, good catch Chris! This is one of those features that I built in but never really used. Could I trouble you to post a minimal working example so I can debug it? I want to also double check that there aren't any downstream errors once the datasets are stitched together.

Cheers, Jon

chvandorp commented 4 years ago

No problem. See script below. Not sure if it is minimal, but it is an example. I did get this to work when I commented out lines 68-70 of stem_measure.R. (although I did find another problem with prevalence-based observations initially, but I'll bother you with that later)

library(stemr)
library(ggplot2)
library(gridExtra)

true_pars <- c(beta=1, gamma=0.5, delta=0.01, theta=0.2)

compartments <- c("S", "I", "R")

rates <- list(
  rate(rate="beta * I/(S+I+R)",
       from="S",
       to="I",
       incidence=TRUE),
  rate(rate="gamma",
       from="I",
       to="R",
       incidence=TRUE)
)

N <- 1e5 ## popsize
epsilon <- 1e-4 ## inoculum (fraction of popsize)

state_initializer <- list(
  stem_initializer(
    init_states = c(
      S = N * (1-epsilon), 
      I = N * epsilon,
      R = 0
    ),
    fixed = TRUE)
  )

# only estimate beta, gamma and delta
par_names <- c("beta", "gamma", "delta")
constants <- c(t0=0, true_pars[!names(true_pars) %in% par_names])
params <- true_pars[par_names]

tmax <- 35

# compile the model
dynamics <- stem_dynamics(
  rates = rates,
  parameters = params,
  tmax = tmax,
  state_initializer = state_initializer,
  compartments = compartments,
  constants = constants,
  compile_ode = TRUE,  
  compile_rates = TRUE, 
  compile_lna = TRUE,
  messages = TRUE
)

obstimes.cases <- seq(7,tmax,7) ## weekly case counts
obstimes.deaths <- seq(1,tmax) ## daily death counts

emissions <- list(
  emission(meas_var = "cases",
           distribution    = "poisson",
           emission_params = c("S2I * theta"),
           incidence       = TRUE,
           obstimes        = obstimes.cases),
  emission(meas_var = "deaths",
           distribution    = "poisson",
           emission_params = c("I2R * delta"),
           incidence       = TRUE,
           obstimes        = obstimes.deaths)
)

measurement_process <- stem_measure(
  emissions = emissions,
  dynamics  = dynamics,
  messages  = TRUE
)

stem_object <- make_stem(
  dynamics = dynamics,
  measurement_process = measurement_process
)

########### simulate data using the MJP ############

sim_mjp <- simulate_stem(
  stem_object=stem_object, 
  method="gillespie", 
  messages=TRUE
)

sim_data <- as.data.frame(sim_mjp$datasets[[1]])
## select data using the obstimes
sim_death_rows <- sim_data[,"time"] %in% obstimes.deaths
sim_death_data <- sim_data[sim_death_rows,]

sim_case_rows <- sim_data[,"time"] %in% obstimes.cases
sim_case_data <- sim_data[sim_case_rows,]

plot.cases <- ggplot(data=sim_case_data, aes(x=time, y=cases)) + geom_point()
plot.deaths <- ggplot(data=sim_death_data, aes(x=time, y=deaths)) + geom_point()

grid.arrange(plot.cases, plot.deaths)

####### create a new measurement process using the simulated data #######

data.list <- list(
  sim_case_data[c("time", "cases")],
  sim_death_data[c("time", "deaths")]
)

measurement_process <- stem_measure(
  emissions = emissions,
  dynamics  = dynamics,
  data      = data.list, ### ERROR!!!
  messages  = TRUE
)

stem_object <- make_stem(
  dynamics = dynamics,
  measurement_process = measurement_process
)

############## fit model to simulated data #############

to_estimation_scale <- function(params_nat) {
  c(log(params_nat[1:2]), logit(params_nat[3]))
}

from_estimation_scale <- function(params_est) {
  c(exp(params_est[1:2]), expit(params_est[3]))
}

priors <- list(
  logprior = function(params_est){return(0.0)}, ## flat prior
  to_estimation_scale = to_estimation_scale,
  from_estimation_scale = from_estimation_scale
)

mcmc_kern <- mcmc_kernel(
  parameter_blocks = list(
    parblock(
      pars_nat = par_names,
      pars_est = c("log_beta", "log_gamma", "logit_delta"),
      priors = priors,
      alg = "mvnmh",
      sigma = diag(0.01, length(par_names)),
      control = mvnmh_control(stop_adaptation = 2.5e2)
    )
  ), 
  lna_ess_control = lna_control(bracket_update_iter = 50)
)

res <- fit_stem(
  stem_object = stem_object,
  method = "lna",
  mcmc_kern = mcmc_kern,
  thinning_interval = 10,
  iterations = 2e3,
  print_progress = 10
)

posterior = res$results$posterior # list with posterior objects

############ plot trajectories ###############

num_samples <- dim(posterior$latent_paths)[3]

cases.plot <- ggplot(data=sim_case_data, aes(x=time, y=cases)) +
  geom_point()

for ( i in seq(num_samples/2, num_samples, 10) ) {
  traj <- as.data.frame(posterior$latent_paths[,,i])
  traj$expect_cases <- traj$S2I * true_pars["theta"]
  cases.plot <- cases.plot + geom_line(data=traj, aes(x=time, y=expect_cases), 
                                       alpha=0.2, color='green')
}

deaths.plot <- ggplot(data=sim_death_data, aes(x=time, y=deaths)) +
  geom_point()

for ( i in seq(num_samples/2, num_samples, 10) ) {
  traj <- as.data.frame(posterior$latent_paths[,,i])
  traj$expect_deaths <- traj$I2R * true_pars["delta"]
  deaths.plot <- deaths.plot + geom_line(data=traj, aes(x=time, y=expect_deaths), 
                                       alpha=0.2, color='red')
}

grid.arrange(cases.plot, deaths.plot)
fintzij commented 4 years ago

Hmm I'm having trouble simulating data with this code. For some reason, my R session crashes in the Gillespie simulation. I'm assuming this does not also happen for you, right?

chvandorp commented 4 years ago

No, works fine. I'm using R version 4.0.2 with gcc 9.3.0 on ubuntu 20.04.1 Result should be something like this:

Rplot

fintzij commented 4 years ago

Hi Chris,

Just pushed the correction along with a fix for another bug I hadn't previously noticed. There was an issue with how constants were being parsed when the initial compartment counts are fixed. You should now be able to pass the population size as a constant and use it in the rates, e.g., rate="beta I/N" where N is passed as a constant, instead of "beta I / (S+I+R)".

Cheers, Jon