mjskay / tidybayes

Bayesian analysis + tidy data + geoms (R package)
http://mjskay.github.io/tidybayes
GNU General Public License v3.0
712 stars 59 forks source link

Recover higher-level groups with gather_draws() and a nested model #269

Open BKoblinger opened 4 years ago

BKoblinger commented 4 years ago

Hi! I learned about tidybayes during StanCon and I am very pleased with what I am finding.

I am wondering if there is a tidier way to use gather_draws() with a nested model akin to x_at_y(). I think an example would be most clear.

In my model, each Prog is one of three Pillars. I am very pleased how easily I can get my data list for stan especially with x_at_y().

data <- data %>% select(y, Treatment, Pillar, Prog)
sdata <- data %>% compose_data(Pillar = x_at_y(Pillar, Prog))

I am then also happy to get the draws out of the stan model using gather_draws.

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog]) %>% 
  inner_join(data %>% select(Prog, Pillar) %>% distinct(), "Prog")

The only thing that isn't so "tidy" is the inner_join() and what I did within the inner join. gather_draws() and recover_types(data) gives me Prog, but I also want the higher level Pillar, hence the join.

It gets a bit ugly if I try to also include the top-level hyperparameters in a single call to gather_draws():

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog], d_theta_mu[Pillar]) %>% 
  left_join(data %>% transmute(Prog, pillar2 = Pillar) %>% distinct(), "Prog") %>% 
  mutate(Pillar = coalesce(Pillar, pillar2)) %>% 
  select(-pillar2)

gather_draws(d_theta_mu[Pillar]) produces a Pillar column, but it is empty for gather_draws(d_theta[Prog]), so I do the ugly join and then coalesce.

May I kindly as: is there a tidier way to do this?

mjskay commented 4 years ago

This is a really good question. It does seem unsatisfying to not have a counterpart for x_at_y() for use "on the way back out". I'd want to think carefully about what that API should look like --- there might be room for a more general version that allows you to specify lookup vectors or joins that solves your current problem while also supporting some other use cases (more on that later). I'd prefer that if possible, as the spread/gather_draws() syntax is already a bit overloaded.

In the absence of that easier method already existing, your approach with joins is close to the simplest thing I can think of off the top of my head. You could also try using the lookup vector created in sdata instead of doing a join, which should be faster. It would probably look something like this (not tested):

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog]) %>% 
  mutate(Pillar = sdata$Pillar[Prog])

Or if Pillar is a factor, you could do this to recover the level names (again not tested):

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog]) %>% 
  mutate(Pillar = factor(sdata$Pillar[Prog], labels = levels(data$Pillar))

I think that should work but let me know if it doesn't.

The future...

As to the question of some kind of general solution to this problem... I usually try to imagine what the syntax should look like first, then sort out the details later :). In that spirit, one option might be to think about allowing arbitrary expressions to determine the index value to join on, so maybe allowing syntax like this:

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog], d_theta_mu[Pillar = sdata$Pillar[Prog]])

However, that might be hard to figure out automatically. One option would be to turn the derived index (Pillar = sdata$Pillar[Prog]) into a mutate on the previously-generated tables (i.e. the internal output of tidying d_theta[Prog]) before the internal join between the outputs of tidying both parameters, but that could be a fairly opaque / "magical" solution that it would be hard for users to reason about.

Maybe a better option is to explicitly allow the construction of such a "derived" index that the join could then be made on, so something like this:

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog, Pillar = sdata$Pillar[Prog]], d_theta_mu[Pillar])

Where a named index is simply an additional column added to the output evaluated as an expression of columns in the tidied table and the user's environment (thus allowing you to access both Prog and sdata). Then the Pillar column would (1) automatically get its levels back-translated in the same way as usual and (2) be available internally for a join against d_theta_mu's Pillar index. Basically, a bit more explicit version of the previous solution. The downside is it kind of breaks the "index" semantics implied by the use of [ ] and which I like about spread/gather_draws.

Anyway, I think something like this could solve your problem if implemented, but is also verging a bit on "too clever by half" territory, so I'd want to think a bit before committing to implementing it. Would be curious your thoughts.

