greta-dev / greta.dynamics

a greta extension for modelling dynamical systems
https://greta-dev.github.io/greta.dynamics/
Other
6 stars 2 forks source link

Add functions for implementing continuous relaxations of discrete stochastic transitions #31

Open goldingn opened 3 months ago

goldingn commented 3 months ago

Background

Gradient-based inference (like the HMC greta uses) can only operate on continuous parameter spaces. That means it cannot learn the values of parameters with discrete support (e.g. no unobserved Poisson random variables).

But demographic stochasticity due to discrete stochastic variation in population sizes between timesteps in models of populations and infections (e.g. the number of new infectees is Poisson, the number of individuals surviving is binomial) is often important, especially when populations reach low numbers and near extinction. We cannot directly model the values of these discrete random variables, but we can apply continuous relaxations to approximate them; keeping the state values continuous and replacing the discrete random variables with continuous random variables that matches the mean and variance (and ideally the full shape of the distribution) of the random variable we would like to model as stochastic.

Continuous approximations to distributions

E.g. a poisson random variable can be approximated with an appropriately-shaped gamma distribution that exactly matches the PMF at discrete values, or by some other distribution that is a a close-enough approximation:

We might write a discrete stochastic growth-rate model like this:

  1. $$xt \sim Poisson(x{t-1} \times r)$$

where $x$ takes integer values, $Poisson(\lambda)$ is the Poisson distribution, and $r$ is a positive-valued growth rate parameter. To estimate the posterior over the values represented by $x$ in this model, but using HMC, we could instead fit the model:

  1. $$yt \sim \pi(y{t-1} \times r)$$

where $y_t$ is a (strictly positive) real-valued parameter, and $\pi(\lambda)$ is some probability distribution with support on positive real values that has similar moments (men, variance, skewness, etc.) to $Poisson(\lambda)$.

Reparameterisation

If we structure these probability distributions such that they can be reparameterised in terms of the parameter and some latent 'innovation' or noise, with known distribution, we can significantly improve the ability to sample these models, since we can decorrelate the posterior distribution in a similar way to the reparameterisation trick for hierarchical models. This can also provide some computational advantages in greta/tensorflow by working with arrays rather than scalars.

If we know the quantile function of the continuous distribution (e.g. $q_{\pi}(p, \lambda)$ as the quantile function of $\pi(\lambda)$ , with $p$ the probability argument), we can reparameterise the innovations as $u \sim U(0, 1)$, and then plug them into the quantile function to sample poisson values. Ie. equation 2 above is equivalent to:

3.

\displaylines{
  y_t = q_{\pi}(u_t, \lambda_t) \\
  \lambda_t = y_{t-1} \times r \\
  u_t \sim U(0, 1)
}

The vector of $u$ values can then be computed in advance, and passed into the solvers to be chopped up appropriately. The dependency structure in the model means $yt$ depends on $y{t-1}$, and so they are a priori (and therefore also a posteriori) correlated. But in this reparameterised formulation, HMC operates instead on $u$, and $ut$ doesn't depend on $u{t-1}$ so they are a priori uncorrelated, which removes a lot of correlation in the posterior and makes sampling much easier.

Note that if the quantile function is expensive to compute, other approximations and reparameterisations may be more appealing. E.g. for the Poisson, the inverse of the incomplete gamma function gives the quantile of the gamma distribution whose PDF matches the Poisson PMF at discrete values, but the function has no analytic form and is expensive to compute. A lognormal approximation (with either uniform or standard normal innovations) is imperfect but much more computationally efficient.

Currently this reparameterisation trick is only applicable in the greta_2 branch with iterate_dynamic_function(), since it requires functionality for indexing time-varying parameters.

Implementation

We just need to provide functions for the relaxations, documentation, and examples of applying this approach. I have some, I just need to add them to the greta_2 branch and work out the neatest user interface.

goldingn commented 3 months ago

Here's a script with my current hacky functions, using the current greta_2 branch (which has time-varying parameters etc and bug fixes), applied to a problem where the population goes extinct.

# do 1D version of stochastic dynamics in greta.dynamics

