Closed chvandorp closed 4 years ago
Thanks, good catch Chris! This is one of those features that I built in but never really used. Could I trouble you to post a minimal working example so I can debug it? I want to also double check that there aren't any downstream errors once the datasets are stitched together.
Cheers, Jon
No problem. See script below. Not sure if it is minimal, but it is an example. I did get this to work when I commented out lines 68-70 of stem_measure.R
. (although I did find another problem with prevalence-based observations initially, but I'll bother you with that later)
library(stemr)
library(ggplot2)
library(gridExtra)
true_pars <- c(beta=1, gamma=0.5, delta=0.01, theta=0.2)
compartments <- c("S", "I", "R")
rates <- list(
rate(rate="beta * I/(S+I+R)",
from="S",
to="I",
incidence=TRUE),
rate(rate="gamma",
from="I",
to="R",
incidence=TRUE)
)
N <- 1e5 ## popsize
epsilon <- 1e-4 ## inoculum (fraction of popsize)
state_initializer <- list(
stem_initializer(
init_states = c(
S = N * (1-epsilon),
I = N * epsilon,
R = 0
),
fixed = TRUE)
)
# only estimate beta, gamma and delta
par_names <- c("beta", "gamma", "delta")
constants <- c(t0=0, true_pars[!names(true_pars) %in% par_names])
params <- true_pars[par_names]
tmax <- 35
# compile the model
dynamics <- stem_dynamics(
rates = rates,
parameters = params,
tmax = tmax,
state_initializer = state_initializer,
compartments = compartments,
constants = constants,
compile_ode = TRUE,
compile_rates = TRUE,
compile_lna = TRUE,
messages = TRUE
)
obstimes.cases <- seq(7,tmax,7) ## weekly case counts
obstimes.deaths <- seq(1,tmax) ## daily death counts
emissions <- list(
emission(meas_var = "cases",
distribution = "poisson",
emission_params = c("S2I * theta"),
incidence = TRUE,
obstimes = obstimes.cases),
emission(meas_var = "deaths",
distribution = "poisson",
emission_params = c("I2R * delta"),
incidence = TRUE,
obstimes = obstimes.deaths)
)
measurement_process <- stem_measure(
emissions = emissions,
dynamics = dynamics,
messages = TRUE
)
stem_object <- make_stem(
dynamics = dynamics,
measurement_process = measurement_process
)
########### simulate data using the MJP ############
sim_mjp <- simulate_stem(
stem_object=stem_object,
method="gillespie",
messages=TRUE
)
sim_data <- as.data.frame(sim_mjp$datasets[[1]])
## select data using the obstimes
sim_death_rows <- sim_data[,"time"] %in% obstimes.deaths
sim_death_data <- sim_data[sim_death_rows,]
sim_case_rows <- sim_data[,"time"] %in% obstimes.cases
sim_case_data <- sim_data[sim_case_rows,]
plot.cases <- ggplot(data=sim_case_data, aes(x=time, y=cases)) + geom_point()
plot.deaths <- ggplot(data=sim_death_data, aes(x=time, y=deaths)) + geom_point()
grid.arrange(plot.cases, plot.deaths)
####### create a new measurement process using the simulated data #######
data.list <- list(
sim_case_data[c("time", "cases")],
sim_death_data[c("time", "deaths")]
)
measurement_process <- stem_measure(
emissions = emissions,
dynamics = dynamics,
data = data.list, ### ERROR!!!
messages = TRUE
)
stem_object <- make_stem(
dynamics = dynamics,
measurement_process = measurement_process
)
############## fit model to simulated data #############
to_estimation_scale <- function(params_nat) {
c(log(params_nat[1:2]), logit(params_nat[3]))
}
from_estimation_scale <- function(params_est) {
c(exp(params_est[1:2]), expit(params_est[3]))
}
priors <- list(
logprior = function(params_est){return(0.0)}, ## flat prior
to_estimation_scale = to_estimation_scale,
from_estimation_scale = from_estimation_scale
)
mcmc_kern <- mcmc_kernel(
parameter_blocks = list(
parblock(
pars_nat = par_names,
pars_est = c("log_beta", "log_gamma", "logit_delta"),
priors = priors,
alg = "mvnmh",
sigma = diag(0.01, length(par_names)),
control = mvnmh_control(stop_adaptation = 2.5e2)
)
),
lna_ess_control = lna_control(bracket_update_iter = 50)
)
res <- fit_stem(
stem_object = stem_object,
method = "lna",
mcmc_kern = mcmc_kern,
thinning_interval = 10,
iterations = 2e3,
print_progress = 10
)
posterior = res$results$posterior # list with posterior objects
############ plot trajectories ###############
num_samples <- dim(posterior$latent_paths)[3]
cases.plot <- ggplot(data=sim_case_data, aes(x=time, y=cases)) +
geom_point()
for ( i in seq(num_samples/2, num_samples, 10) ) {
traj <- as.data.frame(posterior$latent_paths[,,i])
traj$expect_cases <- traj$S2I * true_pars["theta"]
cases.plot <- cases.plot + geom_line(data=traj, aes(x=time, y=expect_cases),
alpha=0.2, color='green')
}
deaths.plot <- ggplot(data=sim_death_data, aes(x=time, y=deaths)) +
geom_point()
for ( i in seq(num_samples/2, num_samples, 10) ) {
traj <- as.data.frame(posterior$latent_paths[,,i])
traj$expect_deaths <- traj$I2R * true_pars["delta"]
deaths.plot <- deaths.plot + geom_line(data=traj, aes(x=time, y=expect_deaths),
alpha=0.2, color='red')
}
grid.arrange(cases.plot, deaths.plot)
Hmm I'm having trouble simulating data with this code. For some reason, my R session crashes in the Gillespie simulation. I'm assuming this does not also happen for you, right?
No, works fine. I'm using R version 4.0.2 with gcc 9.3.0 on ubuntu 20.04.1 Result should be something like this:
Hi Chris,
Just pushed the correction along with a fix for another bug I hadn't previously noticed. There was an issue with how constants were being parsed when the initial compartment counts are fixed. You should now be able to pass the population size as a constant and use it in the rates, e.g., rate="beta I/N" where N is passed as a constant, instead of "beta I / (S+I+R)".
Cheers, Jon
Hi Jon,
Found another small bug.. I want to use multiple data sets, with different sampling intervals. Therefore I passed a list of data frames to the
data
argument of thestem_measure
function, as described in the documentation. However this leads to an indexing error. I think this is because of how the validity of the arguments is checked at the beginning of thestem_measure
function. When I comment the following lines out, the list of data is accepted without error:Hope this helps! Many thanks,
Chris