stan-dev / loo

loo R package for approximate leave-one-out cross-validation (LOO-CV) and Pareto smoothed importance sampling (PSIS)
https://mc-stan.org/loo
Other
148 stars 34 forks source link

Clarification on using loo_moment_match() with non-Stan objects #209

Open wlandau opened 2 years ago

wlandau commented 2 years ago

I am working on a model averaging problem with very simple models, and I am getting intermittently high Pareto k values even on simple well-behaved simulated datasets. I would like to apply the moment matching correction to both non-longitudinal JAGS models and longitudinal Stan models. The latter case is trivially easy with moment_match = TRUE in loo(), but I do not have a stanfit object in the former case.

Would you help me understand how to use loo_moment_match() in the case where my model fit is a posterior::as_draws_df() data frame with columns for parameters and pointwise log likelihoods? To set up a sufficiently motivating scenario, I converted the roaches example from the vignette into JAGS. I also put constrained priors on the scale parameters for the sake of learning what to do with the unconstrain_pars, log_prob_upars, and log_lik_i_upars arguments of loo_moment_match(). (Is it even appropriate to consider "unconstrained parameters" without HMC?)

    library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union

    data(roaches, package = "rstanarm")
    roaches$roach1 <- sqrt(roaches$roach1)
    x <- roaches[, c("roach1", "treatment", "senior")]
    data <- list(
        N = nrow(x),
        K = ncol(x),
        x = as.matrix(x),
        y = roaches$y,
        offset = log(roaches[,"exposure2"])
    )

    model_text <- "
model {
  for (n in 1:N) {
    y[n] ~ dpois(exp(inprod(x[n,], beta) + intercept + offset[n]))
  }
  for (k in 1:K) {
    beta[k] ~ dnorm(0, 1 / (scale_beta * scale_beta))
  }
  intercept ~ dnorm(0, 1 / (scale_alpha * scale_alpha))
  scale_beta ~ dunif(0, 10)
  scale_alpha ~ dnorm(0, 10) T(0,)
  for (n in 1:N) {
    log_lik[n] <- log(dpois(y[n], exp(inprod(x[n,], beta) + intercept + offset[n])))
  }
}
"
    file <- tempfile()
    writeLines(model_text, file)

    tmp <- capture.output({
        model <- rjags::jags.model(
            file = file,
            data = data,
            n.chains = 4,
            n.adapt = 2e3
        )
        stats::update(model, n.iter = 2e3, quiet = TRUE)
        coda <- rjags::coda.samples(
            model = model,
            variable.names = c(
                "beta",
                "intercept",
                "scale_beta",
                "scale_alpha",
                "log_lik"
            ),
            n.iter = 4e3
        )
    })

    fit <- posterior::as_draws_df(coda)
    print(fit) # This is the model fit object I can work with.
#> # A draws_df: 4000 iterations, 4 chains, and 268 variables
#>    beta[1] beta[2] beta[3] intercept log_lik[1] log_lik[2] log_lik[3]
#> 1     0.16   -0.55   -0.30       2.5        -19        -16       -2.1
#> 2     0.16   -0.55   -0.29       2.5        -19        -16       -2.1
#> 3     0.16   -0.52   -0.27       2.5        -17        -14       -2.1
#> 4     0.16   -0.55   -0.38       2.5        -16        -14       -2.1
#> 5     0.16   -0.56   -0.25       2.5        -16        -14       -2.1
#> 6     0.16   -0.59   -0.26       2.5        -18        -15       -2.1
#> 7     0.16   -0.60   -0.25       2.5        -19        -16       -2.0
#> 8     0.16   -0.58   -0.34       2.5        -18        -16       -2.0
#> 9     0.16   -0.58   -0.30       2.5        -18        -15       -2.1
#> 10    0.16   -0.60   -0.33       2.5        -17        -15       -2.1
#>    log_lik[4]
#> 1        -2.2
#> 2        -2.2
#> 3        -2.3
#> 4        -2.2
#> 5        -2.2
#> 6        -2.2
#> 7        -2.2
#> 8        -2.2
#> 9        -2.2
#> 10       -2.2
#> # ... with 15990 more draws, and 260 more variables
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}

    # Convergence looks okay.
    fit %>%
        select(starts_with(c("beta", "intercept", "scale"))) %>%
        posterior::summarize_draws() %>%
        print()