# simulate a discrete stochastic growth rate population model
set.seed(3)
n_times <- 50
# daily population growth rate
r_true <- 1.01
# initial population (just before timeseries)
pop_init_true <- 25

time <- seq_len(n_times)
pop_true <- rep(NA, n_times)
pop_previous <- pop_init_true
for (i in 1:n_times) {
  pop_true[i] <- rpois(1, pop_previous * r_true)
  pop_previous <- pop_true[i]
}

# add an observation process (binomial with fixed detection probability)
obs_prob <- 0.7
pop_obs <- rbinom(n_times, size = pop_true, prob = obs_prob)

# define functions for continuous relaxation and reparameterisation of the
# Poisson distributions

# given poisson rate parameter lambda and random uniform deviate u, a continuous
# relaxation of poisson random variable generation is computed using the inverse
# of the incomplete gamma function. ie. igammainv(lambda, 1 - u) is
# approximately equal to qpois(u, lambda) (and exactly equal to qgamma(1 - u,
# lambda) in the R implementation)
gamma_continuous_poisson <- function(lambda, u) {
  igammainv(lambda, 1 - u)
}
# # check:
# lambda <- as_data(pi)
# u <- uniform(0, 1)
# y <- gamma_continuous_poisson(lambda, u)
# sims <- calculate(y, u, nsim = 1e5)
# max(abs(sims$y - qgamma(1 - sims$u, pi)))
# quantile(round(sims$y))
# quantile(rpois(1e5, pi))

# the inverse incomplete gamma function (the major part of the quantile function
# of a gamma distribution)
igammainv <- function(a, p) {
  op <- greta::.internals$nodes$constructors$op
  op("igammainv", a, p,
     tf_operation = "tf_igammainv"
  )
}
tf_igammainv <- function(a, p) {
  tfp <- greta:::tfp
  tfp$math$igammainv(a, p)
}

# given random variables z (with standard normal distribution a priori), and
# Poisson rate parameter lambda, return a strictly positive continuous random
# variable with the same mean and variance as a poisson random variable with
# rate lambda, by approximating the poisson as a lognormal distribution.
lognormal_continuous_poisson <- function(lambda, z) {
  sigma <- sqrt(log1p(lambda / exp(2 * log(lambda))))
  # sigma2 <- log1p(1 / lambda)
  mu <- log(lambda) - sigma^2 / 2
  exp(mu + z * sigma)
}
# Working: The lognormal mean and variance should both equal lambda. The
# lognormal mean and variance can both be expressed in terms of the parameters
# mu and sigma.
# mean = lambda = exp(mu + sigma^2 / 2)
# variance = lambda = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)

# solve to get sigma and mu as a function of lambda:
# mu = log(lambda) - sigma^2 / 2
# lambda = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)
# lambda = (exp(sigma ^ 2) - 1) * exp(2 * (log(lambda) - sigma^2 / 2) + sigma ^ 2)
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * (log(lambda) - sigma^2 / 2))
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * log(lambda) - sigma^2)
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * log(lambda)) / exp(sigma^2)
# lambda / exp(2 * log(lambda)) = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * 1 / exp(sigma^2)
# lambda / exp(2 * log(lambda)) = (exp(sigma ^ 2) - 1)
# log(lambda / exp(2 * log(lambda)) + 1) = sigma ^ 2
# sigma = sqrt(log(lambda / exp(2 * log(lambda)) + 1)) = sigma
# mu = log(lambda) - sigma^2 / 2

# # check these numerically
# library(tidyverse)
# compare <- tibble(
#   lambda = seq(0.01, 1000, length.out = 100)
# ) %>%
#   mutate(
#     sigma = sqrt(log(lambda / exp(2 * log(lambda)) + 1)),
#     mu = log(lambda) - sigma^2 / 2
#   ) %>%
#   mutate(
#     mean = exp(mu + sigma^2 / 2),
#     variance = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)
#   ) %>%
#   mutate(
#     diff_mean_variance = abs(mean - variance),
#     diff_mean_lambda = abs(mean - lambda),
#     diff_variance_lambda = abs(variance - lambda)
#   ) %>%
#   summarise(
#     across(
#       starts_with("diff"),
#       ~max(.x)
#     )
#   )

