probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

choicemap constraints not being enforced in the generate function #138

Closed rameshputalapattu closed 4 years ago

rameshputalapattu commented 4 years ago

I'm trying to solve the polyhedral dice problem using Gen. Suppose we have 6 dices with the following sides:4,6,8,10,12,20. One of the dice is chosen at random. On a single throw, the observed value is 7. What is the posterior probability for each of the dice being chosen ? I have written the following model in Gen.

@gen function single_roll()
    dice_all = [4,6,8,10,12,20]
    dice_chosen = @trace(uniform_discrete(1,length(dice_all)),:choose)
    dice_sides = dice_all[dice_chosen]
    roll = @trace(uniform_discrete(1,dice_sides),:roll)
    return roll
end 

The conditioning on the observed value is performed as below.

constraints = choicemap()
constraints[:roll] = 7

Given the observed value was 7, clearly dice with 4 sides and 6 sides are ruled out. However, the trace returned by generate has 1 (dice with 4 sides) some times

trace,_ = generate(single_roll,(),constraints)
println(trace[:choose])  # returns 1 and 2 sometimes which is ruled out by the observation

Also, when I calculate the posterior probability of 4 sided dice being selected, I get a value of 0.1675 which is approx. 1/6 . This indicates that constraints are not being applied in the generate call. The code is below.

constraints = choicemap()
constraints[:roll] = 7
traces,_,_ = importance_sampling(single_roll,(),constraints,100000)
sum([trace[:choose]==1 for trace in traces])/100000 # returns 0.1675

Not sure if I missed something in my code. I'm not sure if issues are the right forum for asking questions like this. Thanks for your time and help in advance.

alex-lew commented 4 years ago

Thanks for trying out Gen!

When using importance sampling, your must compute a weighted average of your test function, using the weights it returns:

traces, log_weights, lml = importance_sampling(single_roll,(),constraints,100000)
weights = exp.(log_weights)
sum([weights[i] * (traces[i][:choose]==1) for i in 1:100000]) # returns 0
sum([weights[i] * (traces[i][:choose]==4) for i in 1:100000]) # returns 0.277
exp(lml) # returns 0.059, the marginal probability of a 7

Alternatively, you can use importance_resampling to get approximate posterior samples, and take the (regular) average of those.

The fact that Gen.generate is not producing consistent traces is because your model violates one of the assumptions that Gen makes: a choice with a certain address always has the same support. It's OK to violate this constraint sometimes, as the above importance sampling experiment demonstrates, but it does mean Gen's inference will be less efficient.

Note that when the above assumption is violated, it is possible to write probabilistic programs that would require arbitrarily difficult constraint solving just to produce a trace consistent with the observations. See also #134.

Hope that helps!

rameshputalapattu commented 4 years ago

Hi @alex-lew Thank you for your response. It clarifies everything. Also, thank you and congratulations for your amazing talk at strange loop 2019. It inspired me to try out Gen. Waiting for the release of pclean.It would be of great help for some data quality problems I’m facing at my work place. 😀

My previous (limited) experience with Probabilistic programming systems involved pyMC3, Stan (and webppl ) where inference is a black box (inference buttons). Hoping that Gen’s novel programmable inference will help me learn and explore different inference techniques. Thank you 🙏 once again. I’m happy to close this issue if you agree.

marcoct commented 4 years ago

@rameshputalapattu Thanks for experimenting with Gen!

I would like to add that generate returns a weight along with a trace. Even though generate does sometimes return traces containing the 4- or 6-sided die, those traces are given weight negative infinity:

using Gen

@gen function single_roll()
    dice_all = [4,6,8,10,12,20]
    dice_chosen = @trace(uniform_discrete(1,length(dice_all)),:choose)
    dice_sides = dice_all[dice_chosen]
    roll = @trace(uniform_discrete(1,dice_sides),:roll)
    return roll
end 

constraints = choicemap()
constraints[:roll] = 7

for i=1:1000
    trace, w = generate(single_roll,(),constraints)
    println("choose: $(trace[:choose]), w: $w")  # returns 1 and 2 sometimes which is ruled out by the observation
end

produces:

choose: 2, w: -Inf
choose: 3, w: -2.0794415416798357
choose: 5, w: -2.4849066497880004
choose: 6, w: -2.995732273553991
choose: 2, w: -Inf
choose: 6, w: -2.995732273553991
choose: 4, w: -2.3025850929940455
choose: 5, w: -2.4849066497880004
choose: 3, w: -2.0794415416798357
choose: 1, w: -Inf
choose: 6, w: -2.995732273553991
choose: 4, w: -2.3025850929940455
choose: 2, w: -Inf
choose: 2, w: -Inf
...
rameshputalapattu commented 4 years ago

Hi @marcoct, Thank you for the clarification. Also thank you 🙏 for your wonderful talk at strange loop 2019. Much to learn.