rmcelreath / rethinking

Statistical Rethinking course and book package
2.14k stars 603 forks source link

How to marginalize over random effects with `ulam()` model? #352

Open wzbillings opened 2 years ago

wzbillings commented 2 years ago

When I use map2stan(), I can easily use link() to get marginal predictions without random effects by replacing part of the data with 0.

However, when I do the same thing using a model fitted with ulam(), I get a "non-conformable array" error. (Error in alpha[TRUE, species[i]] + beta * sepal_width[i] : non-conformable arrays)

Is there a way to get the same marginal predictions from a model fitted with ulam()?

I've included a small example with a random intercept below. When I run the example on my machine (Windows 10, rethinking 2.21, latest versions of rstan and cmdstanr), the link code works for the map2stan model but not for the ulam model.

# Create fake data with iris to use for fitting
test_dat <-
    iris[,c("Species", "Sepal.Width", "Petal.Width")] |>
    `names<-`(c("species", "sepal_width", "petal_width")) |>
    as.list()

# Data for marginalizing over random effects: replace Species with 0
marginalizing_data <- test_dat
marginalizing_data$species <- 0

# Fit model with ulam
test_model <- ulam(
    data = test_dat,
    flist = alist(
        petal_width ~ dnorm(mu, sigma),
        mu <- alpha[species] + beta * sepal_width,
        alpha[species] ~ dnorm(mu_alpha, gamma),
        mu_alpha ~ dnorm(0, 10),
        gamma ~ dcauchy(0, 2),
        beta ~ dnorm(0, 10),
        sigma ~ dcauchy(0, 2)
    ),
    constraints = list(
        sigma = "lower=0",
        gamma = "lower=0"
    ),
    seed = 12345,
    chains = 4
)

# Fit model with map2stan
test_model2 <- map2stan(
    data = test_dat,
    flist = alist(
        petal_width ~ dnorm(mu, sigma),
        mu <- alpha[species] + beta * sepal_width,
        alpha[species] ~ dnorm(mu_alpha, gamma),
        mu_alpha ~ dnorm(0, 10),
        gamma ~ dcauchy(0, 2),
        beta ~ dnorm(0, 10),
        sigma ~ dcauchy(0, 2)
    ),
    constraints = list(
        sigma = "lower=0",
        gamma = "lower=0"
    ),
    rng_seed = 12345,
    chains = 4
)

# Try getting marginal predictions with both
link_ulam <- link(test_model,  data = marginalizing_data)
link_m2s  <- link(test_model2, data = marginalizing_data)

To get the marginal effects, I tried to replicate what it appears the link() function does in the first place by manually modifying my posterior samples, and then passing these posterior samples to the link() function as follows.

post_samples <- extract.samples(test_model)
post_samples$mu_alpha <- array(0, dim = dim(this$mu_alpha))
link_ulam <- link_ulam(test_model, data = test_dat, post = post_samples)

Will this work to get marginal predictions without random effects?