pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.67k stars 2k forks source link

BUG: OrderedMultinomial occasionally fails to sample with numpyro sampler #7155

Closed velochy closed 7 months ago

velochy commented 7 months ago

Describe the issue:

I am running very big regression models (with 20+ regressors and complex pooling) and with the bigger models, numpyro sampler occasionally fails completely on some chains.

I am relatively sure the issue is related to a category getting zero probability due to floating number imprecision, but I have failed to find a good small and reproducible example. Even in my big examples, the failures are random, happening about 1/3 of the time.

What fixed it for me was to re-implement the ordered part myself (copy-pasting from _OrderedMultinomial) and then add: probs += 1e-30 * (probs.shape[-1]) before feeding the probabilities to pm.Multinomial. This seems to avoid inf/nan values getting created and the entire chain failing as logprob has a finite value.

It's worth noting that the regular pymc sampler seems to not have the same problem, although considering it takes 6+ hours to sample the models I am running, I have not tested it enough to be sure.

Reproduceable code example:

# Could not create a self-contained reproducible example

Error message:

No response

PyMC version information:

pymc 5.10.2

Context for the issue:

I don't think this is a high priority issue.

This being said, if you are fine with me using the same fix of adding an infinitesimal value to all probabilities, I'd be happy to do a PR that implements the same fix into OrderedMultinomial (and possibly OrderedLogistic and OrderedProbit as they are likely to exhibit similar behavior).

That or just close the issue as not reproducible. I already have a fix for myself so I'm ok either way :)

ricardoV94 commented 7 months ago

Those kind of hacks are sometimes needed, specially because the gradients can result in nan otherwise. But I don't think they should be added by default, since it depends on the use case. I can see cases that require zero-event probabilities to come out exactly as zero when you evaluate probabilities (e.g, in ZeroInflated models).

PyMC is not used only for NUTS sampling.

However if we see a way to make the computation more stable without an eps hack, that could be pursued, but probably need a MWE for that

velochy commented 7 months ago

I've thought about this a fair bit and even tried rewriting sigmoid(x)-sigmoid(y) with a multiplicative tanh identity in the hopes it preserve small values, but it did not seem to improve on the status quo.

Maybe it still makes sense to add the eps hack, but with a parameter flag that would allow it to be turned off for the niche use cases that it messes with?

ricardoV94 commented 7 months ago

I don't think we should add it, even less by default.

When someone defines a distribution and requests it's logp we should try to honor the right expression, not something that is hacked to work in some edge cases. Floating precision issues happen everywhere, sigmas can also underflow to 0 sometimes. The modeller is allowed to tweak values as they see fit prior to using a given distribution to work around any issues their particular model/data have.

If you have evidence this is a widespread problem when using this distribution, in that most people won't be able to do anything useful with it, that would be a more compelling case. This kind of hacks have a tendency to creep in in the source code and make things messier over time.

If you think this is a general problem with this distribution, a reproducible example would help here. The first step would be seeing if there's room to define things in a more numerical precise way without hacks.

velochy commented 7 months ago

ok, fair enough. I don't think its widespread, although it did take me a couple of days to figure out why it was failing and I was hoping to spare others of the same fate. But I see your point about this essentially building technical debt and understand why you want to keep the code base clean:) Feel free to close this issue. As I said, I already resolved the issue for myself.

ricardoV94 commented 7 months ago

One thing we do try is to give better debugging tools. If you have a problematic point you can pass it to model.debug(pt, fn="dlogp") for example to see if there are nan grads. The issue is numpyro is a bit too blackbox sometimes and you can't easily get the info of what is failing where.

This is compounded by the fact that JAX does not allow any runtime checks / asserts, although point 5. could help a bit: https://github.com/pymc-devs/pymc/issues/6841#issuecomment-1661680771

I'll leave the issue opened a bit longer to see if anyone else has had issues with the OrderedMultinomial

ricardoV94 commented 7 months ago

We can also add a note in the documentation of the distribution if you think that would have spared you some trouble

velochy commented 7 months ago

We can also add a note in the documentation of the distribution if you think that would have spared you some trouble

It probably would have, but I'm not sure how to phrase it.

"This distribution can result in samplers failing because a probability for a bin can become zero. The woraround for this is to implement it yourself and add 1e-30 to all bins."

would sound very out of place, don't you think?

ricardoV94 commented 7 months ago

Yeah a bit