stan-dev / posterior

The posterior R package
https://mc-stan.org/posterior/
Other
167 stars 24 forks source link

Add nested R-hat convergence diagnostic #303

Closed n-kall closed 1 year ago

n-kall commented 1 year ago

Summary

This adds the nested R-hat convergence diagnostic which is useful when running many short chains. Chains need to be grouped into superchains (given as an additional argument). This addresses issue #256

Nested R-hat is described in: Charles C. Margossian, Matthew D. Hoffman, Pavel Sountsov, Lionel Riou-Durand, Aki Vehtari, Andrew Gelman (2022). Nested Rˆ: Assessing the convergence of Markov chain Monte Carlo when running many short chains. https://arxiv.org/abs/2110.13017

Status

Working. Currently work in progress. Functionality seems to work, but untested. Opening this draft PR early as place for discussion and further development.

TODOs:

Example usage

Example usage:

x <- example_draws()
example_superchain_ids <- c(1,1,2,2) # first two chains are part of superchain 1, second two chains are part of superchain 2
summarise_draws(x, rhat_nested, .args = list(superchain_ids = example_superchain_ids))

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to license the submitted work under the following licenses:

codecov-commenter commented 1 year ago

Codecov Report

Merging #303 (afd6cff) into master (e23467b) will decrease coverage by 0.07%. Report is 4 commits behind head on master. The diff coverage is 89.18%.

:exclamation: Current head afd6cff differs from pull request most recent head 32f97c4. Consider uploading reports for the commit 32f97c4 to get more accurate results

@@            Coverage Diff             @@
##           master     #303      +/-   ##
==========================================
- Coverage   95.66%   95.60%   -0.07%     
==========================================
  Files          46       47       +1     
  Lines        3645     3682      +37     
==========================================
+ Hits         3487     3520      +33     
- Misses        158      162       +4     
Files Coverage Δ
R/convergence.R 91.16% <ø> (ø)
R/nested_rhat.R 89.18% <89.18%> (ø)

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if da6fed76bd51d9dfe9ca4cc6c4aa3db3759ee586 is merged into master:

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if e1d22ff8e7bc5e0f06c4e98d718892f2d279b024 is merged into master:

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if e4d63a948f35cf5f0a7bbab702bd6143216548cb is merged into master:

n-kall commented 1 year ago

One thing to check: Should the rhat_nested with 1 chain per superchain be exactly equal to rhat_basic? Do we have some reference data to compare to?

Currently rhat_basic gives slightly different values, so perhaps there is an issue in the current implementation

summarise_draws(example_draws(), rhat_basic, rhat_nested, .args = list(superchain_ids = c(1,2,3,4))

# A tibble: 10 × 3
   variable   rhat_basic rhat_nested
   <chr>           <dbl>       <dbl>
 1 mu       0.9979105738 1.003389884
 2 tau      1.009976393  1.003445833
 3 theta[1] 1.014966741  1.007488973
 4 theta[2] 0.9981447065 1.002107795
 5 theta[3] 1.000405648  1.008267202
 6 theta[4] 0.9957624905 1.000607206
 7 theta[5] 0.9987923422 1.007260715
 8 theta[6] 0.9982158544 1.002237514
 9 theta[7] 1.002538583  1.003347448
10 theta[8] 0.9933503132 1.003124265
paul-buerkner commented 1 year ago

@avehtari do you want this feature to already be in the new posterior released to be release shortly? Just asking so I can set priority accordingly.

n-kall commented 1 year ago

I've now added input checks for the superchain_ids and NA/Inf values in the chains, and corresponding tests.

paul-buerkner commented 1 year ago

Great! Is there anything else you want to add from your side? Or is this ready for me to check and then merge?

paul-buerkner commented 1 year ago

I have check and things look good to me. @n-kall do I have your OK to merge?

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 10bdfc52858fddea8c0780ea24cf1da4ce33e791 is merged into master:

n-kall commented 1 year ago

@paul-buerkner hold off on merging, as Charles might take a look beforehand. We'll let you know when it's ready for merge

n-kall commented 1 year ago

There's still a discrepancy between rhat_basic (without splitting) and rhat_nested with 1 chain per superchain.

summarise_draws(
  example_draws(),
  rhat_basic_nosplit = ~rhat_basic(.x, split = FALSE),
  rhat_nested = ~rhat_nested(.x, superchain_ids = c(1, 2, 3, 4))
)
# # A tibble: 10 × 3
#    variable rhat_basic_nosplit rhat_nested
#    <chr>                 <dbl>       <dbl>
#  1 mu                  0.99839      1.0034
#  2 tau                 0.99845      1.0034
#  3 theta[1]            1.0025       1.0075
#  4 theta[2]            0.99711      1.0021
#  5 theta[3]            1.0033       1.0083
#  6 theta[4]            0.99560      1.0006
#  7 theta[5]            1.0023       1.0073
#  8 theta[6]            0.99724      1.0022
#  9 theta[7]            0.99835      1.0033
# 10 theta[8]            0.99813      1.0031
n-kall commented 1 year ago

There's still a discrepancy between rhat_basic (without splitting) and rhat_nested with 1 chain per superchain.

summarise_draws(
  example_draws(),
  rhat_basic_nosplit = ~rhat_basic(.x, split = FALSE),
  rhat_nested = ~rhat_nested(.x, superchain_ids = c(1, 2, 3, 4))
)
# # A tibble: 10 × 3
#    variable rhat_basic_nosplit rhat_nested
#    <chr>                 <dbl>       <dbl>
#  1 mu                  0.99839      1.0034
#  2 tau                 0.99845      1.0034
#  3 theta[1]            1.0025       1.0075
#  4 theta[2]            0.99711      1.0021
#  5 theta[3]            1.0033       1.0083
#  6 theta[4]            0.99560      1.0006
#  7 theta[5]            1.0023       1.0073
#  8 theta[6]            0.99724      1.0022
#  9 theta[7]            0.99835      1.0033
# 10 theta[8]            0.99813      1.0031

Ok, it pays to check the footnotes. Page 7 of Margossian et al. states: "The original R-hat uses a slightly different estimate for the within-chain variance when computing the numerator in R-hat. There W is scaled by 1/N , rather than 1/(N − 1). This explains why occasionally R-hat < 1. This is of little concern when N is large, but we care about the case where N is small, and we therefore adjust the R-hat statistic slightly." So I think this discrepancy is fine

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 9500ca297e1dbea2c18344fe3802831b66800f1d is merged into master:

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 8fce13a0f5ec584816334e3baafa1c8e1c59fb9e is merged into master:

n-kall commented 1 year ago

I've now updated the docs as @charlesm93 suggested and added some further details. I think it's ready to merge

github-actions[bot] commented 1 year ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if afd6cffac16f56570d97253b372af30e92110b1d is merged into master:

paul-buerkner commented 1 year ago

Thank you! The failing tests seem to be unrelated to this PR so I will merge it now.