library(greta.dynamics)
#> Loading required package: greta
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# growth rate
r <- normal(1, 0.1, truncation = c(0, Inf))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 

# latent variable for stochastic transitions
latent_z_vec <- normal(0, 1, dim = n_times)
# latent_u_vec <- uniform(0, 1, dim = n_times)

# initial population size (with mean equal to truth)
init <- exponential(1 / pop_init_true)

# transition function for the population process as a difference equation in
# integer timestep
fun <- function(state, iter, r, latent_z) {
  lambda <- state * r
  state_new <- lognormal_continuous_poisson(lambda, latent_z)
  # state_new <- gamma_continuous_poisson(lambda, latent_u)
  state_new
}

# solve it (integer times)
pop_sim <- iterate_dynamic_function(
  transition_function = fun,
  initial_state = init,
  niter = n_times,
  tol = 0,
  r = r,
  latent_z = latent_z_vec,
  parameter_is_time_varying = "latent_z",
  # clamp the populations to reasonably values to avoid numerical under/overflow
  state_limits = c(1e-3, 1e3)
)

# get the modelled true population
pop_modelled <- t(pop_sim$all_states)
# and the expected value of the observation distribution
pop_obs_expected <- pop_modelled * obs_prob
pop_obs_ga <- as_data(pop_obs)
distribution(pop_obs_ga) <- poisson(pop_obs_expected)

# fit this, reducing the stochasticity when finding the initial values
n_chains <- 4
inits <- replicate(n_chains,
                   initials(
                     init = pop_init_true,
                     latent_z_vec = rnorm(n_times, 0, 0.1)),
                   simplify = FALSE)

m <- model(r, latent_z_vec, init)
draws <- mcmc(m,
              initial_values = inits)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling ====================================== 1000/1000 | eta:  0s

# get posterior samples and plot summaries
sims <- calculate(pop_modelled, nsim = 1000, values = draws)
posterior_sims <- sims$pop_modelled[, , 1]
posterior_sims_discrete <- round(posterior_sims)
posterior_est <- colMeans(posterior_sims_discrete)
posterior_ci <- apply(posterior_sims_discrete, 2, quantile, c(0.025, 0.975))

# plot posterior draws, summary stats, and the truth
plot(posterior_est,
     ylim = range(c(posterior_ci, pop_true)),
     type = "n",
     xlab = "day",
     ylab = "population")
for (i in 1:50) {
  lines(posterior_sims_discrete[i, ],
        lwd = 0.1,
        col = grey(0.4))
}
lines(posterior_ci[1, ], lty = 2)
lines(posterior_ci[2, ], lty = 2)
lines(posterior_est,
     lwd = 1.5)
lines(pop_true,
      col = "blue")
# noisy observations, naively adjusted for detection probability
points(pop_obs / obs_prob,
       cex = 0.5)

Created on 2024-03-13 with reprex v2.0.2

goldingn commented 3 months ago

Here's a demo estimating population trajectories and growth rates from data with discrete stochastic (poisson + multinomial) population and extinction/invasion dynamics in a metapopulation:

# demo of stochastic dispersal and growth dynamics

# simulate a discrete stochastic growth rate population model
set.seed(1)
n_times <- 20
n_pops <- 4

# daily population growth rate in each location
r_true <- runif(n_pops, 1, 1.3)

# initial populations (just before timeseries)
pop_init_true <- c(15, 0, 0, 0)

# population locations and dispersal matrix
coords <- matrix(runif(n_pops * 2), nrow = n_pops)
dispersal_range <- 6
dispersal_weight_raw <- exp(-dispersal_range * as.matrix(dist(coords)))
# add a nugget effect (increased probability of not dispersing)
prob_dispersing <- 0.1
dispersal_weight <- dispersal_weight_raw * prob_dispersing +
  diag(n_pops) * (1 - prob_dispersing)

# plot these
par(mfrow = c(1, 1))
plot(coords,
     type = "n",
     ylab = "",
     xlab = "",
     axes = FALSE)
