greta-dev / greta

simple and scalable statistical modelling in R
https://greta-stats.org
Other
535 stars 63 forks source link

Add custom summary method for greta MCMC chains #653

Open njtierney opened 3 months ago

njtierney commented 3 months ago

Currently is uses the coda::summary.mcmc.list method, which is fine, but I think we could make it a bit nicer.

hrlai commented 3 months ago

Not sure how much effort you want to sink into this, so just showing the posterior package here in case it helps to make decision.

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

x <- rnorm(10)
mu <- normal(0, 5)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
sigma <- lognormal(1, 0.1)
distribution(x) <- normal(mu, sigma)
m <- model(mu, sigma)
draws <- mcmc(m)
#> running 4 chains simultaneously on up to 8 cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 11s | 20% bad    warmup ====                                    100/1000 | eta:  7s | 13% bad    warmup ======                                  150/1000 | eta:  6s | 9% bad     warmup ========                                200/1000 | eta:  5s | 7% bad     warmup ==========                              250/1000 | eta:  4s | 5% bad     warmup ===========                             300/1000 | eta:  4s | 4% bad     warmup =============                           350/1000 | eta:  3s | 4% bad     warmup ===============                         400/1000 | eta:  3s | 3% bad     warmup =================                       450/1000 | eta:  3s | 3% bad     warmup ===================                     500/1000 | eta:  2s | 3% bad     warmup =====================                   550/1000 | eta:  2s | 2% bad     warmup =======================                 600/1000 | eta:  2s | 2% bad     warmup =========================               650/1000 | eta:  2s | 2% bad     warmup ===========================             700/1000 | eta:  1s | 2% bad     warmup ============================            750/1000 | eta:  1s | 2% bad     warmup ==============================          800/1000 | eta:  1s | 2% bad     warmup ================================        850/1000 | eta:  1s | 2% bad     warmup ==================================      900/1000 | eta:  1s | 2% bad     warmup ====================================    950/1000 | eta:  0s | 1% bad     warmup ====================================== 1000/1000 | eta:  0s | 1% bad 
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  5s            sampling ====                                    100/1000 | eta:  4s            sampling ======                                  150/1000 | eta:  4s            sampling ========                                200/1000 | eta:  3s            sampling ==========                              250/1000 | eta:  3s            sampling ===========                             300/1000 | eta:  3s            sampling =============                           350/1000 | eta:  2s            sampling ===============                         400/1000 | eta:  2s            sampling =================                       450/1000 | eta:  2s            sampling ===================                     500/1000 | eta:  2s            sampling =====================                   550/1000 | eta:  2s            sampling =======================                 600/1000 | eta:  1s            sampling =========================               650/1000 | eta:  1s            sampling ===========================             700/1000 | eta:  1s            sampling ============================            750/1000 | eta:  1s            sampling ==============================          800/1000 | eta:  1s            sampling ================================        850/1000 | eta:  1s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s

# default summary
summary(draws)
#> 
#> Iterations = 1:1000
#> Thinning interval = 1 
#> Number of chains = 4 
#> Sample size per chain = 1000 
#> 
#> 1. Empirical mean and standard deviation for each variable,
#>    plus standard error of the mean:
#> 
#>          Mean     SD Naive SE Time-series SE
#> mu    -0.1733 0.8015 0.012673       0.023347
#> sigma  2.5303 0.2506 0.003963       0.005833
#> 
#> 2. Quantiles for each variable:
#> 
#>         2.5%     25%     50%    75% 97.5%
#> mu    -1.728 -0.7143 -0.1803 0.3563 1.442
#> sigma  2.088  2.3528  2.5128 2.6896 3.059

# piggy back on the posterior package

library(posterior)
#> This is posterior version 1.5.0
#> 
#> Attaching package: 'posterior'
#> 
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var
#> 
#> The following objects are masked from 'package:base':
#> 
#>     %in%, match
draws2 <- as_draws(draws)

draws2
#> # A draws_list: 1000 iterations, 4 chains, and 2 variables
#> 
#> [chain = 1]
#> $mu
#>  [1] -0.741  0.507  0.507  0.638 -0.178 -0.208 -0.218  0.085 -0.057 -0.596
#> 
#> $sigma
#>  [1] 2.3 2.1 2.1 2.2 2.5 2.7 2.4 2.4 2.2 2.5
#> 
#> 
#> [chain = 2]
#> $mu
#>  [1]  0.33  0.37 -0.36  0.50  1.66  0.71  0.72 -0.18 -0.32  0.42
#> 
#> $sigma
#>  [1] 2.3 2.3 3.0 3.3 3.5 3.3 2.8 2.8 2.6 2.5
#> 
#> # ... with 990 more iterations, and 2 more chains

summary(draws2)
#> # A tibble: 2 × 10
#>   variable   mean median    sd   mad    q5   q95  rhat ess_bulk ess_tail
#>   <chr>     <dbl>  <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
#> 1 mu       -0.173 -0.180 0.802 0.794 -1.49  1.20  1.00     961.     554.
#> 2 sigma     2.53   2.51  0.251 0.248  2.15  2.97  1.00    1725.    2453.

Created on 2024-07-30 with reprex v2.0.2