stan-dev / cmdstanpy

CmdStanPy is a lightweight interface to Stan for Python users which provides the necessary objects and functions to compile a Stan program and fit the model to data using CmdStan.
BSD 3-Clause "New" or "Revised" License
149 stars 67 forks source link

Python and CmdStan quantiles differ #762

Closed bob-carpenter closed 1 month ago

bob-carpenter commented 1 month ago

Summary:

I fit a model and calculated quantiles using CmdStan (through cmdstanpy) and using Python. They give close, but different answers. I suspect the quantiles are broken in CmdStan, but I thought I'd file here first and then the issue can be moved to CmdStan if necessary.

Description:

See above. Here's a minimal-ish working example, which I put in sim.py:

import cmdstanpy as csp
import numpy as np
import logging
csp.utils.get_logger().setLevel(logging.ERROR)

model = csp.CmdStanModel(stan_file='funnel.stan')
init = {'double_log_scale': 0, 'alpha': np.zeros(9)}
mass_matrix = {'inv_metric': np.ones(10)}
epsilon = 0.0025
print(f"\n\n epsilon={epsilon:6.4f}")
fit = model.sample(inits=init, chains=1, step_size=epsilon, iter_warmup=0, adapt_engaged=False, iter_sampling=2_000, metric=mass_matrix, show_progress=False) 
print(fit.summary(percentiles=(2.5, 50, 97.5)))
metadata_param_draws = fit.draws(concat_chains=True)
print(f"{metadata_param_draws.shape=}")
draws = metadata_param_draws[:, 7:]
print(f"{draws.shape=}")
quantiles = np.quantile(draws, [0.025, 0.5, 0.975], axis=0)
for n in range(10):
    print(f"     theta[{n}] : ({quantiles[0, n]:7.3f}, {quantiles[1, n]:7.3f}  {quantiles[2, n]:7.3f})")

Here's the Stan program in funnel.stan:

parameters {
  real double_log_scale;
  vector[9] alpha;
}
model {
  double_log_scale ~ normal(0, 3);
  alpha ~ normal(0, exp(double_log_scale / 2));
}

And here's what I get:

~/temp2/nuts-funnel$ python3 sim.py
 epsilon=0.0025
                      Mean      MCSE    StdDev      2.5%       50%     97.5%    N_Eff   N_Eff/s     R_hat
lp__             -3.095650  3.124210  12.31130 -24.04040 -2.915350  21.59580  15.5285   8.36663  1.273310
double_log_scale -0.402755  0.690198   2.76107  -6.29412 -0.337422   4.11488  16.0032   8.62241  1.275270
alpha[1]         -0.225077  0.216606   2.06038  -5.99571 -0.009239   4.09132  90.4802  48.75010  1.034430
alpha[2]          0.058789  0.339706   2.27934  -5.39011  0.012140   5.83297  45.0208  24.25690  1.036400
alpha[3]          0.093776  0.302027   2.27817  -5.80824 -0.000979   5.73368  56.8957  30.65500  0.999740
alpha[4]         -0.509345  0.618755   3.10420 -10.09710 -0.018771   4.64869  25.1688  13.56080  1.020510
alpha[5]          0.186080  0.329267   2.37797  -4.13391 -0.013949   8.00979  52.1574  28.10210  0.999694
alpha[6]         -0.096148  0.318258   2.25711  -6.20796 -0.008866   4.91267  50.2976  27.10000  1.022940
alpha[7]         -0.418351  0.279737   2.40307  -6.62655 -0.054300   4.88067  73.7961  39.76080  1.008840
alpha[8]         -0.591723  0.536642   3.12854 -12.05920 -0.014100   3.39054  33.9871  18.31200  1.033610
alpha[9]          0.159583  0.362457   2.72899  -5.90080  0.010879   6.93319  56.6879  30.54310  1.005150
metadata_param_draws.shape=(2000, 17)
draws.shape=(2000, 10)
     theta[0] : ( -6.287,  -0.341    4.115)
     theta[1] : ( -5.995,  -0.010    4.092)
     theta[2] : ( -5.347,   0.012    5.834)
     theta[3] : ( -5.764,  -0.001    5.734)
     theta[4] : ( -9.998,  -0.019    4.651)
     theta[5] : ( -4.110,  -0.014    8.010)
     theta[6] : ( -6.201,  -0.009    4.915)
     theta[7] : ( -6.604,  -0.055    4.883)
     theta[8] : (-12.016,  -0.014    3.391)
     theta[9] : ( -5.831,   0.011    6.938)

You can see that the quantiles are close, but not quite spot on. At first I thought this might be rounding of the 2.5 and 97.5, but specifying 2.5 also doesn't match the 0.02 or 0.03 quantiles.

Current Version:

>>> import cmdstanpy; cmdstanpy.show_versions()
INSTALLED VERSIONS
---------------------
python: 3.10.8 (main, Oct 13 2022, 09:48:40) [Clang 14.0.0 (clang-1400.0.29.102)]
python-bits: 64
OS: Darwin
OS-release: 23.5.0
machine: arm64
processor: arm
byteorder: little
LC_ALL: None
LANG: en_US.UTF-8
LOCALE: ('en_US', 'UTF-8')
cmdstan_folder: /Users/bcarpenter/.cmdstan/cmdstan-2.35.0
cmdstan: (2, 35)
cmdstanpy: 1.2.2
pandas: 2.0.1
xarray: None
tqdm: 4.65.0
numpy: 1.23.4

' '
WardBrian commented 1 month ago

Investigating now, I suspect that neither is actually "incorrect" but the difference lies in how ties are broken. The doc for method on np.quantile is pretty dense

WardBrian commented 1 month ago

My instinct was right, but it is still a bit odd.

The Stan behavior changes at p=0.5, so no one argument to method will work, but this does:

import cmdstanpy as csp
import numpy as np
import logging
csp.utils.get_logger().setLevel(logging.ERROR)

def stan_like_quantiles(a, q, axis=None):
    out = []
    for p in q:
        if p < 0.5:
            out.append(np.quantile(a, p, axis=axis, method='lower'))
        else:
            out.append(np.quantile(a, p, axis=axis, method='nearest'))
    return np.array(out)

model = csp.CmdStanModel(stan_file='funnel.stan')
init = {'double_log_scale': 0, 'alpha': np.zeros(9)}
mass_matrix = {'inv_metric': np.ones(10)}
epsilon = 0.0025
print(f"\n\n epsilon={epsilon:6.4f}")
fit = model.sample(inits=init, chains=1, step_size=epsilon, iter_warmup=0, adapt_engaged=False, iter_sampling=2_000, metric=mass_matrix, show_progress=False)
print(fit.summary(percentiles=(2.5, 50, 97.5)))
metadata_param_draws = fit.draws(concat_chains=True)
print(f"{metadata_param_draws.shape=}")
draws = metadata_param_draws[:, 7:]
print(f"{draws.shape=}")

quantiles = stan_like_quantiles(draws, [0.025,0.5,0.975], axis=0)

for n in range(10):
    print(f"     theta[{n}] : ({quantiles[0, n]:7.5f}, {quantiles[1, n]:7.5f}  {quantiles[2, n]:7.5f})")

prints

                      Mean      MCSE    StdDev      2.5%       50%     97.5%     N_Eff   N_Eff/s     R_hat
lp__              1.845180  2.485300  10.53630 -17.95740  2.236310  21.90610   17.9728   10.3949  1.002140
double_log_scale -1.488500  0.567439   2.41004  -6.39507 -1.487180   3.04691   18.0390   10.4332  1.002080
alpha[1]         -0.122363  0.063339   1.10582  -3.16798 -0.016352   2.04890  304.8030  176.2890  1.004150
alpha[2]          0.076506  0.178563   1.63905  -3.63104  0.005829   4.66082   84.2563   48.7312  1.005280
alpha[3]          0.039439  0.115609   1.37654  -2.92992  0.006639   3.48889  141.7750   81.9981  1.006530
alpha[4]          0.161731  0.145244   1.50736  -2.23265  0.007861   4.41500  107.7050   62.2930  1.013910
alpha[5]         -0.017582  0.079393   1.21976  -2.73845  0.007058   2.45055  236.0390  136.5180  1.018140
alpha[6]         -0.090614  0.126625   1.44340  -4.65026 -0.001944   2.47281  129.9370   75.1514  0.999589
alpha[7]         -0.187405  0.121376   1.42942  -3.91546 -0.010271   2.35533  138.6920   80.2152  1.018600
alpha[8]          0.167995  0.132302   1.47003  -2.54510  0.009134   4.60312  123.4580   71.4042  1.006310
alpha[9]         -0.073094  0.069720   1.12135  -2.79059 -0.006392   2.29814  258.6800  149.6130  1.003620
metadata_param_draws.shape=(2000, 17)
draws.shape=(2000, 10)
     theta[0] : (-6.39507, -1.48718  3.04691)
     theta[1] : (-3.16798, -0.01635  2.04890)
     theta[2] : (-3.63104, 0.00583  4.66082)
     theta[3] : (-2.92992, 0.00664  3.48889)
     theta[4] : (-2.23265, 0.00786  4.41500)
     theta[5] : (-2.73845, 0.00706  2.45055)
     theta[6] : (-4.65026, -0.00194  2.47281)
     theta[7] : (-3.91546, -0.01027  2.35533)
     theta[8] : (-2.54510, 0.00913  4.60312)
     theta[9] : (-2.79059, -0.00639  2.29814)
WardBrian commented 1 month ago

I think if the above behavior is something we should change the right place to open an issue is Stan. It appears that code has been essentially unchanged since 2013 at this point