flatironinstitute / bayes-kit

Bayesian inference and posterior analysis for Python
MIT License
42 stars 3 forks source link

Adds Metropolis and MH samplers with tests. #17

Closed jsoules closed 1 year ago

jsoules commented 1 year ago

This PR implements Metropolis and Metropolis-Hastings samplers in the Bayes-Kit framework, and incorporates a couple drive-bys that PyLance thought were necessary. (Rude of me, sorry.)

Major issues for discussion:

The major contributions requiring attention are bayes_kit/metropolis.py (which implements the 2--well, 1.5--samplers), test/test_metropolis.py with the corresponding tests, and test/models/skew_normal.py which implements a standard skew-normal model for testing (essentially a pass-through to the scipy implementation).

Minor changes

WardBrian commented 1 year ago

Am I correct that this makes bayes_kit/rwm.py and tests/test_rwm.py no longer necessary?

On the main "substance" point, could you sketch out or put in a gist somewhere how you think this code would look? Ideally, I can imagine something like

class MetropolisHastings:
    ....

# complete code of metropolis is as follows
class Metropolis(MetropolisHastings):
    def __init__(self, ...):
          super().__init__(...)
          self.test_proposal = metropolis_proposal_test

(or have them both have a common private base class, if the direct inheritance is objectionable)

This to me would be more than acceptable, but as I've written it here would require some re-thinking of the argument structure to the tests etc

Re: testing, we may want to come up with "time budgets" for each test. Currently test_rmw.py performs 10k iterations for std_normal, which seems to be fairly solid in terms of stochastic failures.

Finally, do you mind formatting the changes with black?

jsoules commented 1 year ago

Am I correct that this makes bayes_kit/rwm.py and tests/test_rwm.py no longer necessary?

It seems to me that all the functionality of rwm.py is incorporated, but I don't rule out the possibility that there's some other reason you want to keep them?

Finally, do you mind formatting the changes with black?

Done.

On the main "substance" point, could you sketch out or put in a gist somewhere how you think this code would look?

I was referring to the MetropolisHastingsCombo class in the present bayes_kit/metropolis.py (which would obviously be renamed, I just needed to distinguish the redundant implementation for now). I don't see a strong need to set up a formal inheritance relationship, but as you say it would be pretty trivial along the lines of what you suggested.

I'll point up the details in inline commentary.

WardBrian commented 1 year ago

Ah, it seems I did not scroll down far enough.

Personally, I would prefer that the user-facing API still present itself as two classes, one called Metropolis and one called MetropolisHastings. If those classes have some code sharing that is great, but I don't think calling MetropolisHastings with correction=None is ergonomic for someone wanting a Metropolis sampler, and it opens up a whole can of worms in terms of input checking ("What do you mean I forgot an argument and now my sampler is actually a different algorithm?")

Edit: Also, feel free to delete the rwm files as part of this

WardBrian commented 1 year ago

Re testing: The fact that the latest commit introduced a bug in MH which is causing the skew normal test to estimate an answer of ~24 when the true answer is 0.77 suggests to me that we can increase the absolute tolerance above 0.1 and still catch bugs.

(The bug seems to be a swapping of the forward/reverse directions in the acceptance criteria compared to the previous code)

codecov-commenter commented 1 year ago

Codecov Report

Merging #17 (4933846) into main (fbb71e7) will increase coverage by 3.48%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main      #17      +/-   ##
==========================================
+ Coverage   92.20%   95.68%   +3.48%     
==========================================
  Files           9        9              
  Lines         218      255      +37     
==========================================
+ Hits          201      244      +43     
+ Misses         17       11       -6     
Impacted Files Coverage Δ
bayes_kit/__init__.py 100.00% <ø> (ø)
bayes_kit/hmc.py 94.73% <100.00%> (+0.45%) :arrow_up:
bayes_kit/mala.py 93.54% <100.00%> (+1.24%) :arrow_up:
bayes_kit/metropolis.py 100.00% <100.00%> (ø)
bayes_kit/model_types.py 100.00% <100.00%> (+28.57%) :arrow_up:
bayes_kit/rhat.py 100.00% <100.00%> (ø)
bayes_kit/smc.py 95.55% <100.00%> (ø)

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

jsoules commented 1 year ago

Re testing: The fact that the latest commit introduced a bug in MH which is causing the skew normal test to estimate an answer of ~24 when the true answer is 0.77 suggests to me that we can increase the absolute tolerance above 0.1 and still catch bugs.

Yeah, still zeroing in on the right tolerances here. I do note that scipy.stats.skewnorm.var(4) is 0.4, so maybe looking for tolerance of 0.1 is a bit tight :)

I'll fuss with that some more tomorrow, but assuming I can get the right tolerances so the tests don't have stochastic failures I think this is ready to be considered for real.

WardBrian commented 1 year ago

I'm a bit weary of fixing random seeds in general, but I think it is acceptable for something like the skew normal test. I think we should add a comment to the effect of "Note: We don't generally want to fix tests seeds but are making an exception for this algorithm because..."

We previously didn't have many issues with failures from the Metropolis tests on the StdNormal model, is the fixed seed just to allow us to lower the number of iterations?

jsoules commented 1 year ago