BKoblinger commented 4 years ago

Thanks for the response. If thought is going to go into this, I think it is worthwhile to have a working example.

library(dplyr)
library(tidyr)
library(tidybayes)
library(rstan)

set.seed(1425)
data <- tribble(
  ~Pillar, ~Prog, ~y,
  "Pillar1", "ProgA", rbinom(50, 1, 0.5),
  "Pillar1", "ProgB", rbinom(40, 1, 0.4),
  "Pillar1", "ProgC", rbinom(30, 1, 0.3),
  "Pillar2", "ProgX", rbinom(20, 1, 0.2),
  "Pillar2", "ProgY", rbinom(10, 1, 0.1)) %>% 
  unnest(y) %>% 
  mutate(across(where(is.character), as.factor))
data %>% group_by(Prog) %>% summarise(p = sum(y) / n(), q = qnorm(p))

sdata <- data %>% compose_data(Pillar = x_at_y(Pillar, Prog))
print(sdata)

smodel <- "
data {
  int<lower=0> n;
  int<lower=0, upper=1> y[n];
  int n_Prog;
  int<lower=1, upper=n_Prog> Prog[n];
  int n_Pillar;
  int<lower=1, upper=n_Pillar> Pillar[n_Prog];
}
parameters {
  real a[n_Prog];
  real mu_a[n_Pillar];
  real<lower=0> sigma_a[n_Pillar];
}
model {
  for (i in 1:n) {
    y[i] ~ bernoulli(Phi(a[Prog[i]]));
  }
  for (pr in 1:n_Prog) {
    a[pr] ~ normal(mu_a[Pillar[pr]], sigma_a[Pillar[pr]]);
  }
  for (pi in 1:n_Pillar) {
    mu_a[pi] ~ normal(0, 1);
    sigma_a[pi] ~ exponential(1);
  }
}
"

fit <- stan(model_code = smodel, data = sdata)

# draws does not contain Pillar for the parameter "a"
draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(a[Prog], mu_a[Pillar])
BKoblinger commented 4 years ago

Based on your suggestion, the solution to my immediate problem is (noting that the groups "get in the way"):

# draws can be wrangled to contain "Pillar"
draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(a[Prog]) %>% 
  ungroup(Prog) %>% 
  mutate(Pillar = factor(sdata$Pillar[Prog], labels = levels(data$Pillar))) %>% 
  group_by(Prog, Pillar, .add = TRUE)
BKoblinger commented 4 years ago

I have been thinking about the syntax. Please note, the simplified example has the parameters a indexed by Prog, and hyper-parameters mu_a and sigma_a indexed by Pillar. The a parameters have a shared prior a ~ normal(mu, sigma_a). Apologies for any confusion.

I originally thought about gather_draws(a[Prog[Pillar]]) or gather_draws(a[Pillar[Prog]]), but find these cryptic or illogical.

I find you idea good:

draws <- fit %>% 
  recover_types(data) %>% 
  gather_draws(d_theta[Prog, Pillar = sdata$Pillar[Prog]], d_theta_mu[Pillar])

It is clear (to me) that Pillar is being joined/added to d_theta[Prog].

I wonder if a separate function would be more in line with the tidyverse philosophy. The previous looks a bit more like datatable syntax. I have yet to come up with a decent idea for the name of the function, or even the function definition. Could the name involve join? index? Or is it just mutate? Could x_at_y() be reused here (not really, I guess)? Perhaps a decent idea will come to me. I guess I imagine something similar to the current solution, but simpler:

draws <- fit %>%
  recover_types(data) %>%
  gather_draws(a[Prog]) %>%
  mutate(Pillar = Pillar[Prog])
mjskay commented 4 years ago

Thanks, this is really helpful!

I agree re: the potential crypticness / non-tidyverse style of this. On the other hand, the indexing syntax is already a little different from tidyverse style. I think the idea of an additional function is a good one; one thing I've been thinking about is adding a kind of alternative tidyverse-like syntax that can be used to create specs inside of spread/gather_draws, in the way that across() can be used inside mutate() in the most recent versions of dplyr. That might be way to introduce something more like the x_at_y() syntax.