Closed ferrine closed 2 years ago
I think what's happening in an example that we tried was that during the first couple of iterations of nuts we attempted to compute the logp of the model where the log of a sigma parameter was so small that exp(log_sigma) was 0 in float64.
This caused this assertion error, which then stopped the sampler. In previous versions of pymc, we would only return -inf, so the sampler would treat this as a divergence and continue running (to eventually converge).
I actually quite like the idea of distinguishing invalid parameters (where we throw an error) from invalid values (where we return -inf), it does seem to cause some issues here however.
I currently see the following options:
It sounds like you're describing an issue with the sampling process in PyMC, and not the general evaluation of a log-likelihood. If I'm understanding this correctly, it would make more sense to catch the AssertionError
s in the sampling process and handle them as divergences.
Yes, @aseyboldt describes the issue as I understand it.
- We could keep this behaviour in aeppl as it is, but change the nuts sampler so that it catches the assertion errors and converts those to divergences. This wouldn't work for samplers that are written in aesara however (I don't think there is a way to catch exceptions within a graph)
This sound like the most reasonable option, because it doesn't involve compromising the meaning of a log-likelihood in any way (e.g. if sigma <= 0
is invalid, then we don't want to treat it in any other way); however, we should throw a more precise exception type (e.g. ValueError
is more appropriate, or a new exception type).
- We remove (or adapt) the assertions in pymc using a graph rewrite, but keep them in aeppl as they are. In that case I think it would be great if the assertions had some form of metadata so that we only remove (or change to return np.nan) the invalid parameter assertions and not something a user might have asserted in the model.
This option is also quite possible/reasonable to do.
I think I like the last option most (and replace the Assert ops by something like ValidParameterAssert or so). Ideally though there would be a permissive and an non-permissive version (eg sigma >= 0 and sigma > 0). The boundary cases just are a bit different in a way, because they can happen due to the finite precision of the floats, not only because of invalid models.
Yeah, this is always a tricky area, because the problem could be addressed—in some cases—by using the completion of a distribution's log-probability in some space (i.e. include the limit points, like sigma == 0
) or any other extension of the domain/range (e.g. simply mapping undefined values to -inf
).
Both can be fine, but, since the former is not always possible/reasonable and the latter is arguably a strong and potentially questionable design choice (e.g. it becomes impossible to distinguish between invalid values and numerical limits when they both map to the same value), it's best that a general library like AePPL makes the more broadly applicable choice as its standard.
For AePPL, the real requirement is that we honor these domain errors in some non-fixed and externally manageable way, and the fact that the Assert
s can be handled/caught and removed/replaced altogether is a great example of that.
What's the plan here. Should we introduce a specialized Assert Op. Something like AssertValidParameter for easier/ safer parsing?
If we're going with the last option, then we need a rewrite that converts the Assert
s into nan
s when they fail (e.g. replace them with switch
/ifelse
graphs).
If we're going with the last option, then we need a rewrite that converts the
Assert
s intonan
s when they fail (e.g. replace them withswitch
/ifelse
graphs).
That should be easy enough. Would such rewrite be provided as an aeppl.util
or be implemented in pymc
?
Also are we on board with creating a more specialized AssertOp
to distinguish from other generic Asserts that may be introduced by Aesara / users / 3rd libraries?
If we're going with the last option, then we need a rewrite that converts the
Assert
s intonan
s when they fail (e.g. replace them withswitch
/ifelse
graphs).That should be easy enough. Would such rewrite be provided as an
aeppl.util
or be implemented inpymc
?
In PyMC, where it would be used.
Also are we on board with creating a more specialized
AssertOp
to distinguish from other generic Asserts that may be introduced by Aesara / users / 3rd libraries?
Yes, this is better than converting to nan
or raising AssertionError
s.
Description of your problem or feature request
Please provide a minimal, self-contained, and reproducible example. https://github.com/aesara-devs/aeppl/blob/18275e2bba83f21da5cde256bd67d86fa5b32451/aeppl/logprob.py#L92
Please provide the full traceback of any errors.
Please provide any additional information below. Sampling with PyMC (master) fails with assertion errors and does not treat AssertionErorr as divergent sample
Expected Behaviour
Possible Solution
This snippet solved my issue
Versions and main components
python -c "import aesara; print(aesara.config)"
)