flatironinstitute / bayes-kit

Bayesian inference and posterior analysis for Python
MIT License
42 stars 3 forks source link

Handle user-provided log densities that throw exceptions #43

Open WardBrian opened 1 year ago

WardBrian commented 1 year ago

This came up in a discussion with @gil2rok during #39, but applies to all of our algorithms. The user provided log_density_gradient function may throw an exception (we may even expect it to, e.g. if the parameters are out of support).

Currently, if this happens, the exception is propagated all the way up out of the algorithm and into the caller's code. This is probably not what we want to do. Instead, we could work as Stan does and treat any exception in the log density calculation as resulting in a rejection (assuming the algorithm contains an accept-reject step)

bob-carpenter commented 1 year ago

Stan separates two kinds of errors. If there are indexing errors, we tend to throw things that stop the algorithm because those aren't recoverable. If they are numerical errors like underflow, we treat them as rejections in the Stan algorithms. One thing to do would be to convert an exception into a negative infinite log density, which in any Metropolis setting will reject.

reject.stan

parameters {
  vector[2] y;
}
model {
  y[1:3] ~ normal(0, 1); // reject
} 

Running this produces an error and sampling doesn't happen:

RuntimeError: Error during sampling:
Exception: vector[min_max] max indexing: accessing element out of range. index 3 out of range; expecting index to be between 1 and 2 (in 'reject.stan', line 5, column 2 to column 24)

reject2.stan

parameters {
  vector<upper=0>[2] y;
}
model {
  y ~ exponential(1);  // out of range
} 

Running this produces a bunch of rejections of draws:

RuntimeError: Error during sampling:
        Exception: exponential_lpdf: Random variable[1] is -0.155166, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -4.05851, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -0.75561, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -3.93886, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -0.16518, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -0.633151, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -1.43768, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
    Exception: exponential_lpdf: Random variable[1] is -0.61375, but must be nonnegative! (in 'reject2.stan', line 5, column 2 to column 21)
...

If it could've gotten past initialization, it would've given very different answers.

WardBrian commented 1 year ago

We could try to separate those errors at the python level (e.g., we could have it so IndexError and others are not caught, but ArithmeticError and RuntimeError are). There is a well-established hierarchy of built in error types in Python.

Unfortunately, I believe as implemented today, BridgeStan raises all errors as RuntimeErrors, but other user-defined classes could follow this pattern