nimble-dev / nimble

The base NIMBLE package for R
http://R-nimble.org
BSD 3-Clause "New" or "Revised" License
160 stars 24 forks source link

Bug in `makeModelDerivsInfo` for multiply split LHSinferred nodes #1511

Open perrydv opened 2 weeks ago

perrydv commented 2 weeks ago

makeModelDerivsInfo takes steps to find immediate parent nodes that are needed for the AD tape and would not otherwise be used (e.g. they are not in wrt or calc nodes).

Note that what is needed is not really parent nodes but just any arbitrary parent elements involved in calculations. For example x[1:5] might be a vector node but if a calculation involved in AD only needs x[3], then that (and not x[1:5]) is the arbitrary parent element that needs to be found in makeModelDerivsInfo.

We have model$getParents (which calls C++), but this does not appear to have an option to get just the elements. Perhaps for that reason (I don't fully remember), makeModelDerivsInfo calls a bespoke version called getImmediateParentNodes. In pure R it manipulates pieces of the modelDef$maps and can return just the needed elements. So it's processing is completely distinct from model$getParents.

The problem is that "just the needed elements" can be nodes of type LHSinferred. (E.g. if x[1:5] was declared as a stochastic node, while potentially arbitrary subsets of its elements were used in RHS calculations, such as z[1] <- foo(x[1]), then x[1] a LHSinferred node.) For particular types of splitting, the node name can include %.s% which is designed purely as a syntax scheme and is never intended to be evaluated in one of the environments where we evaluate node names to get graphIDs etc.

getImmediateParentNodes can end up with one of the %.s% LHSinferred names and then pass it into further node processing, which causes an error.

Here is a reproducible example.

modelCode <- nimbleCode({
  for(i in 1:5) alpha[i] ~ dhalfflat()
  x[1:5] ~ ddirch(alpha[1:5])
  z[1] <- x[1]
  z[2] ~ dunif(0, mean(x[1:5]))
  z[3] <- x[3]
  z[4] ~ dunif(0, mean(x[1:5]))
  z[5] <- x[5]
  for(i in 1:5) {
    y[i] ~ dbinom(z[i], size = 1)
  }
})

m <- nimbleModel(modelCode, data = list(y = c(0,1,0,1,0)))
makeModelDerivsInfo(m, wrt = wrt, calcNodes = cn) # error
# break down the steps:
wrt <- c("x")
cn <- m$getDependencies(wrt)
nimble:::getImmediateParentNodes("z", m) # return a %.s% notation. This triggers the later error
nimble:::getImmediateParentNodes("z[1]", m) # has x[1]. This is correct
nimble:::getImmediateParentNodes("z[2]", m) # has the bad %.s%.
# Can we switch to m$getParents
m$getParents("z[1]", includeRHSonly = TRUE, immediateOnly=TRUE) # x[1:5]. # This is valid but will be inefficient later because from z[1] we only need x[1].
m$getParents("z[2]", includeRHSonly = TRUE, immediateOnly=TRUE) # x[1:5]. # This is valid and correct because from z[2] we really do need x[1:5]

Notice that the double split arising for z[2] and z[4] is necessary to invoke the problem.

Possible solutions:

  1. Stick with the two processing pathways and fix getImmediateParentNodes to never return a "%.s%" case.
  2. Use m$getParents as is and accept that sometimes this will result in inefficiency.
  3. Modify m$getParents to allow returning LHSinferred elements but never "%.s%" cases. I think this would require changes at the C++ level.

I'm thinking option 1 is the way to go but haven't tried it.