#> Warning: Dropping 'draws_df' class as required metadata was removed.
#> # A tibble: 6 × 10
#>   variable      mean median      sd     mad     q5    q95  rhat ess_bulk ess_t…¹
#>   <chr>        <dbl>  <dbl>   <dbl>   <dbl>  <dbl>  <dbl> <dbl>    <dbl>   <dbl>
#> 1 beta[1]      0.161  0.161 0.00193 0.00194  0.158  0.164  1.00    1704.   2993.
#> 2 beta[2]     -0.566 -0.566 0.0248  0.0246  -0.607 -0.524  1.00    3388.   5597.
#> 3 beta[3]     -0.312 -0.312 0.0334  0.0335  -0.368 -0.259  1.00    5957.   8365.
#> 4 intercept    2.52   2.52  0.0260  0.0260   2.48   2.56   1.00    1339.   2777.
#> 5 scale_alpha  0.905  0.889 0.156   0.153    0.673  1.19   1.00    9336.   8030.
#> 6 scale_beta   0.816  0.569 0.835   0.298    0.272  2.20   1.00    2056.    864.
#> # … with abbreviated variable name ¹​ess_tail

    # LOO without the moment matching correction is straightforward.
    log_lik <- as.matrix(dplyr::select(fit, tidyselect::starts_with("log_lik")))
#> Warning: Dropping 'draws_df' class as required metadata was removed.
    r_eff <- loo::relative_eff(x = log_lik, chain_id = fit$.chain)
    loo <- loo::loo(x = log_lik, r_eff = r_eff)
#> Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.

    # But we get high Pareto k values.
    print(loo)
#> 
#> Computed from 16000 by 262 log-likelihood matrix
#> 
#>          Estimate     SE
#> elpd_loo  -5462.1  696.5
#> p_loo       261.3   57.6
#> looic     10924.3 1393.0
#> ------
#> Monte Carlo SE of elpd_loo is NA.
#> 
#> Pareto k diagnostic values:
#>                          Count Pct.    Min. n_eff
#> (-Inf, 0.5]   (good)     239   91.2%   537       
#>  (0.5, 0.7]   (ok)         9    3.4%   76        
#>    (0.7, 1]   (bad)        7    2.7%   11        
#>    (1, Inf)   (very bad)   7    2.7%   1         
#> See help('pareto-k-diagnostic') for details.

    # How do I use loo_moment_match() in this situation?
    # loo::loo_moment_match(
    #   x = fit,
    #   post_draws = function(x) as.matrix(x),
    #   log_lik_i = function(x, i) x[[sprintf("log_lik[%s]", i)]],
    #   unconstrain_pars = "???", # Do we even need to consider the unconstrained space for non-HMC MCMC?
    #   log_prob_upars = "???", # Here is where I start to get lost.
    #   log_lik_i_upars = "???" # Same here.
    # )

Created on 2022-12-01 with reprex v2.0.2

