stan-dev / math

The Stan Math Library is a C++ template library for automatic differentiation of any order using forward, reverse, and mixed modes. It includes a range of built-in functions for probabilistic modeling, linear algebra, and equation solving.
https://mc-stan.org
BSD 3-Clause "New" or "Revised" License
738 stars 186 forks source link

multinomial_rng throws exception when total count N is zero #3057

Closed chvandorp closed 4 months ago

chvandorp commented 5 months ago

Description

When the total count is set to zero in the RNG for the multinomial distribution, an exception is thrown:

Exception: multinomial_rng: number of trials variables is 0, but must be positive!

However, the multinomial_lpmf function accepts zero integer arrays (and the log-prob is always 0.0, as expected). I think it would make more sense if the RNG also accepts a zero total count and just returns a zero count vector: If you don't sample anything, then all the counts per category should also be zero. Allowing this can be useful for e.g. encoding missing data.

If you agree, I can submit a simple PR

Example

Stan model. In the generated quantities block, sum(n) can't be zero

data {
    int<lower=1> k; // number of categories
    array[k] int<lower=0> n; // number of samples per category
}

parameters {
    simplex[k] p; // sampling probability per category
}

model {
    n ~ multinomial(p);
}

generated quantities {
    array[k] int n_sim = multinomial_rng(p, sum(n)); // simulated samples
}

Python code to run the model. Running the model with good_data works fine. If we use bad_data, Exceptions are thrown and the generated quantities are all NaN.

import cmdstanpy
cmdstanpy.show_versions()
sm = cmdstanpy.CmdStanModel(stan_file="multinomial_zero_error.stan")

good_data = {
    "k" : 3,
    "n" : [3,2,5]
}

sam = sm.sample(data=good_data, iter_sampling=100)
print(sam.summary())

# set all counts to zero. Not an issue for multinomial_lpmf
# but multinomial_rng throws an exception

bad_data = {
    "k" : 3,
    "n" : [0,0,0]
}

sam = sm.sample(data=bad_data, iter_sampling=100)
print(sam.summary())

Expected output

Result with good_data:

               Mean      MCSE    StdDev  ...    N_Eff  N_Eff/s     R_hat
lp__     -14.756800  0.060222  0.915617  ...  231.160  38526.7  1.000060
p[1]       0.313440  0.006897  0.119988  ...  302.615  50435.9  1.002250
p[2]       0.220428  0.006402  0.103776  ...  262.742  43790.3  0.998186
p[3]       0.466132  0.006631  0.131760  ...  394.827  65804.5  1.003410
n_sim[1]   3.150000  0.108731  1.825060  ...  281.735  46955.9  1.001100
n_sim[2]   2.265000  0.096356  1.631390  ...  286.655  47775.9  0.997755
n_sim[3]   4.585000  0.128205  2.011940  ...  246.273  41045.4  1.008090

Result with bad_data

Exception: multinomial_rng: number of trials variables is 0, but must be positive! (in 'multinomial_zero_error.stan', line 15, column 4 to column 52)
    Exception: multinomial_rng: number of trials variables is 0, but must be positive! (in 'multinomial_zero_error.stan', line 15, column 4 to column 52)
...
              Mean      MCSE    StdDev  ...    N_Eff   N_Eff/s     R_hat
lp__     -4.470830  0.091875  1.141650  ...  154.409   5514.62  1.013780
p[1]      0.347875  0.012957  0.233269  ...  324.102  11575.10  0.996473
p[2]      0.335542  0.012238  0.229005  ...  350.137  12504.90  0.999920
p[3]      0.316583  0.011180  0.234400  ...  439.615  15700.50  0.996259
n_sim[1]       NaN       NaN       NaN  ...      NaN       NaN       NaN
n_sim[2]       NaN       NaN       NaN  ...      NaN       NaN       NaN
n_sim[3]       NaN       NaN       NaN  ...      NaN       NaN       NaN

Current Version:

python: 3.10.6 (v3.10.6:9c7b4bd164, Aug 1 2022, 17:13:48) [Clang 13.0.0 (clang-1300.0.29.30)] cmdstan: (2, 34) cmdstanpy: 1.2.1

andrjohns commented 5 months ago

I think it's fine to update the RNG to check that N is non-nonnegative, instead of positive - since this is what the binomial_rng does