I'm a bit weary of fixing random seeds in general, but I think it is acceptable for something like the skew normal test. I think we should add a comment to the effect of "Note: We don't generally want to fix tests seeds but are making an exception for this algorithm because..."

Sure, happy to add a note. I'm viewing this as a trade-off: fixed seeds mean we can run the same integration test with many fewer iterations and tighter tolerances, while still catching major bugs. (Modulo the ones that would always come out in the wash.)

We previously didn't have many issues with failures from the Metropolis tests on the StdNormal model, is the fixed seed just to allow us to lower the number of iterations?

Exactly. I was seeing failures on the order of 1 in 25ish (?) runs, which doesn't seem like it indicates an actual error, but does seem likely to be an annoyance. I could up the tolerances or number of iterations, but that's taking the batteries out of the smoke alarm just as much as fixing the seed is--I figured it would be both more efficient and more diagnostic to fix the seed and enable using tighter tolerances and fewer iterations.

WardBrian commented 1 year ago

Sounds good to me. I personally think this is in a good enough place to merge, I just wanted to ask how many of the various TODOs would you like to do now versus coming back to them? Particularly things like error checking in the constructor

bob-carpenter commented 1 year ago

I could up the tolerances or number of iterations, but that's taking the batteries out of the smoke alarm just as much as fixing the seed is--I figured it would be both more efficient and more diagnostic to fix the seed and enable using tighter tolerances and fewer iterations.

This is the fundamental tension in stochastic testing! And in hypothesis testing in general where you need to make a binary decision. It's always a sensitivity/specificity tradeoff.

bob-carpenter commented 1 year ago

I think we should put error checking in the constructors into the code. I'd like to start encapsulating tests so that they can be re-used as a lot of our samplers are going to do the same thing, like checking if a value is an integer and is non-negative and throwing an exception if it's not. But I'm not sure how to deal with formatting as we don't want to eagerly format an error message. Do either of you know when formatted strings like f"{a=}" gets executed? Specifically, could I do something like this:

def validate_non_neg(x, fun):
    if not x >= 0: raise ValueError(msg)

def foo(x): 
    validate_non_neg(x, "x: f"{foo() requires x >= 0, found {x=}")

What I'm hoping is that the format string only gets executed if it's used in validate_non_neg. All I can find online is "at runtime," which is obvious. But will it evaluate on call or on use?

If this isn't lazy, we'll have to go with the older formatting style.

WardBrian commented 1 year ago

All arguments are fully evaluated before the function call. This would also be true if you used % formatting or what have you. What you'd probably want to do in this situation is pass the name of the caller and the name of the variable and then let the check function produce the message as-needed.

Jeff and I just talked about this offline, but it's not clear for these simple samplers how much validation is needed in the constructor. In particular, we call dims and log_density both right away, so if you model doesn't have those it will raise a type error (without us needing to explicitly check/raise ourselves).

bob-carpenter commented 1 year ago

it's not clear for these simple samplers how much validation is needed in the constructor. In particular, we call dims and log_density both right away, so if you model doesn't have those it will raise a type error (without us needing to explicitly check/raise ourselves).

The problem I have with delegated calls failing is that the error messages are confusing to the API user.

It looks like NumPy checks its arguments and I'd like to follow that practice. One case where I may be OK with delegated failures is when accessing a member variable that doesn't exist.

All arguments are fully evaluated before the function call. This would also be true if you used % formatting or what have you.

OK. I suspected as much but couldn't find doc or how to evaluate.

I can just do this then:

def validate_non_neg(x, fun_name, var_name):
     if not x >= 0:
         raise ValueError("in call to function {0}(...), variable {1} must be positive, but found {2}.".format(fun_name, var_name, x))

That will presumably not evaluate the formatted string unless x < 0 or x is NaN.

And then run

>>> validate_non_neg(1, "foo", "x")

>>> validate_non_neg(-1, "foo", "x")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 3, in validate_non_neg
ValueError: in call to function foo(...), variable x must be positive, but found -1.

I love writing this kind of validation code and don't mind doing it. I'll start by moving validations into a file called validate_args or something like that.

WardBrian commented 1 year ago

That will presumably not evaluate the formatted string unless x < 0 or x is NaN.

That is true, but neither would and f-string there, so you should probably prefer to use that.

On more general validation, I think checking our own preconditions, like that a number of steps is positive, is good. I'm less keen to do things like check that the return value of proposal_rng() is something castable to a float, since numpy is going to do that for us and there's no need to check it twice.

That said, anything we can check in the constructor is pretty cheap and I won't complain too much. Checking things in the hot loop (especially if the code will naturally fail) less so

bob-carpenter commented 1 year ago

I think checking our own preconditions, like that a number of steps is positive, is good. I'm less keen to do things like check that the return value of proposal_rng() ...

Checking things in the hot loop (especially if the code will naturally fail) less so

Absolutely. We only check the top-level API calls by clients that we do not control. So we don't do error checking on internal functions, inside loops, etc. If we happen to have an external API function foo() that we also call inside a loop we control, then what I've always done is define a non-checked variant _foo() that gets called by the client-facing foo().

def foo(...):
    validate_args(...)
    _foo(...)

def _foo(...):
    # no error check ...
    ... do the work ...

I'm less certain in an interpreted language like Python where there's not aggressive compiler inlining.