nimble-dev / nimble

The base NIMBLE package for R
http://R-nimble.org
BSD 3-Clause "New" or "Revised" License
155 stars 22 forks source link

more general CRP sampling needs to check for unequal number of obs per group #1128

Closed paciorek closed 3 years ago

paciorek commented 3 years ago

We overlooked the possibility that different groups might have different numbers of subjects and CRP_sampler assumes they are the same without checking. The MCMC can be built and run but runs incorrectly. The only indication seems to be a hard-to-interpret warning.

Presumably we can put the check in line 1448 of CRP_samplers.R ( we may also want to check for constant numbers of intermediate nodes ).

Ideally we would allow this. It would require more complicated indexing to accommodate the raggedness.

paciorek commented 3 years ago

Here's an example:

set.seed(1)
library(nimble)
library(coda)

## ----basic params--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
subject.num <- 10 # Number of subjects
nmin <- 3
nmax <- 15
n <- ceiling(runif(subject.num, nmin, nmax)) # Number of measurements
maxtime <- 5
cluster_num <- 4

## ----defining true betas-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
betas <- matrix(c(rep(2, 3),
                c(3.3, 2.2, -0.5),
                c(1.5, 1.75, 2),
                c(2.9, 1.6, 1)),
                nrow = 4,
                byrow = TRUE)
rownames(betas) <- paste("cluster:", seq(1, 4))
colnames(betas) <- paste0("beta_", seq(1, 3))

## ----Simulating data-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
times.mat <- matrix(nrow = 0, ncol = max(n))
sim.vals <- matrix(nrow = 0, ncol = max(n) + 2) # + 2 for subject num & true cluster

for (subject in 1:subject.num) {
  clust.assign <- ceiling(cluster_num * runif(1))

  times <- runif(n[subject], 0, maxtime)
  times <- sort(times)
  times <- c(times, rep(0, max(n) - n[subject]))
  times.mat <- rbind(times.mat, times)

  betas.val <- betas[clust.assign, ]

  values <- rep(0, n[subject])
  errors <- rnorm(n = n[subject], mean = 0, sd = sqrt(1))

  for (i in seq_len(n[subject])) {
    values[i] <- betas.val[1] + (times[i] * betas.val[2]) + ((times[i] ^ 2) * betas.val[3])  + errors[i]
  }

  values <- c(subject, clust.assign, values, rep(0, max(n) - n[subject]))
  sim.vals <- rbind(sim.vals, values)
}

colnames(sim.vals) <- c("subject", "cluster", paste0("time_", seq(1, max(n))))

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
  nburnin <- 100
  thin <- 1
  niter <- 10000

## ----MCMC----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

subject.num <- nrow(sim.vals)

code <- nimbleCode({
  # Priors
  sigma2 ~ dinvgamma(shape = 1, scale = 2)

  # Vector of class memberships for each subject
  xi[1:subject.num] ~ dCRP(conc = alpha, size = subject.num)
  #alpha ~ dgamma(shape = 1, rate = 1)
  alpha ~ dinvgamma(shape = 1, scale = 1)

     # betas
  for (i in 1:subject.num) { # for subject i

    for (j in 1:3) {
      betaTilde[i, j] ~ dnorm(0, 1)
      beta[i, j] <- betaTilde[xi[i], j]
    }

    for (t in 1:n[i]) {
      logy[i, t] ~ dnorm(mean = mu[i, t], sd = sqrt(sigma2))
      mu[i, t] <- beta[i, 1] + (beta[i, 2] * times[i, t]) + beta[i, 3] * (times[i, t] ^ 2)
    }
  }
  })

constants <- list(subject.num = subject.num,
                  n = n)
data <- list(logy = sim.vals[, c(-1, -2)], times = times.mat)
inits <- list(beta = matrix(0, nrow = subject.num, ncol = 3),
              alpha = 1,
              mu <- matrix(1, nrow = subject.num, ncol = max(n)),
              betaTilde = matrix(0, nrow = subject.num, ncol = 3),
              xi = rep(1, subject.num),
              sigma2 = 2.5)

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
model <- nimbleModel(code = code,
                     data = data,
                     inits = inits,
                     constants = constants)
model$initializeInfo()

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
cmodel <- compileNimble(model)

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
conf <- configureMCMC(model, monitors = c('xi', 'beta', 'alpha'), print = TRUE)

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
mcmc <- buildMCMC(conf)
cmcmc <- compileNimble(mcmc, project = model, resetFunctions = TRUE)

## ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
nimbleMCMC_samples_sim <- runMCMC(cmcmc,
                                  nburnin = nburnin,
                                  thin = thin,
                                  niter = niter)
paciorek commented 3 years ago

Being fixed in PR #1152 .