pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.68k stars 2k forks source link

NUTS sampler chain breaks if observational error is less than a factor of 10^-4 of the data #4038

Closed HinLeung622 closed 2 years ago

HinLeung622 commented 4 years ago

If you have questions about a specific use case, or you are not sure whether this is a bug or not, please post it to our discourse channel: https://discourse.pymc.io

Description of your problem

NUTS sampler's chain breaks with the "The derivative of RV variable.ravel()[0] is zero." error whenever an observational error is less than a factor of 10^-4 of the observed data

Please provide a minimal, self-contained, and reproducible example.

import pymc3 as pm
import numpy as np
import seaborn as sns
import theano.tensor as T

# creating fake data
true_mean = 10
true_sigma = 1
N=100
obs_values = np.random.randn(N)*true_sigma+true_mean

model = pm.Model()
with model:
    mean = pm.Normal('mean',15,5)
    sigma = pm.Lognormal('sigma',T.log(1),0.4)
    true_value = pm.Normal('true_value',mean,sigma,shape=N)
    obs_value = pm.Normal('obs_value',true_value,0.0001,observed=obs_values)

with model:
    trace = pm.sample(500,tune=2000, init='adapt_diag', target_accept=0.999, cores=1, chains=4)

Please provide the full traceback.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-73c05cda1b48> in <module>
      1 with model:
----> 2     trace = pm.sample(500,tune=2000, init='adapt_diag', target_accept=0.999, cores=1, chains=4)

~\Anaconda3\lib\site-packages\pymc3\sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, **kwargs)
    510             _log.info("Sequential sampling ({} chains in 1 job)".format(chains))
    511             _print_step_hierarchy(step)
--> 512             trace = _sample_many(**sample_args)
    513 
    514     discard = tune if discard_tuned_samples else 0

~\Anaconda3\lib\site-packages\pymc3\sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
    560             step=step,
    561             random_seed=random_seed[i],
--> 562             **kwargs
    563         )
    564         if trace is None:

~\Anaconda3\lib\site-packages\pymc3\sampling.py in _sample(chain, progressbar, random_seed, start, draws, step, trace, tune, model, **kwargs)
    634     try:
    635         strace = None
--> 636         for it, (strace, diverging) in enumerate(sampling):
    637             if it >= skip_first:
    638                 trace = MultiTrace([strace])

~\Anaconda3\lib\site-packages\fastprogress\fastprogress.py in __iter__(self)
     45         except Exception as e:
     46             self.on_interrupt()
---> 47             raise e
     48 
     49     def update(self, val):

~\Anaconda3\lib\site-packages\fastprogress\fastprogress.py in __iter__(self)
     39         if self.total != 0: self.update(0)
     40         try:
---> 41             for i,o in enumerate(self.gen):
     42                 if i >= self.total: break
     43                 yield o

~\Anaconda3\lib\site-packages\pymc3\sampling.py in _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
    735                 step = stop_tuning(step)
    736             if step.generates_stats:
--> 737                 point, stats = step.step(point)
    738                 if strace.supports_sampler_stats:
    739                     strace.record(point, stats)

~\Anaconda3\lib\site-packages\pymc3\step_methods\arraystep.py in step(self, point)
    247 
    248         if self.generates_stats:
--> 249             apoint, stats = self.astep(array)
    250             point = self._logp_dlogp_func.array_to_full_dict(apoint)
    251             return point, stats

~\Anaconda3\lib\site-packages\pymc3\step_methods\hmc\base_hmc.py in astep(self, q0)
    128                 (np.abs(check_test_point) >= 1e20) | np.isnan(check_test_point)
    129             ]
