ASKurz / Statistical_Rethinking_with_brms_ggplot2_and_the_tidyverse_2_ed

The bookdown version lives here:
https://bookdown.org/content/4857/
Creative Commons Zero v1.0 Universal
125 stars 37 forks source link

Section 4.4.3.5 Using `nesting` when doing posterior calculations #45

Closed dipetkov closed 1 year ago

dipetkov commented 1 year ago

In Section 4.4.3.5.1 Overthinking: Rolling your own predict(), you use nesting(b_Intercept, b_weight_c, sigma) to calculate prediction intervals.

This seems to drop duplicate posterior draws. Is this the correct behavior? Shouldn't we use all draws from the posterior to do inference? That is, if we've already checked that there are no convergence issues.

# There are 4,000 samples from the posterior.
post %>%
  nrow()
#> [1] 4000

# But 3,949 unique samples.
post %>%
  select(b_Intercept, b_weight_c, sigma) %>%
  distinct() %>%
  nrow()
#> Warning: Dropping 'draws_df' class as required metadata was removed.
#> [1] 3949

# By nesting we drop the duplicates.
post %>%
  expand(
    nesting(b_Intercept, b_weight_c, sigma)
  ) %>%
  nrow()
#> [1] 3949

# But perhaps we should keep all the draws?
# Or thin the draws... If duplicates remain (very unlikely), keep them.
post %>%
  crossing(
    weight = 50 - mean(Howell1$weight)
  ) %>%
  nrow()
#> [1] 4000
ASKurz commented 1 year ago

Hmm, I'll have to look at this more closely. This doesn't look good, to me. Great catch and great question!

ASKurz commented 1 year ago

A quick fix to tide you over would be something like

library(tidyverse)

tibble(x = c(1, 1:3),
       y = c(11, 11:13)) %>% 
  mutate(.draw = 1:n()) %>% 
  expand(l = letters[1:2],
         nesting(.draw, x, y))
# A tibble: 8 × 4
  l     .draw     x     y
  <chr> <int> <dbl> <dbl>
1 a         1     1    11
2 a         2     1    11
3 a         3     2    12
4 a         4     3    13
5 b         1     1    11
6 b         2     1    11
7 b         3     2    12
8 b         4     3    13

where the new variable .draw indexes the rows in the original data frame. I'm calling these .draw because the brms::as_draws_df() function automatically includes a .draw index. So anytime you are working with the results of as_draws_df(), it'll already be there for you. I'll be on the lookout for a more elegant solution, though.

ASKurz commented 1 year ago

Okay, the good people of Twitter (namely the great A. Jordan Nafa) came through (see here). A better solution is to use tidyr::expand_grid().

library(tidyverse)

tibble(x = c(1, 1:3),
       y = c(11, 11:13)) %>% 
  expand_grid(l = letters[1:2]) %>% 
  # not needed, just added for pedagogical purposes
  arrange(l)
# A tibble: 8 × 3
      x     y l    
  <dbl> <dbl> <chr>
1     1    11 a    
2     1    11 a    
3     2    12 a    
4     3    13 a    
5     1    11 b    
6     1    11 b    
7     2    12 b    
8     3    13 b 
dipetkov commented 1 year ago

But why do you need either nesting or expand_grid in the first place?

This:

as_draws_df(b4.3) %>%
  select(b_Intercept, b_weight_c, sigma)

gives you a sample of parameter values from the posterior. That's all you need to calculate $E_{\text{posterior}}(\text{quantity of interest})$.

Importantly, the b_Intercept, b_weight_c and sigma values are correlated. Does it really make sense to create a grid b_intercept x b_weight_c x sigma (of all possible unique triplets)? I don't think so because not all these triplets are equally likely under the posterior.

I suggest instead:

# Compute compatibility intervals
as_draws_df(b4.3) %>%
  select(b_Intercept, b_weight_c, sigma) %>%
  crossing(
    weight_c = c(30, 40, 50) - mean(Howell1$weight)
  ) %>%
  mutate(
    Ey = b_Intercept +  b_weight_c * weight_c
    ) %>%
  group_by(
    weight_c
  ) %>%
  summarise(
    lower = quantile(Ey, .025),
    estimate = mean(Ey),
    upper = quantile(Ey, .975)
  )

And:

# Compute prediction intervals
as_draws_df(b4.3) %>%
  select(b_Intercept, b_weight_c, sigma) %>%
  crossing(
    weight_c = c(30, 40, 50) - mean(Howell1$weight)
  ) %>%
  mutate(
    y = rnorm(n(), b_Intercept +  b_weight_c * weight_c, sigma)
  ) %>%
  group_by(
    weight_c
  ) %>%
  summarise(
    lower = quantile(y, .025),
    estimate = mean(y),
    upper = quantile(y, .975)
  )
ASKurz commented 1 year ago

A few things:

dipetkov commented 1 year ago

I have applied the age >= 18 filter but didn't change the name of the tibble. I've rechristened it Howell1_adults. On this point, I disagree with McElreath and you, d2 is a bad variable name and makes for harder to read code.

I admit I don't understand how expand_grid would do the job since it takes name-value pairs, not a tibble (of posterior draws) as input. But there are often several ways to implement the same logic in the tidyverse! What I'm aiming for is a "by-hand" version of this snippet in rethinking::sim: https://github.com/rmcelreath/rethinking/blob/master/R/sim.r#L190

Thank you a lot for taking a look at this!

dipetkov commented 1 year ago

Came up with an even more elegant solution, to my eyes at least! Use full_join with by = character() to generate all combinations of the posterior draws and the new data. Here is the full reprex:

library("brms")
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.18.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
library("tidyverse")

data(Howell1, package = "rethinking")

Howell1_adults <- Howell1 %>%
  filter(
    age >= 18
  ) %>%
  mutate(
    weight_c = weight - mean(weight)
  )

b4.3 <-
  brm(
    data = Howell1_adults,
    family = gaussian,
    height ~ 1 + weight_c,
    prior = c(
      prior(normal(178, 20), class = Intercept),
      prior(lognormal(0, 1), class = b),
      prior(uniform(0, 50), class = sigma, ub = 50)
    ),
    iter = 2000, warmup = 1000, chains = 4, cores = 4,
    seed = 4
  )

newdata <- tibble(
  weight_c = c(30, 40, 50) - mean(Howell1_adults$weight)
)

as_draws_df(b4.3) %>%
  full_join(
    newdata,
    by = character()
  ) %>%
  mutate(
    Ey = b_Intercept + b_weight_c * weight_c
  ) %>%
  group_by(
    weight_c
  ) %>%
  summarise(
    lower = quantile(Ey, .025),
    estimate = mean(Ey),
    upper = quantile(Ey, .975)
  )
#> # A tibble: 3 × 4
#>   weight_c lower estimate upper
#>      <dbl> <dbl>    <dbl> <dbl>
#> 1   -15.0   140.     141.  142.
#> 2    -4.99  149.     150.  151.
#> 3     5.01  158.     159.  160.

as_draws_df(b4.3) %>%
  full_join(
    newdata,
    by = character()
  ) %>%
  mutate(
    y = rnorm(n(), b_Intercept + b_weight_c * weight_c, sigma)
  ) %>%
  group_by(
    weight_c
  ) %>%
  summarise(
    lower = quantile(y, .025),
    estimate = mean(y),
    upper = quantile(y, .975)
  )
#> # A tibble: 3 × 4
#>   weight_c lower estimate upper
#>      <dbl> <dbl>    <dbl> <dbl>
#> 1   -15.0   131.     141.  151.
#> 2    -4.99  140.     150.  160.
#> 3     5.01  149.     159.  169.

Created on 2023-01-13 with reprex v2.0.2