for (i in 1:n_pops) {
  for (j in i:n_pops) {
    arrows(x0 = coords[i, 1],
           y0 = coords[i, 2],
           x1 = coords[j, 1],
           y1 = coords[j, 2],
           length = 0,
           lwd = 10 * dispersal_weight[i, j])
  }
}
points(coords,
       pch = 21,
       bg = grey(0.8),
       cex = 2)
text(coords[, 1],
     coords[, 2],
     labels = paste("pop", seq_len(n_pops)),
     pos = 3,
     xpd = NA)

# normalise dispersal weights to get dispersal probabilities
dispersal_prob <- sweep(dispersal_weight,
                        1,
                        rowSums(dispersal_weight),
                        FUN = "/")

# simulate stochastic population dynamics and dispersal
time <- seq_len(n_times)
pop_true <- matrix(NA, nrow = n_times, ncol = n_pops)
pop_previous <- pop_init_true
for (i in 1:n_times) {
  # innovate populations
  pop_grown <- rpois(n_pops, pop_previous * r_true)
  # do dispersal, with multinomial randomness
  pop_dispersed <- matrix(NA, n_pops, n_pops)
  for (pop in 1:n_pops) {
    pop_dispersed[pop, ] <- rmultinom(1,
                                   pop_grown[pop],
                                   prob = dispersal_prob[pop, ])
  }
  # collate all the individuals staying, arriving, less those leaving
  pop_new <- colSums(pop_dispersed)
  # store the states
  pop_previous <- pop_true[i, ] <- pop_new

}

# add an observation process (binomial with fixed detection probability)
obs_prob <- 0.8
pop_obs <- pop_true * NA
pop_obs[] <- rbinom(length(pop_true),
                    size = pop_true[],
                    prob = obs_prob)

# # plot true (lines) and observed (points) populations across these populations
# par(mfrow = n2mfrow(n_pops))
# for (i in seq_len(n_pops)) {
#   plot(pop_true[, i] ~ time,
#        type = "l",
#        ylab = "population",
#        ylim = range(pop_true),
#        main = paste("pop", i))
#
#   points(pop_obs[, i] ~ time,
#        type = "b",
#        pch = 21,
#        bg = ifelse(pop_true[, i] > 0, grey(0.4, 0.5), NA))
# }

# build a greta model to infer latent populations, using stochastic transitions
# but a continuous relaxation of the discrete process
source("src/modelFunctions.R")
#> Warning in file(filename, "r", encoding = encoding): cannot open file
#> 'src/modelFunctions.R': No such file or directory
#> Error in file(filename, "r", encoding = encoding): cannot open the connection
library(greta.dynamics)
#> Loading required package: greta
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply


# define functions for continuous relaxation and reparameterisation of the
# Poisson distributions

# given random variables z (with standard normal distribution a priori), and
# Poisson rate parameter lambda, return a strictly positive continuous random
# variable with the same mean and variance as a poisson random variable with
# rate lambda, by approximating the poisson as a lognormal distribution.
lognormal_continuous_poisson <- function(lambda, z) {
  sigma <- sqrt(log1p(lambda / exp(2 * log(lambda))))
  # sigma2 <- log1p(1 / lambda)
  mu <- log(lambda) - sigma^2 / 2
  exp(mu + z * sigma)
}

# dispersal parameters to be fixed for now, just learn the growth rates and the
# initial populations and transitions
r <- normal(1, 0.25,
            truncation = c(0, Inf),
            dim = n_pops)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 

# latent variable for stochastic transitions in lognormal approximation to
# Poisson
latent_z_vec <- normal(0, 1, dim = c(n_times, n_pops))

# initial population size - with mean equal to truth, but small, instead of
# zero, elements
pop_init_prior <- pmax(pop_init_true, 1e-3)
init <- exponential(1 / pop_init_prior)

# transition function for the population process as a difference equation in
# integer timestep
fun <- function(state, iter, r, latent_z) {
  # grow the populations
  expected_pop_grown <- state * r
  # disperse the populations
  expected_pop <- t(dispersal_prob) %*% expected_pop_grown
  # add stochasticity
  state_new <- lognormal_continuous_poisson(expected_pop, latent_z)
  state_new
}

