Closed torfjelde closed 11 months ago
Files with Coverage Reduction | New Missed Lines | % | ||
---|---|---|---|---|
src/utils.jl | 33 | 82.28% | ||
<!-- | Total: | 33 | --> |
Totals | |
---|---|
Change from base Build 6906848787: | 0.04% |
Covered Lines: | 2541 |
Relevant Lines: | 3165 |
All modified and coverable lines are covered by tests :white_check_mark:
Comparison is base (
c9489aa
) 80.24% compared to head (1cfe231
) 80.28%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
AFAIK it's not breaking; we're still keeping the old acclogprob!!
definition too.
EDIT: bumped patch version
Some samplers have very specific ways of accumulating log-probabilities which they do through overloading of the tilde-pipeline, e.g.
PG
in Turing overloads bothassume
andobserve
because it needs to accumulate the log-probabilities in the task localvarinfo
rather than the "global"varinfo
.But this means that, in certain scenarios, e.g.
PG
, usage of@acclogprob!!
in a model will (silently) result in target the incorrect model!One way we can fix this is to also pass in the
context
to@acclogprob!!
(andacclogp!!
), which can then alter how the log-probs are accumualted.A few examples where this is useful:
PG
in Turing.jl, can then correctly handle@addlogprob!!
by accumulating the log-probabilities in the task-localvarinfo
rather than the "global"varinfo
.assume
and returning 0 instead of the actuallogp
value (as is done here: https://github.com/TuringLang/Turing.jl/blob/6649f10c48917a27531214f02777408d2ab82928/src/mcmc/particle_mcmc.jl#L376), we can simply return thelogp
value as we do in all other implementations ofassume
and let theacclogp!!
call intilde_assume
handle the accumulation to the task-localvarinfo
. This would then make particle samplers compatible with stuff like re-weighting of log-probabilities, e.g. usingMiniBatchContext
or theGibbsContext
in https://github.com/TuringLang/Turing.jl/pull/2099 which overrides theassume
statements to instead return a conditioned value + its logprob, while avoiding hitting theobserve
.