nimble-dev / nimble

The base NIMBLE package for R
http://R-nimble.org
BSD 3-Clause "New" or "Revised" License
156 stars 23 forks source link

Strange AD behavior for custom distributions involving dbinom for scalar and vector inputs #1424

Closed weizhangstats closed 3 months ago

weizhangstats commented 6 months ago

When enabling AD for N-mixture models in nimbleEcology, I found a behavior of AD that seems odd to me. For a custom distribution involving dbinom, it does not compile with scalar inputs but works ok for vectors. Below is the code where it does not work if a scalar is given for dbinom:

library(nimble)
dfoo <- nimbleFunction(
  run = function(x = double(),
                 N = integer(),
                 prob = double(),
                 log = integer(0, default = 0)) {
    logProb <- dbinom(x, size = N, prob = prob, log = TRUE)
    if (log) return(logProb)
    else return(exp(logProb))
    returnType(double())
  },
  buildDerivs = list(run = list())
)
registerDistributions("dfoo")
nc <- nimbleCode({
  prob ~ dunif(0, 1)
  x ~ dfoo(N = N, prob = prob)
})
m <- nimbleModel(nc, constants = list(N=10), data = list(x=5), inits = list(prob = 0.1),
                 buildDerivs = TRUE)
cm <- compileNimble(m) ## Fails

In a similar case, if a vector is provided it compiles ok:

dfoo2 <- nimbleFunction(
  run = function(x = double(1),
                 N = integer(),
                 prob = double(),
                 log = integer(0, default = 0)) {
    logProb <- sum(dbinom(x, size = N, prob = prob, log = TRUE))
    if (log) return(logProb)
    else return(exp(logProb))
    returnType(double())
  },
  buildDerivs = list(run = list())
)
registerDistributions("dfoo2")
nc2 <- nimbleCode({
  prob ~ dunif(0, 1)
  x[1:2] ~ dfoo2(N = N, prob = prob)
})
m <- nimbleModel(nc2, constants = list(N=10), data = list(x=c(4,5)), inits = list(prob = 0.1),
                 buildDerivs = TRUE)
cm <- compileNimble(m)

@perrydv Is this something expected?

perrydv commented 4 months ago

Hi @weizhangstats Thanks for this issue and my apologies for not responding to it earlier. It is a good and puzzling case. The problem actually has to do with the type of N. All model variables are doubles. Outside of the AD context, it is harmless to expect automatic casting (conversion) among scalar types, e.g. to make N an integer in dfoo and expect it to be safely passed and cast from the model even though it is a double in the model. But the types are more rigid for AD and it creates a problem if N was declared as an integer. In the second case, the vector x happens to make the C++ compilation use a different template that happens to handle the type of N, I tihnk without having looked in great detail.

Clearly this is a case where either more flexible handling or better error-trapping and messaging would be helpful.

weizhangstats commented 4 months ago

Thanks @perrydv. I found that declaring N as double in the distribution then the first example worked. The issue in the more complicated scenario was that N was also used as part of indices for a vector. Declaring N as double did not work there. The inelegant solution I got for dfoo was that

xx <- nimNumeric(length = 2, value = x)
logProb <- sum(dbinom(xx, size = Nmin, prob = prob, log = TRUE))/2

Agreed that a more flexible handling or better error-trapping and messaging would be helpful.

perrydv commented 4 months ago

@weizhangstats I'm sorry I'm not following why it won't work for you to declare N = double(). (I'm assuming Nmin is the same as N.)

Even if N is a double, it should still work as an index. It should be handled as a stochastic or fixed index, as appropriate. All model variables are doubles.

(If you do need to use your suggested workaround, I think you could use length=1, and it would still then be a vector.)

weizhangstats commented 4 months ago

@perrydv I think I figured out what was happening in the N-mixture distributions where I thought I should declare Nmin and Nmax as integers instead of doubles. The compilation problem came from the line of code for (i in (Nmin+1):Nmax) in the functions. Nmax and Nmin being declared as doubles caused the compilation problem that AD double cannot be converted to int. ADbreak will solve the problem nicely. I will update the code.