justindomke / pangolin

probabilistic programming focused on fun
GNU Affero General Public License v3.0
36 stars 1 forks source link

Checking distribution's parent / observed value in proper range #10

Open xidulu opened 8 months ago

xidulu commented 8 months ago

The code below runs without issues, which is an issue.

w1 = normal_scale(0, 1)
x = bernoulli(w1 - 100)
y = -1
calc = Calculate("numpyro",niter=10000)
ys = calc.sample(w1, x, y)

Maybe worth adding some validation mechanism, like NumPyro's validate_args option: https://num.pyro.ai/en/stable/utilities.html#enable-validation (or can we directly use NumPyro's validation system?)

But it seems to require some significant amount of engineering efforts (gotta implement conditions for every single CondDist), and is not of high priority IMO

justindomke commented 8 months ago

I think that for Pangolin "proper" this is a WONTFIX—I don't think there's any known algorithm that can properly resolve if a collection of random variables is "valid" in this sense or not just through program analysis, (and I somewhat suspect the problem is NP-hard).

So I think these checks probably need to be done at inference time, which means it needs to be done by the inference backend. But I think this can be improved for all the current backends:

xidulu commented 8 months ago

Ok it seems that by simply adding

import numpyro 
numpyro.enable_validation(True)

can solve this issue on the inference end.

justindomke commented 8 months ago

So that can work without any modification needed here?

One thing that's not clear to me from the numpyro docs is that it looks like those checks are removed by JAX's JIT. So I'm not 100% sure which of numpyro's inference routines might actually use JIT...

xidulu commented 8 months ago

From my limited experience, when this flag is turned on, it seems to be able to handle some "basic" cases

Screenshot 2024-03-14 at 13 59 37
w1 = normal_scale(1, 1)
x = bernoulli(w1 - 10)
y = 1
calc = Calculate("numpyro",niter=10000)
ys = calc.sample(w1, [x], [y - 5])

File ~/work/anaconda3/envs/pangolin/lib/python3.11/site-packages/numpyro/distributions/distribution.py:239, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
    237         if not_jax_tracer(is_valid):
    238             if not np.all(is_valid):
--> 239                 raise ValueError(
    240                     "{} distribution got invalid {} parameter.".format(
    241                         self.__class__.__name__, param
    242                     )
    243                 )
    244 super(Distribution, self).__init__()

ValueError: BernoulliProbs distribution got invalid probs parameter.

But yes, when JIT is introduced, this validation mechanism could be ineffective...