Open xidulu opened 8 months ago
I think that for Pangolin "proper" this is a WONTFIX—I don't think there's any known algorithm that can properly resolve if a collection of random variables is "valid" in this sense or not just through program analysis, (and I somewhat suspect the problem is NP-hard).
So I think these checks probably need to be done at inference time, which means it needs to be done by the inference backend. But I think this can be improved for all the current backends:
Ok it seems that by simply adding
import numpyro
numpyro.enable_validation(True)
can solve this issue on the inference end.
So that can work without any modification needed here?
One thing that's not clear to me from the numpyro docs is that it looks like those checks are removed by JAX's JIT. So I'm not 100% sure which of numpyro's inference routines might actually use JIT...
From my limited experience, when this flag is turned on, it seems to be able to handle some "basic" cases
w1 = normal_scale(1, 1)
x = bernoulli(w1 - 10)
y = 1
calc = Calculate("numpyro",niter=10000)
ys = calc.sample(w1, [x], [y - 5])
File ~/work/anaconda3/envs/pangolin/lib/python3.11/site-packages/numpyro/distributions/distribution.py:239, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
237 if not_jax_tracer(is_valid):
238 if not np.all(is_valid):
--> 239 raise ValueError(
240 "{} distribution got invalid {} parameter.".format(
241 self.__class__.__name__, param
242 )
243 )
244 super(Distribution, self).__init__()
ValueError: BernoulliProbs distribution got invalid probs parameter.
But yes, when JIT is introduced, this validation mechanism could be ineffective...
The code below runs without issues, which is an issue.
Maybe worth adding some validation mechanism, like NumPyro's validate_args option: https://num.pyro.ai/en/stable/utilities.html#enable-validation (or can we directly use NumPyro's validation system?)
But it seems to require some significant amount of engineering efforts (gotta implement conditions for every single CondDist), and is not of high priority IMO