regehr / guided-tree-search

heuristically and dynamically sample (more) uniformly from large decision trees of unknown shape
Mozilla Public License 2.0
12 stars 3 forks source link

How do we handle rejection sampling? #2

Open DRMacIver opened 2 years ago

DRMacIver commented 2 years ago

Rejection sampling is pretty common in generators. i.e. something like:

while(true) {
    let thing = generate_a_thing();
    if is_good(thing) { return thing; }
}

The problem with this for uniform sampling is that from the perspective of a naive approach to this algorithm, the rejection branch looks very interesting - there are all these exciting trees under there to explore - but in fact we should ideally never run this loop more than once and anything else is a waste.

Do we want a special API for indicating when this is the case?

In Hypothesis we have an API that looks roughly as follows:

while(true) {
    let thing = generate_a_thing();
    if is_good(thing) { return thing; }
    reject();
}

The way this works is that generation continues as normal after you've called reject() but we mark the choice sequence (i.e. sequence of flips leading up to this point) leading up to the first reject() as one to be excluded in future generations. This has the dual advantage of not wasting a lot of time exploring dead parts of the generations space, and also it improves our rejection rate over time, ideally leading to faster generation.

This has the downside that it does require people to specifically annotate their code where they're using rejection sampling in order to get the advantages of it. This API is approximately never used by Hypothesis users (it's not even really accessible to them outside of a few special case wrappers around it).

In a beautiful ideal world we would be able to automatically guess that rejection sampling is going on and do something appropriate without any user intervention, but I haven't the first idea of how that would work.

regehr commented 2 years ago

I like this and believe that an API for reject() is a totally reasonable thing to ask people to use