probcomp / Gen.jl

A general-purpose probabilistic programming system with programmable inference
https://gen.dev
Apache License 2.0
1.79k stars 160 forks source link

Error in `discard` returned by `update` for DynamicDSLTrace with hierarchical addresses #512

Closed fsaad closed 3 months ago

fsaad commented 11 months ago

Consider the following simple programs

using Gen

@gen function model_ok()
    k ~ poisson(5)
    for i=1:k
        {i} ~ uniform(0,1)
    end
end

@gen function model_bad()
    k ~ poisson(5)
    for i=1:k
        {:value => i} ~ uniform(0,1)
    end
end

Do we expect that Gen.update operates differently in these two cases. For model_ok, if we call update in such a way that reduces k then the discard address {i} are correctly placed in the discard, but in model_bad, the address {:value => i} are not in the discard. See following example:

Discard for model_ok

julia> tr, = Gen.generate(model_ok, (), choicemap(:k=>3));
julia> (new_trace, weight, _, discard) = Gen.update(tr, Gen.choicemap(:k=>1));
julia> display(discard)
│
├── :k : 3
│
├── 2 : 0.552591934644268
│
└── 3 : 0.6198635352707209

Discard for model_bad

julia> tr, = Gen.generate(model_bad, (), choicemap(:k=>3));
julia> (new_trace, weight, _, discard) = Gen.update(tr, Gen.choicemap(:k=>1));
julia> display(discard)
│
└── :k : 3

Possible Reason

The following lines seem to be the culprit: the recursive call to add_unvisited_to_discard! use an anonymous choicemap when subdiscard is empty, so the call to set_submap! fills out a subdiscard that we never access again.

https://github.com/probcomp/Gen.jl/blob/7955b07e6df273633da6e72d434884b73fbdffe6/src/dynamic/update.jl#L177-L181

Changing these lines to

                subdiscard = get_submap(discard, key)
                subdiscard_recursive = isempty(subdiscard) ? choicemap() : subdiscard
                add_unvisited_to_discard!(
                    subdiscard_recursive,
                    subvisited, submap)
                set_submap!(discard, key, subdiscard_recursive)

seems to fix the issue.

fsaad commented 11 months ago

Related #506

ztangent commented 11 months ago

Thanks for catching this!

So I'm trying to understand why the line of code you identified was introduced in the first place, and it looks like @alex-lew added it in this old commit: c8ce0d735a775749052c12db657e92cba761219c

Apparently it was introduced to address cases where things were being discarded, but shouldn't be. So maybe we can check if the fix you suggested @fsaad also passes the original test case that @alex-lew added?

This is the test case:

https://github.com/probcomp/Gen.jl/blob/05709187d79316464623625e370b7cf59706d2ea/test/dsl/dynamic_dsl.jl#L153-L183

ztangent commented 11 months ago

This issue also seems highly related, but @alex-lew identified a different line of code in add_unvisited_to_discard! as the likely culprit, so maybe it's a different issue?

https://github.com/probcomp/Gen.jl/issues/237

mlb2251 commented 3 months ago

@yifr and I just ran into this same issue. We tried out @fsaad 's solution and it works.

We initially also thought it was the code @alex-lew points to, but actually that branch doesn't get run in @fsaad 's (and our own) examples because key in visited is only true when the hierarchical address comes from a GenerativeFunction call and not when you just define one like {:value => i} ~ ....

@fsaad 's reasoning for it being wrong is right – the original code has

add_unvisited_to_discard!( 
     isempty(subdiscard) ? choicemap() : subdiscard, 
     subvisited, submap)

But since add_unvisited_to_discard! actually works by mutating its first argument, it doesn't make sense to pass in a fresh choicemap without first keeping a reference to it as they do in their fix

I'll make a quick PR with the fix!