paul-buerkner / LFO-CV-paper

BSD 3-Clause "New" or "Revised" License
14 stars 6 forks source link

`rstan` example #2

Open jackbrookes opened 5 years ago

jackbrookes commented 5 years ago

Hello, I am very interested in using this technique to compare a set of time-series models. However, my models are written directly in stan and I interface with them through rstan.

I have been working on a function to perform approximate LFO like the one you have here: https://github.com/paul-buerkner/LFO-CV-paper/blob/master/sim_functions.R#L123-L204

Computing log likelihood and re-sampling data can vary on a case by case basis, so here's my attempt at a general function. It takes two functions as parameters, which return log likelihood of (new or existing) observations using a stanfit object, and a new slice of input observations according to requested indexes respectively. Just posting here as it might be helpful and hopefully someone can spot any errors.

EDIT: Modified data_fun usage for clarity, now pass dataset as well as subsetting function

approx_lfo <- function(fit, N, M, L, dataset, log_lik_fun, data_fun, k_thres = 0.6, ...) {
  # args:
  #   fit: stanfit object to evaluate
  #   M: observations ahead to predict
  #   L: minimum number of observations to keep
  #   dataset: full set of observations, passed to data_fun to subset data 
  #   log_lik_fun: function(stanfit, input_data) that takes stanfit and data returns the log-likelihood (according to sampled parameters)
  #   data_fun: function(dataset, irange) that returns a data subset with indexes passed as an argument
  #   k_thres: pareto k value threshold upon which to re-fit the model
  #   ...: additional arguments passed to rstan::stan

  require(loo)
  require(rstan)

  stopifnot(L < N)

  # get all the original datapoints (used to produce fit)
  input_data <- data_fun(dataset, 1:N)

  # compute log likelihood of original dataset from our fit and store it
  ll_full <- attr(fit, "log_lik") <- log_lik_fun(fit, input_data)

  # approximate LFO likelihoods
  loglikm <- loglik <- matrix(nrow = numsamples(fit), ncol = N)
  out <- ks <- rep(NA, N)

  # last observation included in the model fitting
  i_star <- N
  refits <- numeric(0)
  # no isolated predictions of the last M observations
  loglik[, (N - M + 1):N] <- ll_full[, (N - M + 1):N, drop = FALSE] 
  for (i in (N - M):L) {
    ioos <- 1:(i + M) # including out-of-sample indexes
    ll <- attr(fit, "log_lik")[, ioos, drop = FALSE] # most recent ll computation from fit
    loglikm[, i] <- rowSums(ll[, i:(i + M - 1)])
    loglik[, i] <- ll[, i]
    # observations over which to perform importance sampling
    logratio <- sum_log_ratios(loglik, (i + 1):i_star)
    psis_part <- suppressWarnings(psis(logratio))
    k <- pareto_k_values(psis_part)
    ks[i] <- k
    # can we estimate, or do we need to re-fit?
    if (k > k_thres) {
      # refit the model based on the first i observations
      i_star <- i
      refits <- c(refits, i)
      ind_rm <- (i + 1):N

      input_data <- data_fun(dataset, 1:(i - 1))
      fit <- stan(data = input_data, ...)

      # compute log likelihood of dataset that includes unseen samples with our new fit
      assess_data <- data_fun(dataset, 1:(i + M - 1))
      ll <- attr(fit, "log_lik") <- log_lik_fun(fit, assess_data)

      loglik[, i] <- ll[, i]
      loglikm[, i] <- rowSums(ll[, i:(i + M - 1)])
      out[i] = log_mean_exp(loglikm[, i])
    } else {
      # PSIS approximate LFO is possible
      lw_i <- weights(psis_part, normalize = TRUE)[, 1]
      out[i] <- log_sum_exp(lw_i + loglikm[, i])
    }
  }
  attr(out, "is") <- 1:N
  attr(out, "ks") <- ks
  attr(out, "refits") <- refits
  out
}

numsamples <- function(stanfit) {
  stan_sim <- attr(stanfit, "sim")
  stan_sim$chains * (stan_sim$iter - stan_sim$warmup)
}
paul-buerkner commented 5 years ago

Thanks for sharing the code! One quick question: How is data_fun getting the data if not passed as an argument?

jackbrookes commented 5 years ago

In my code, data_fun is a function made from a function, something like:

make_data_fun <- function(df) {
  return(function(indexes) df[indexes])
}

I do this because my input data is not actually a dataframe but a list, so removing observations isn't as simple as subsetting with [ like the above example.

paul-buerkner commented 5 years ago

I see. Still, the data is not explicitly passed inside the code, which may be confusing to people trying to use your function.