donaldRwilliams / chkptstanr

Checkpoint Stan R
https://donaldrwilliams.github.io/chkptstanr/
Other
9 stars 5 forks source link

multivariate models appear supported #9

Open peclayson opened 2 years ago

peclayson commented 2 years ago

I tried to use multivariate models, but chkpt_brms kicks an error.

chkpt_brms checks whether the formula is a formula or brmsformula. If not, it stops and kicks an error. See below.

if (isFALSE(is(formula, "formula") |
              is(formula, "brmsformula") )) {
    stop("formula must be of class formula or brmsformula")
  }

I changed these lines to allow mvbrmsformula to pass, and it seems to have no downstream detrimental effect. Model fits seem fine. Is there something I'm missing where I shouldn't try to use multivariate models?

  if (isFALSE(is(formula, "formula") |
              is(formula, "brmsformula") |
              is(formula, "mvbrmsformula"))) {
    stop("formula must be of class formula, brmsformula, or mvbrmsformula")
  }

I tested using the following code.

library(brms)
library(cmdstanr)
library(chkptstanr)

#model example from 
# https://cran.r-project.org/web/packages/brms/vignettes/brms_multivariate.html
data("BTdata", package = "MCMCglmm")

bform1 <- 
  bf(mvbind(tarsus, back) ~ sex + hatchdate + (1|p|fosternest) + (1|q|dam)) +
  set_rescor(TRUE)

fit1 <- brm(bform1, 
            data = BTdata, 
            chains = 2, 
            cores = 2,
            seed = 062122,
            iter = 2000)

chkpt_path <- create_folder(folder_name  = "mv_test")

fit1_chkpt <- chkpt_brms(formula = bform1,
                         data = BTdata,
                         path = chkpt_path,
                         parallel_chains = 2,
                         iter_warmup = 1000,
                         iter_sampling = 1000,
                         iter_per_chkpt = 250,
                         seed = 062122)

Here are the model summaries.

Using brms...

> summary(fit1)
 Family: MV(gaussian, gaussian) 
  Links: mu = identity; sigma = identity
         mu = identity; sigma = identity 
Formula: tarsus ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam) 
         back ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam) 
   Data: BTdata (Number of observations: 828) 
  Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 2000

Group-Level Effects: 
~dam (Number of levels: 106) 
                                     Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept)                     0.48      0.05     0.39     0.59 1.00
sd(back_Intercept)                       0.25      0.07     0.10     0.39 1.00
cor(tarsus_Intercept,back_Intercept)    -0.51      0.22    -0.93    -0.06 1.00
                                     Bulk_ESS Tail_ESS
sd(tarsus_Intercept)                      823     1140
sd(back_Intercept)                        343      748
cor(tarsus_Intercept,back_Intercept)      577      808

~fosternest (Number of levels: 104) 
                                     Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept)                     0.27      0.05     0.17     0.37 1.00
sd(back_Intercept)                       0.35      0.06     0.24     0.47 1.00
cor(tarsus_Intercept,back_Intercept)     0.68      0.21     0.20     0.98 1.01
                                     Bulk_ESS Tail_ESS
sd(tarsus_Intercept)                      643     1189
sd(back_Intercept)                        530      882
cor(tarsus_Intercept,back_Intercept)      225      656

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
tarsus_Intercept    -0.41      0.07    -0.53    -0.28 1.00     1781     1515
back_Intercept      -0.01      0.07    -0.15     0.12 1.00     2302     1516
tarsus_sexMale       0.77      0.06     0.66     0.88 1.00     3588     1614
tarsus_sexUNK        0.23      0.13    -0.02     0.48 1.00     2977     1591
tarsus_hatchdate    -0.04      0.06    -0.15     0.07 1.00     1740     1410
back_sexMale         0.01      0.07    -0.12     0.14 1.00     3571     1214
back_sexUNK          0.15      0.15    -0.15     0.45 1.00     4185     1752
back_hatchdate      -0.09      0.05    -0.19     0.02 1.00     2004     1235

