nimble-dev / nimble

The base NIMBLE package for R
http://R-nimble.org
BSD 3-Clause "New" or "Revised" License
156 stars 23 forks source link

bug in conjugacy processing for a stickbreaking case #1409

Closed paciorek closed 4 months ago

paciorek commented 8 months ago

In this user example we don't properly handle the fact that there are two deterministic calc node dependencies of eta (v and pi) when trying to set up conjugate sampling for eta.

codeModel <- nimbleCode({
  for(i in 1:N) {
    for(j in 1:J) {
      y[i, j] ~ dbern(pp[z[i], j])
    }
    z[i] ~ dcat(pi[i,1:M])
  }
  for(m in 1:M) {
    for(j in 1:J) {
      pp[m, j] ~ dbeta(shape1=1, shape2=1)
    }
  }
  for(m in 1:(M-1)) {
    knot[m] ~ dunif(0,max_x)
    invband[m] ~ dgamma(1,1)
  }
  alpha ~ dgamma(1,1)
  for(i in 1:N){
    for(m in 1:(M-1)){
      eta[i,m] ~ dbeta(1,alpha)
    }
  }
  for(i in 1:N) {
    for(m in 1:(M-1)) {
      v[i,m] <- exp(-0.5*invband[m]*(knot[m]-x[i])^2)*eta[i,m]
    }
    pi[i,1:M] <- stick_breaking(v[i,1:(M-1)])
  }
})

#============================
# Simulate my dataset
#============================
dat.sim <- function(J, K, N, props, ips) {
  trueclass <- sample(1:K, prob=props, size=N, replace=TRUE)
  dat <- matrix(0, nrow=N, ncol=J)
  for(i in 1:N) {
    dat[i,] <- rbinom(n=J, size=1, p=ips[trueclass[i],])
  }
  dat <- as.data.frame(dat)
  colnames(dat) <- paste0('Y',1:J)
  res <- list(dat = dat, trueclass = trueclass)
  return(res)
}
K <- 3
J <- 6
N <- 400
P <- c(0.5,0.3,0.2)  
pp <- matrix(c(rep(0.9,J), rep(c(0.9,0.1),c(J/2,J/2)), rep(0.1,J)), nrow=K, ncol=J, byrow=T)
set.seed(123)
simdat <- dat.sim(J, K, N, P, pp)
outcome <- simdat$dat
predictor <- rnorm(N,0,1)

N=400; J=6; M=10
constsList <- list(N=400, J=6, M=10, max_x=max(predictor,na.rm=TRUE), x=predictor)
initsList <- list(pp=matrix(rbeta(M*J,1,1),M,J),  
                  alpha=1, eta=matrix(0,N,M),
                  v=matrix(1/M,nrow=N,ncol=M),
                  pi=matrix(1/M,nrow=N,ncol=M),
                  invband=rgamma(M,1,1),
                  knot=runif(M,0,max(predictor,na.rm=TRUE)),
                  z=sample(1:8,size=N,replace=TRUE))
Data <- list(y=as.matrix(outcome))
model <- nimbleModel(code=codeModel, data=Data, inits=initsList, constants=constsList)
cmodel <- compileNimble(model)
conf <- configureMCMC(model, monitors=c('pp'), print=TRUE)
mcmc <- buildMCMC(conf)
# Error: Unexpected error in processing model node: v[1, 2]Unexpected error in processing model node: pi[1, 1:10]

In line 628 of MCMC_conjugacy.R, calcNodesDeterm has two elements and that causes problems in getValueExpr.

                        stickbreakingCheckExpr <- model$getValueExpr(calcNodesDeterm)
                        stickbreakingCheckExpr <- cc_expandDetermNodesInExpr(model, stickbreakingCheckExpr, targetNode = 

Need to figure out if we just need to bail out of conjugacy detection earlier on or if we can accommodate this case. I haven't thought about it further.

The issue is that

danielturek commented 8 months ago

@paciorek Please let me know if this runs into conjugacy code that you want me to update / deal with - otherwise I'll presume it's sufficiently in the stick-breaking realm to be handled by you.