aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 21 forks source link

AssertionError in aeppl.logp breaks sampling in pymc #84

Closed ferrine closed 2 years ago

ferrine commented 2 years ago

Description of your problem or feature request

Please provide a minimal, self-contained, and reproducible example. https://github.com/aesara-devs/aeppl/blob/18275e2bba83f21da5cde256bd67d86fa5b32451/aeppl/logprob.py#L92

Please provide the full traceback of any errors.

AssertionError: sigma > 0
Apply node that caused the error: Assert{msg='sigma > 0'}(Elemwise{Composite{((i0 + (i1 * sqr(i2))) - log(i3))}}[(0, 2)].0, All.0)

Please provide any additional information below. Sampling with PyMC (master) fails with assertion errors and does not treat AssertionErorr as divergent sample

Expected Behaviour

return -inf

Possible Solution

This snippet solved my issue

aesara.assert_op.Assert = lambda name: (lambda res, *cond: aesara.tensor.switch(
    aesara.tensor.all(aesara.tensor.stack([c.all() for c in cond])), 
    res, 
    -np.inf
))

Versions and main components

aseyboldt commented 2 years ago

I think what's happening in an example that we tried was that during the first couple of iterations of nuts we attempted to compute the logp of the model where the log of a sigma parameter was so small that exp(log_sigma) was 0 in float64.

This caused this assertion error, which then stopped the sampler. In previous versions of pymc, we would only return -inf, so the sampler would treat this as a divergence and continue running (to eventually converge).

I actually quite like the idea of distinguishing invalid parameters (where we throw an error) from invalid values (where we return -inf), it does seem to cause some issues here however.

I currently see the following options:

brandonwillard commented 2 years ago

It sounds like you're describing an issue with the sampling process in PyMC, and not the general evaluation of a log-likelihood. If I'm understanding this correctly, it would make more sense to catch the AssertionErrors in the sampling process and handle them as divergences.

brandonwillard commented 2 years ago

Yes, @aseyboldt describes the issue as I understand it.

  • We could keep this behaviour in aeppl as it is, but change the nuts sampler so that it catches the assertion errors and converts those to divergences. This wouldn't work for samplers that are written in aesara however (I don't think there is a way to catch exceptions within a graph)

This sound like the most reasonable option, because it doesn't involve compromising the meaning of a log-likelihood in any way (e.g. if sigma <= 0 is invalid, then we don't want to treat it in any other way); however, we should throw a more precise exception type (e.g. ValueError is more appropriate, or a new exception type).

brandonwillard commented 2 years ago
  • We remove (or adapt) the assertions in pymc using a graph rewrite, but keep them in aeppl as they are. In that case I think it would be great if the assertions had some form of metadata so that we only remove (or change to return np.nan) the invalid parameter assertions and not something a user might have asserted in the model.

This option is also quite possible/reasonable to do.

aseyboldt commented 2 years ago

I think I like the last option most (and replace the Assert ops by something like ValidParameterAssert or so). Ideally though there would be a permissive and an non-permissive version (eg sigma >= 0 and sigma > 0). The boundary cases just are a bit different in a way, because they can happen due to the finite precision of the floats, not only because of invalid models.

brandonwillard commented 2 years ago

Yeah, this is always a tricky area, because the problem could be addressed—in some cases—by using the completion of a distribution's log-probability in some space (i.e. include the limit points, like sigma == 0) or any other extension of the domain/range (e.g. simply mapping undefined values to -inf).

Both can be fine, but, since the former is not always possible/reasonable and the latter is arguably a strong and potentially questionable design choice (e.g. it becomes impossible to distinguish between invalid values and numerical limits when they both map to the same value), it's best that a general library like AePPL makes the more broadly applicable choice as its standard.

For AePPL, the real requirement is that we honor these domain errors in some non-fixed and externally manageable way, and the fact that the Asserts can be handled/caught and removed/replaced altogether is a great example of that.

ricardoV94 commented 2 years ago

What's the plan here. Should we introduce a specialized Assert Op. Something like AssertValidParameter for easier/ safer parsing?

brandonwillard commented 2 years ago

If we're going with the last option, then we need a rewrite that converts the Asserts into nans when they fail (e.g. replace them with switch/ifelse graphs).

ricardoV94 commented 2 years ago

If we're going with the last option, then we need a rewrite that converts the Asserts into nans when they fail (e.g. replace them with switch/ifelse graphs).

That should be easy enough. Would such rewrite be provided as an aeppl.util or be implemented in pymc?

Also are we on board with creating a more specialized AssertOp to distinguish from other generic Asserts that may be introduced by Aesara / users / 3rd libraries?

brandonwillard commented 2 years ago

If we're going with the last option, then we need a rewrite that converts the Asserts into nans when they fail (e.g. replace them with switch/ifelse graphs).

That should be easy enough. Would such rewrite be provided as an aeppl.util or be implemented in pymc?

In PyMC, where it would be used.

Also are we on board with creating a more specialized AssertOp to distinguish from other generic Asserts that may be introduced by Aesara / users / 3rd libraries?

Yes, this is better than converting to nan or raising AssertionErrors.