stan-dev / posterior

The posterior R package
https://mc-stan.org/posterior/
Other
167 stars 24 forks source link

Weighted rvars #331

Open mjskay opened 10 months ago

mjskay commented 10 months ago

Summary

This PR aims to address (at least part of) #184 by implementing weighted rvars.

Currently, rvars cannot contain weights, and weighting of them can only be done by putting them in a draws_rvars object that itself contains a ".log_weight" rvar containing the weights. This leads to counterintuitive behavior, like the default output of the rvar (showing mean and sd) using unweighted versions of those statistics.

This PR addresses that issue in the following ways:

Demo

set.seed(1234)
x = rvar(rnorm(1000))
x
#> rvar<1000>[1] mean ± sd:
#> [1] -0.027 ± 1

w1 = rexp(1000)
x1 = weight_draws(x, w1)
x1
#> weighted rvar<1000>[1] mean ± sd:
#> [1] -0.00087 ± 1

w2 = rexp(1000)
x2 = weight_draws(x, w2)
x2
#> weighted rvar<1000>[1] mean ± sd:
#> [1] -0.003 ± 0.96

You can't combine two rvars with different weights:

x1 + x2
#> Error: Random variables have different log weights and cannot be used together:
#> <dbl> 0.794199981930473, -1.61888922585584, 1.02558084358998, -0.657945687118312, 0.132635682996154 ...
#> <dbl> 0.661721670766407, -1.46589074644228, -1.39312536919089, 0.318133129307739, 0.66043235310858 ...

The check for equality of weights is done on the internal weights using identical(), which should be fast, especially in cases where the two weight vectors are actually pointers to the same vector in memory (in which case the comparison is constant time). This does mean the weights vectors must be exactly the same (no tolerance for floating point error), but I suspect in most cases when weighting happens the exact same weight vector is being applied to many objects. In any case, if someone did encounter this issue they could simply assign the log weights from object to the other.

If one rvar is weighted and another is not, the weights of the weighted rvar are inherited, which I believe covers the use case of (weighted draws from some model) + (unweighted draws, e.g. used to simulate predictions):

x1 + rvar(rnorm(1000, 1))
#> weighted rvar<1000>[1] mean ± sd:
#> [1] 0.96 ± 1.4

If you install the dev version of {ggdist}:

remotes::install_github("mjskay/ggdist")

It supports weighted rvars in all functions (densities, CDFs, quantiles, all interval types and all point summaries):

Without weights:

library(ggplot2)
library(ggdist)

set.seed(1234)
x = rvar(rnorm(10000, c(1,5)))

ggplot() + stat_slabinterval(aes(xdist = x))

image

With weights:

xw = weight_draws(x, rep(c(1,2), 5000))
ggplot() + stat_slabinterval(aes(xdist = xw))

image

Weights should work basically everywhere:

ggplot() + 
  stat_slabinterval(
    aes(xdist = xw), 
    point_interval = mode_hdi, 
    density = "histogram", 
    breaks = 50
  )

image

TODOs and Questions

TODOs:

Questions:

Would love for folks to kick the tires. I think once this is in we could also start thinking about what a successor to summarise_draws() might look like that supports weights (and solves the various other open issues on summarise_draws()).

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to license the submitted work under the following licenses:

codecov-commenter commented 10 months ago

Codecov Report

Attention: Patch coverage is 98.89299% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 95.80%. Comparing base (c312846) to head (88fa83d). Report is 12 commits behind head on master.

:exclamation: Current head 88fa83d differs from pull request most recent head 1079cef. Consider uploading reports for the commit 1079cef to get more accurate results

Files Patch % Lines
R/rvar-.R 94.87% 2 Missing :warning:
R/weighted.R 98.07% 1 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #331 +/- ## ========================================== + Coverage 95.31% 95.80% +0.49% ========================================== Files 50 51 +1 Lines 3840 3979 +139 ========================================== + Hits 3660 3812 +152 + Misses 180 167 -13 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if bdef35c34867a28c7e956da43881bffde8bda583 is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 8c803b8d6eb92d3c521547cd90917d32fd9aa5de is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if cc2cb23ca0f0065665b2a0b1f2279d020df642cf is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 1c1c4b7645fa4e56332a23089e9741764f8406cb is merged into master:

avehtari commented 10 months ago

I don't know if any of the other functions in R/convergence.R should be modified for weighted rvars.

Currently, everything else than pareto_ functions assume non-weighted MCMC. I have so far assumed that MCMC and weighting are independent of each other (there might be some less common algorithms that jointly sample parameter values and weights).

Pinging @n-kall , too

n-kall commented 10 months ago
  • If there are both (assumed) Markov dependency and weights, we could follow the approach presented in PSIS paper for ess_ and mcse_ functions

For reference: Equations 6 (MCSE) and 7 (ESS) in preprint v6

  • pareto_ functions are checking the tail(s) of a given argument, and it has been used to check tails of raw weights/ratios (r or r(theta) in PSIS paper notation), function of a variable (h or h(theta)), or the product (hr). With the weight support, they could automatically make the diagnostics for r and hr (and if no weights then just h). Here I'm assuming that we almost always use self-normalization so that we need to check the normalization (E[r]) and the quantity of interest (E[hr])

Any thoughts on how the two sets of diagnostics should be presented in summarise_draws? Would it make sense to have separate e.g. pareto_khat_quantity, pareto_khat_weights columns?

n-kall commented 10 months ago

I'm currently working on updating the pareto_, ess_ and mcse_ functions for weighted rvars in a fork

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if e52a6f93d1dc05bcc37a1826395ca3c85a9e410d is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 052bb735cac2064426b1a645557cf2f6b29d4155 is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if b59df5e21583e55e99a438c8e9d09c6780af23ec is merged into master:

github-actions[bot] commented 10 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if facc5864e908be28fc8ddebee86e9f1fd150358f is merged into master:

github-actions[bot] commented 9 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if a8ac96e340419a71418ea5303fd4d4a0f97e5f23 is merged into master:

mjskay commented 9 months ago

Okay, I think this is ready for review, pending two things:

  1. Do we want to add an attribute distinguishing MC vs MCMC to this PR? I would suggest we wait and do that as a separate PR to address #239, as it will involve further discussion / thought, and this PR is already large.
  2. Do we want to merge this PR to master and then merge @n-kall's branch on weighted diagnostics to master, or do we want to merge @n-kall's branch into this PR and then merge to master?
github-actions[bot] commented 9 months ago

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 88fa83dacdb4039da9d18a0be57cced099fd92b2 is merged into master:

n-kall commented 9 months ago

Re: option 2, I'll still need some more time to finish up the weighted mcse and ess. So I think it's better to merge without waiting for me

paul-buerkner commented 9 months ago

I am at a conference and then on vacation for the next two weeks. Can someone else review this PR?