# solve it (integer times)
pop_sim <- iterate_dynamic_function(
  transition_function = fun,
  initial_state = init,
  niter = n_times,
  tol = 0,
  r = r,
  latent_z = latent_z_vec,
  parameter_is_time_varying = "latent_z",
  # clamp the populations to reasonably values to avoid numerical under/overflow
  state_limits = c(1e-3, 1e3)
)

# get the modelled true population
pop_modelled <- t(pop_sim$all_states)
# and the expected value of the observation distribution
pop_obs_expected <- pop_modelled * obs_prob
pop_obs_ga <- as_data(pop_obs)
distribution(pop_obs_ga) <- poisson(pop_obs_expected)

# set the initial values for the population trajectories to be near the
# (deterministic) expected values, so sampling the stochastic values doesn't
# initialise us in very weird parts of parameter space
n_chains <- 4
inits <- replicate(n_chains,
                   initials(
                     init = pop_init_prior,
                     latent_z_vec = matrix(rnorm(n_times * n_pops, 0, 0.1),
                                           n_times,
                                           n_pops)),
                   simplify = FALSE)

m <- model(r, latent_z_vec, init)
draws <- mcmc(m,
              initial_values = inits)
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling ====================================== 1000/1000 | eta:  0s

# check convergence
summary(coda::gelman.diag(draws,
                          autoburnin = FALSE,
                          multivariate = FALSE)$psrf)
#>    Point est.      Upper C.I.   
#>  Min.   :1.002   Min.   :1.004  
#>  1st Qu.:1.048   1st Qu.:1.123  
#>  Median :1.112   Median :1.298  
#>  Mean   :1.211   Mean   :1.538  
#>  3rd Qu.:1.234   3rd Qu.:1.616  
#>  Max.   :2.750   Max.   :5.395

# check growth rate estimates we would not expect the posterior to be
# data-informed at all for populations 2 and 3 where the population doesn't have
# a chance to grow. For 1 and 4 (established populations form near the start, no
# stochastic extinction)) we would expect to to estimate a positive growth rate
# somewhere in the correct ball park
summary(calculate(r, values = draws))$statistics
#>             Mean         SD     Naive SE Time-series SE
#> r[1,1] 1.0957812 0.03560134 0.0005629066    0.004680752
#> r[2,1] 0.7560150 0.24658097 0.0038987875    0.040579097
#> r[3,1] 0.9790061 0.12836447 0.0020296205    0.015262417
#> r[4,1] 1.1751366 0.06651551 0.0010517025    0.008744965
r_true
#> [1] 1.079653 1.111637 1.171856 1.272462

# posterior simulations
sims_posterior <- calculate(pop_modelled,
                        values = draws,
                        nsim = 100)

par(mfrow = n2mfrow(n_pops))
for (pop in 1:n_pops) {
  trajectories_posterior <- round(sims_posterior[[1]][, , pop])
  plot(trajectories_posterior[1, ] ~ time,
       type = "n",
       ylim = range(c(trajectories_posterior)),
       ylab = "population",
       main = paste("pop", pop))
  apply(jitter(trajectories_posterior),
        1,
        lines,
        col = grey(0.4, 0.2),
        lwd = 2)
  lines(pop_true[, pop] ~ time,
        lwd = 2,
        col = "blue")
  points(pop_obs[, pop] / obs_prob ~ time,
         bg = ifelse(pop_true[, pop] > 0, "skyblue", "white"),
         cex = 1,
         pch = 21)
}

Created on 2024-03-15 with reprex v2.0.2

goldingn commented 3 months ago

Here are some ballpark run times (MCMC stage only, 4 chains, 2K warmup, 2K samples, not accounting for convergence) with varying dimensions of this example problem:

n_times  n_pops  latent state-space  seconds
20       4       80                  108
50       4       200                 232
100      4       400                 466
20       10      200                 111
20       100     2000                190
20       500     10000               820

So the model run time is approximately linear in n_times, as expected. It scales sub-linearly with the number of populations (especially scalable at low numbers of populations), which is also as expected as TF can parallelise the matrix multiply. This will likely be different when learning the dispersal matrix bit.