rmcelreath / rethinking

Statistical Rethinking course and book package
2.1k stars 596 forks source link

ulam fits giving odd errors #430

Open jebyrnes opened 3 months ago

jebyrnes commented 3 months ago

I'm working with ulam a bit to show mixed models in a few context. This example:

library(rethinking)
data(reedfrogs)

mod_fullpool <- alist(

  #likelihood
  surv ~ dbinom(density, prob),

  #Data Generating Process
  logit(prob) <- p,

  #Priors
  p ~ dnorm(0,10)
)

fit_fullpool <- ulam(mod_fullpool, data=reedfrogs)
postcheck(fit_fullpool)

however, gives an odd error

Error in pred[[j]][s, ] : subscript out of bounds

I can up the number of chains and then postcheck it. However, I only get a single point, instead of looking at all rows of the reedfrogs data.

Here's my relevant sessioninfo

R version 4.3.2 (2023-10-31)
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Sonoma 14.2.1

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.11.0

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/New_York
tzcode source: internal

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] digest_0.6.35   rethinking_2.40 posterior_1.5.0 cmdstanr_0.7.1 
rmcelreath commented 3 months ago

Some bug specific to binomial outcomes perhaps? I may not have time to figure it out today, but will make an issue of it.

rmcelreath commented 3 months ago

Quick assessment, two things going on:

(1) not enough samples in model, so subscript error. this is a bug, but goes away if you sample more chains.

(2) the logit(prob) <- p line generates only a single value, so postcheck generates only a single prediction. can hack it to work by doing something like:

logit(prob) <- p + 0*density

but will need to add special logic to somehow deal with this kind of thing in general.

rmcelreath commented 3 months ago

I have been thinking of replacing the old postcheck with Aki's PIT plots as default. But I need to make more time in my life to work through my dev notes.

jebyrnes commented 3 months ago

Huh - had not run across LOO-PIT yet. Damn this field keeps evolving. I had however, re-implemented Quantile Residual checks, much like in the DHARMa package.

quantile_residuals <- function(fit, n=1000){
  # get y - from rethinking::postcheck
  lik <- (fit@formula)[[1]]
  outcome <- as.character(lik[[2]])
  if (class(fit) == "ulam") 
    outcome <- undot(outcome)
  y <- fit@data[[outcome]]

  # get the ecdfs
  s <- sim(fit, n = n)
  dists <- apply(s, 1, ecdf)

  #get the quantile residuals
  quant_res <- numeric(ncol(s))
  for(i in 1:length(y)){
    quant_res[i] <- dists[[i]](y[i])
  }

  return(quant_res)
}

This is nice because then you can make QQ plots like so:

data(cars)
flist <- alist(
  dist ~ dnorm( mu , sigma ) ,
  mu <- a+b*speed ,
  c(a,b) ~ dnorm(0,1) , 
  sigma ~ dexp(1)
)
fit <- quap( flist , start=list(a=40,b=0.1,sigma=20) , data=cars )

quantile_residuals(fit) |>
  gap::qqunif(logscale=FALSE)

image

But - I'm guessing this will have similar issues for binomial models? I'd imagine calculating LOO-PITs must be somewhat similar? Need to read, though.....

And, I suppose if you want to make a hot plot like the ones I now see in bayesplot - well, you could do it with quantil residuals as well


# plot against unif dists
get_rep_unifs <- function(samples, n = 100){
  data.frame(
    id = sort(rep(1:samples, n)),
    unif = runif(samples*n)
  )
}

library(ggplot2)
bayes_qr_plot <- function(fit, n = 1000, n_unif = 100, ...){
  q <- quantile_residuals(fit)
  get_rep_unifs(length(q), n_unif) |>
    ggplot(aes(x = unif, group = id)) +
    geom_density(color = alpha("black", 0.3), ...) +
    geom_density(data = data.frame(unif = q, id = 1),
                 linewidth = 3, ...)
}

bayes_qr_plot(fit)

image

Which I suppose could serve the basis of a plotting function. Yes, yes, I know, I prefer ggplot2. I'm very tidy. Or try to be. I'm sure this could be done in base, but, wow, my base brain is.... stale.