--> 130             self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
    131             message_energy = (
    132                 "Bad initial energy, check any log probabilities that "

~\Anaconda3\lib\site-packages\pymc3\step_methods\hmc\quadpotential.py in raise_ok(self, vmap)
    239                 errmsg.append('The derivative of RV `{}`.ravel()[{}]'
    240                               ' is zero.'.format(*name_slc[ii]))
--> 241             raise ValueError('\n'.join(errmsg))
    242 
    243         if np.any(~np.isfinite(self._stds)):

ValueError: Mass matrix contains zeros on the diagonal. 
The derivative of RV `mean`.ravel()[0] is zero.

Please provide any additional information below. Scaling both the data and the obs error up by the same factor does not solve the problem, so it is likely a relative problem.

Versions and main components

michaelosthege commented 4 years ago

With a sd=0.0001 you're just running into floating point problems here. A few things to try - a subset of which could solve the problem:

On another note, my gut feeling is that with observational uncertainty orders of magnitude narrower than the priors, you may effectively get the same result as a simple frequentist mean & standard error over the data. (In the example above.)

HinLeung622 commented 4 years ago

Thanks for your reply, I have tried the first two suggestions, but the combination of the two still results in the same error.

How do I initialize the sampling such that true_value is already obs_value? I need to use the true_value line to generate the population while the obs_value line to give the observational errors and observed data. How to combine them into one?

To answer your last thought, yes, it is probably true that a frequentist approach would be equally good for model in this example. However, for the actual model I am running in my project, it has 2 different observed variables and the fundamental values are 2 or 3 layers of variable conversion away from the observed variables, with 2 forward modelling trained neural networks in the middle handling parts of the conversion (due to the lack of direct physical theories and conversion between them), so using a hierarchical Bayesian model does fit my problem.

For reference, here is the full model in my project:

model = pm.Model()
with model:
    #polulation wide fundamentals
    Age_mu = pm.Deterministic('mean_age',pm.Beta('a',10,10)*2+2.5)
    feh_mu = pm.Deterministic('mean_feh',pm.Beta('e',10,10)*0.4-0.2)
    Y_mu = pm.Deterministic('mean_Y',pm.Beta('f',10,10)*0.04+0.24)
    MLT_mu = pm.Deterministic('mean_MLT',pm.Beta('g',10,10)*0.6+1.7)

    #individual fundamentals
    M = pm.Deterministic('mass', pm.Beta('d',10,10,shape=N)*(1.33-0.8)+0.8)
    Age = pm.Deterministic('age',T.ones(N)*Age_mu)
    feh = pm.Deterministic('feh',T.ones(N)*feh_mu)
    Y = pm.Deterministic('Y',T.ones(N)*Y_mu)
    MLT = pm.Deterministic('MLT',T.ones(N)*MLT_mu)

    #first NN conversion
    obs = pm.Deterministic('obs',m1.manualPredict(T.log10([M, Age, 10**feh, Y, MLT])))

    #intermediate variables
    radius = pm.Deterministic('radius', 10**obs[0])
    Teff = pm.Deterministic('Teff', (10**obs[1])*5000)
    Q = pm.Uniform('binary_frac', lower=0.0, upper=0.5, shape=N)
    L = pm.Deterministic('L', (1-Q)*(radius**2)*((Teff/Teff_sun)**4)+Q*2*(radius**2)*((Teff/Teff_sun)**4))
    logg = pm.Deterministic('logg', T.log10(100*constants.G.value*(M/radius**2)*(constants.M_sun.value/constants.R_sun.value**2)))
    Av_list = pm.Deterministic('Av', T.ones(N)*Av)

    #second NN conversion
    BCs = pm.Deterministic('BCs', t1.manualPredict(T.as_tensor_variable([T.log10(Teff), logg, feh, Av_list])))

    BCg = pm.Deterministic('BCg', BCs[5,:])
    BCbp = pm.Deterministic('BCbp', BCs[7,:])
    BCrp = pm.Deterministic('BCrp', BCs[8,:])

    true_mG = pm.Deterministic('true_mG', -2.5*T.log10(L)+Mbol-BCg+dist_mod)
    true_Bp_Rp = pm.Deterministic('true_Bp_Rp', BCrp-BCbp)

    #observed data, M67['g_mag_err'] is on the order of 10^-4 of M67['g_mag']
    obs_mG = pm.Normal('obs_mG', true_mG, M67['g_mag_err']*100, observed=M67['g_mag'])
    obs_Bp_Rp = pm.Normal('obs_Bp_Rp', true_Bp_Rp, M67['Bp_Rp_err'], observed=M67['Bp_Rp'])

Thanks.

michaelosthege commented 4 years ago

With pm.sample(start=...) you can pass a dict that maps the name of a random variable in your model to a float/array for its initial value in the sampling. However, even if the sampling starts at a point that gives a nice likelihood and gradient, it could still drift off into a numerically pathological region of the parameter space!

You also might want to check the value of the likelihood and gradient at the model.test_point:

model.logp(model.test_point)
model.dlogp(model.free_RVs)(model.test_point)
HinLeung622 commented 4 years ago

Thanks for the explanation. I have implemented both the float64 and giving the true values observed data as the starting point on the mock model, and it fixes the issue! However, it does not fix my full-scale model.

I don't know how to read the outputs of logp and dlogp, so here is the output of the full-scale model:

logp: -777890323.28662
dlogp: [-7.83470629e+07  2.12049029e+08 -1.35129507e+08 -2.28975881e+06
 -5.72165659e+05  3.04145722e+06 -8.08268766e+06 -6.50011030e+06
 -3.31067207e+06  4.59959730e+06 -3.80918530e+06 -6.56930932e+06
 -8.43126743e+06 -6.76844366e+06 -6.12452571e+06 -4.88378686e+06
 -7.82217555e+06 -6.10546548e+06 -7.23430117e+06  2.37862218e+06
 -6.36485759e+06 -6.64689946e+06  2.19509209e+06 -6.50092200e+06
 -6.93251893e+06 -6.14472264e+06 -5.88982469e+06 -2.36942741e+06
 -8.61828936e+06 -6.45810708e+06 -8.19639619e+06  3.18719153e+06
 -6.98822282e+06 -8.87784084e+06  2.59704129e+06 -6.42374952e+06
 -2.65006939e+05 -7.55971837e+06 -7.14602197e+06 -1.08523159e+07
  3.95566591e+06  6.13902002e+06 -5.94245877e+06  1.25514433e+05
  3.08769503e+05 -8.44474550e+06 -7.21845604e+06 -1.67457189e+06
 -6.78266835e+06 -6.61544399e+06 -6.59726163e+06 -1.02292657e+07
 -7.32956651e+06 -4.51349840e+06 -9.63997007e+06 -7.16461525e+06
 -7.92737253e+06 -8.16563360e+06 -7.12999461e+06 -3.27201303e+06
 -4.76942371e+06 -7.60179922e+06 -6.95694710e+06  6.61332778e+06
 -6.82613239e+06 -6.81786379e+06 -5.45344136e+06  3.49184839e+06
 -2.37459793e+06 -4.67844284e+06 -8.93297502e+06 -3.35333422e+06
 -5.87890026e+06 -7.05217739e+06 -8.61440480e+06  1.98315122e+06
  2.84087612e+06 -4.79751293e+05 -2.89187916e+06  2.85695742e+06
 -3.02504046e+06 -4.42277019e+06 -7.29230853e+06 -7.90707504e+06
  5.52700623e+06 -6.90377946e+06  4.63554173e+06 -8.64695618e+06
 -7.00881349e+06 -5.72284348e+06 -4.35488809e+06  5.55674675e+06
 -4.09657781e+06 -7.37222362e+05 -6.91140389e+06 -1.40801020e+06
  3.15035504e+06 -5.58112274e+06 -2.31386335e+06 -2.00725625e+06
 -4.74579401e+06 -4.03931000e+06 -5.60053675e+06 -1.74780037e+06
 -3.63340910e+06 -4.03663711e+06 -1.14988023e+06 -3.29699591e+06
 -1.54303040e+06 -3.56698689e+06  4.57374021e+06  3.18512283e+06
 -3.86952570e+06 -7.00918532e+06 -5.19530897e+06 -2.93311446e+06
 -4.31610584e+06 -2.64913296e+06 -4.44932880e+06 -8.09745171e+06
 -5.32565604e+06 -6.16804817e+06 -3.38144689e+06 -5.56140450e+06
 -2.93689517e+06 -3.72993875e+06 -7.07034857e+06 -5.40707469e+06
 -6.50711365e+06  2.50504582e+06 -2.23718535e+06  3.79573959e+06
 -1.77228950e+06 -5.94571199e+06 -3.74378928e+06 -2.56515622e+06
 -1.81560621e+06  2.26994597e+06 -1.15407951e+06 -5.45236294e+06
 -4.88472480e+06 -1.41863160e+06 -7.88147241e+05 -3.75526044e+06
 -6.12762766e+06  4.20933738e+06 -3.94118751e+06  4.33150573e+06
 -4.60721694e+06  6.33519415e+06 -6.54219482e+06 -5.59988940e+06
 -6.41223959e+06 -5.00402044e+06 -5.52785690e+06 -4.78360247e+06
 -2.27601746e+06 -3.71297052e+06  2.23261410e+06 -4.68685937e+06
 -4.35504781e+06  6.66916427e+06 -3.96414672e+06  3.34313871e+05
  2.28752763e+06 -3.72624095e+06 -3.96122237e+06 -3.71820140e+06
  2.53221503e+06 -4.72911890e+06 -4.36932936e+06 -5.12431622e+06
  3.56159807e+06 -5.16072044e+06 -5.43976919e+06 -4.53494584e+06
  2.96698703e+06 -6.14760972e+06 -4.52611450e+06 -5.53339511e+06
 -5.29109193e+06 -5.04247268e+06 -5.79805705e+06 -5.98113396e+06
 -4.44153551e+06 -5.66575171e+06 -1.41730049e+06  8.06588187e+06
  3.48479753e+06  6.37624723e+06 -8.85460066e+06 -6.26228025e+06
 -4.06765889e+06 -5.88100176e+06 -6.25792576e+06 -6.15988107e+06
  2.95957763e+06 -7.70740921e+06 -6.84963274e+06  5.88945732e+06
 -2.07582892e+06 -2.67897766e+06 -4.96436122e+06 -3.01223282e+05
  7.42820159e+06 -4.81270080e+06 -6.20353142e+06 -6.31164606e+06
 -8.73706958e+06 -8.56258700e+06 -3.71339580e+06  1.63347830e+06
 -7.41539468e+06 -4.82377141e+06 -6.74537019e+06 -6.82937828e+06
 -6.39148202e+06 -2.05444112e+06 -2.82057600e+06 -6.33945051e+06
 -5.31935645e+06  3.17093881e+06 -5.55895242e+06 -6.05050152e+06
 -7.04739126e+06  6.10345916e+06  5.00155980e+06 -5.67977844e+06
 -5.84395997e+06 -5.89494321e+06 -2.21144134e+06 -6.43793227e+06
 -7.27032933e+06 -2.87473989e+06 -5.61966716e+06 -5.54724155e+06
  2.28973728e+06 -7.79051416e+06 -6.58057361e+06 -1.78508723e+06
 -4.95650181e+06 -5.83841216e+06 -3.94369911e+06 -5.88643321e+06
 -5.15041949e+06 -9.07446743e+06 -8.86957028e+06 -5.88580175e+06
 -7.69623798e+06 -7.27874296e+06 -5.42446349e+06 -7.69412375e+06
  9.00752402e+06 -6.05821451e+06 -6.89384222e+06 -6.31357949e+06
 -1.50965806e+06 -6.70159242e+06 -7.91817381e+06 -6.31925124e+06
 -5.54105187e+06  2.55236383e+06  2.72276725e+06  3.67282627e+06
 -5.75455919e+06 -6.65381778e+06 -3.97545024e+06 -6.45450526e+06
 -4.53109395e+06 -5.91966173e+06  3.37430622e+06 -4.71735265e+06
 -5.72392841e+05 -7.07485011e+06 -7.03239565e+06  1.82047710e+06
 -5.31765009e+06 -5.27417543e+06 -7.59675801e+06  5.51416425e+05
 -5.59093384e+06 -4.43179118e+06 -5.71673829e+05 -4.89187046e+06
  3.47511008e+06 -4.88619014e+06 -6.29074658e+06 -8.18809151e+06
 -7.61366506e+06 -3.99211191e+06 -6.89301782e+06 -5.20634380e+06
 -8.15950777e+06 -6.67331861e+06  9.51983027e+05 -6.87029692e+06
  2.33672446e+06 -3.89074476e+06 -5.02253011e+06]
michaelosthege commented 4 years ago

These numbers are very large, indicating extremely strong gradients. This is likely throwing NUTS off. In addition, I wouldn't trust the floating point precision at such extreme values.

HinLeung622 commented 4 years ago

So you mean... The combination of this model plus the data and its error fundamentally challenges the NUTS sampler? Is there anything I can do other than non-scientifically scaling the observational errors up by some big factor?

ricardoV94 commented 2 years ago

Closing this due to inactivity. If anyone finds similar problems feel (re)open an issue