Family Specific Parameters: 
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_tarsus     0.76      0.02     0.72     0.80 1.00     2284     1456
sigma_back       0.90      0.02     0.85     0.95 1.00     2538     1248

Residual Correlations: 
                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
rescor(tarsus,back)    -0.05      0.04    -0.13     0.02 1.00     2960
                    Tail_ESS
rescor(tarsus,back)     1599

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Using chkptstanr...

 Family: MV(gaussian, gaussian) 
  Links: mu = identity; sigma = identity
         mu = identity; sigma = identity 
Formula: tarsus ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam) 
         back ~ sex + hatchdate + (1 | p | fosternest) + (1 | q | dam) 
   Data: data (Number of observations: 828) 
  Draws: 2 chains, each with iter = 250; warmup = 0; thin = 1;
         total post-warmup draws = 500

Group-Level Effects: 
~dam (Number of levels: 106) 
                                     Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept)                     0.48      0.05     0.39     0.59 1.00
sd(back_Intercept)                       0.25      0.08     0.09     0.39 1.00
cor(tarsus_Intercept,back_Intercept)    -0.51      0.22    -0.92    -0.05 1.00
                                     Bulk_ESS Tail_ESS
sd(tarsus_Intercept)                      933     1323
sd(back_Intercept)                        323      589
cor(tarsus_Intercept,back_Intercept)      673      591

~fosternest (Number of levels: 104) 
                                     Estimate Est.Error l-95% CI u-95% CI Rhat
sd(tarsus_Intercept)                     0.27      0.06     0.16     0.38 1.00
sd(back_Intercept)                       0.35      0.06     0.24     0.47 1.00
cor(tarsus_Intercept,back_Intercept)     0.67      0.22     0.18     0.98 1.01
                                     Bulk_ESS Tail_ESS
sd(tarsus_Intercept)                      527      953
sd(back_Intercept)                        462     1418
cor(tarsus_Intercept,back_Intercept)      208      527

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
tarsus_Intercept    -0.40      0.07    -0.54    -0.27 1.00     2203     1757
back_Intercept      -0.01      0.07    -0.14     0.11 1.00     2504     1601
tarsus_sexMale       0.77      0.06     0.66     0.88 1.00     4196     1856
tarsus_sexUNK        0.23      0.13    -0.03     0.49 1.00     2255     1661
tarsus_hatchdate    -0.04      0.06    -0.16     0.07 1.00     2115     1404
back_sexMale         0.01      0.07    -0.12     0.15 1.00     3657     1421
back_sexUNK          0.15      0.15    -0.16     0.45 1.00     3231     1507
back_hatchdate      -0.09      0.05    -0.19     0.01 1.00     3022     1794

Family Specific Parameters: 
             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma_tarsus     0.76      0.02     0.72     0.80 1.00     2351     1463
sigma_back       0.90      0.02     0.86     0.95 1.00     2271     1520

Residual Correlations: 
                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
rescor(tarsus,back)    -0.05      0.04    -0.13     0.02 1.01     3044
                    Tail_ESS
rescor(tarsus,back)     1172

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Model summaries seem pretty similar in most respects.

I also don't know whether you've noticed, but I've seen this several times now. Using summary on a model fit with chkpt_brms will return the incorrect number of iterations. It will show the number of iterations in a checkpoint, not the total number of iterations. Note above:

  Draws: 2 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 2000

and

  Draws: 2 chains, each with iter = 250; warmup = 0; thin = 1;
         total post-warmup draws = 500

However, using when comparing dim(posterior_samples(fit1)) to dim(posterior_samples(fit1_chkpt)), each has 2000 draws. Different number of columns... (not sure what that's about), but the correct number of draws.

Anyway, I see no issue using multivariate models, but please let me know if I missed something. I'm about to queue up several :)