Session info ``` r sessioninfo::session_info() #> ─ Session info ─────────────────────────────────────────────────────────────── #> setting value #> version R version 4.2.1 (2022-06-23) #> os macOS Big Sur ... 10.16 #> system x86_64, darwin17.0 #> ui X11 #> language (EN) #> collate en_US.UTF-8 #> ctype en_US.UTF-8 #> tz America/Indiana/Indianapolis #> date 2022-12-01 #> pandoc 2.19.2 @ /Applications/RStudio.app/Contents/MacOS/quarto/bin/tools/ (via rmarkdown) #> #> ─ Packages ─────────────────────────────────────────────────────────────────── #> package * version date (UTC) lib source #> abind 1.4-5 2016-07-21 [1] CRAN (R 4.2.0) #> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.2.0) #> backports 1.4.1 2021-12-13 [1] CRAN (R 4.2.0) #> checkmate 2.1.0 2022-04-21 [1] CRAN (R 4.2.0) #> cli 3.4.1 2022-09-23 [1] CRAN (R 4.2.0) #> coda 0.19-4 2020-09-30 [1] CRAN (R 4.2.0) #> colorspace 2.0-3 2022-02-21 [1] CRAN (R 4.2.0) #> DBI 1.1.3 2022-06-18 [1] CRAN (R 4.2.0) #> digest 0.6.30 2022-10-18 [1] CRAN (R 4.2.0) #> distributional 0.3.1 2022-09-02 [1] CRAN (R 4.2.0) #> dplyr * 1.0.10 2022-09-01 [1] CRAN (R 4.2.0) #> evaluate 0.18 2022-11-07 [1] CRAN (R 4.2.0) #> fansi 1.0.3 2022-03-24 [1] CRAN (R 4.2.0) #> farver 2.1.1 2022-07-06 [1] CRAN (R 4.2.0) #> fastmap 1.1.0 2021-01-25 [1] CRAN (R 4.2.0) #> fs 1.5.2 2021-12-08 [1] CRAN (R 4.2.0) #> generics 0.1.3 2022-07-05 [1] CRAN (R 4.2.0) #> ggplot2 3.4.0 2022-11-04 [1] CRAN (R 4.2.0) #> glue 1.6.2 2022-02-24 [1] CRAN (R 4.2.0) #> gtable 0.3.1 2022-09-01 [1] CRAN (R 4.2.0) #> highr 0.9 2021-04-16 [1] CRAN (R 4.2.0) #> htmltools 0.5.3 2022-07-18 [1] CRAN (R 4.2.0) #> knitr 1.41 2022-11-18 [1] CRAN (R 4.2.0) #> lattice 0.20-45 2021-09-22 [1] CRAN (R 4.2.1) #> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.2.0) #> loo 2.5.1 2022-03-24 [1] CRAN (R 4.2.0) #> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.2.0) #> matrixStats 0.63.0 2022-11-18 [1] CRAN (R 4.2.0) #> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.2.0) #> pillar 1.8.1 2022-08-19 [1] CRAN (R 4.2.0) #> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.2.0) #> posterior 1.3.1 2022-09-06 [1] CRAN (R 4.2.0) #> purrr 0.3.5 2022-10-06 [1] CRAN (R 4.2.0) #> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.2.0) #> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.2.0) #> R.oo 1.25.0 2022-06-12 [1] CRAN (R 4.2.0) #> R.utils 2.12.2 2022-11-11 [1] CRAN (R 4.2.0) #> R6 2.5.1 2021-08-19 [1] CRAN (R 4.2.0) #> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.2.0) #> rjags 4-13 2022-04-19 [1] CRAN (R 4.2.0) #> rlang 1.0.6 2022-09-24 [1] CRAN (R 4.2.0) #> rmarkdown 2.18 2022-11-09 [1] CRAN (R 4.2.0) #> rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.2.0) #> scales 1.2.1 2022-08-20 [1] CRAN (R 4.2.0) #> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.2.0) #> stringi 1.7.8 2022-07-11 [1] CRAN (R 4.2.0) #> stringr 1.4.1 2022-08-20 [1] CRAN (R 4.2.0) #> styler 1.8.1 2022-11-07 [1] CRAN (R 4.2.0) #> tensorA 0.36.2 2020-11-19 [1] CRAN (R 4.2.0) #> tibble 3.1.8 2022-07-22 [1] CRAN (R 4.2.0) #> tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.2.1) #> utf8 1.2.2 2021-07-24 [1] CRAN (R 4.2.0) #> vctrs 0.5.1 2022-11-16 [1] CRAN (R 4.2.0) #> withr 2.5.0 2022-03-03 [1] CRAN (R 4.2.0) #> xfun 0.35 2022-11-16 [1] CRAN (R 4.2.0) #> yaml 2.3.6 2022-10-18 [1] CRAN (R 4.2.0) #> #> [1] /Library/Frameworks/R.framework/Versions/4.2/Resources/library #> #> ────────────────────────────────────────────────────────────────────────────── ```
wlandau commented 2 years ago

Are there other convenient ways to make approximate LOO more robust?

n-kall commented 2 years ago

Hi, you might have some luck with using the generic moment matching functions from https://github.com/topipa/iwmm You'll need to manually specify the target function or importance weight function but it should work on a matrix object.

wlandau commented 1 year ago

Thanks, @n-kall. Is iwmm a generic implementation of https://mc-stan.org/loo/articles/loo2-moment-matching.html, or is the underlying statistical method itself different too?

n-kall commented 1 year ago

Yes, it is the same underlying mechanism, just generic (i.e. not tied to importance weights for leave-one-out posteriors). Given a log_ratio_fun, the moment_match function will return transformed draws and importance weights (and Pareto-k diagnostic values). k_threshold = 0.7 and split = TRUE would match the loo_moment_match defaults.

If you want to use it for the leave-one-out case, the log_ratio_fun should be a function that returns the negative log likelihood of the left-out observation. See the tests for an example. You'd likely need to wrap it in a loop (for each observation) and use the resulting draws+weights to calculate the elpd or other metrics you're interested in.