pymc-devs / pymc-experimental

https://pymc-experimental.readthedocs.io
Other
77 stars 49 forks source link

Implement utility to recover marginalized variables from `MarginalModel` #285

Closed zaxtax closed 9 months ago

zaxtax commented 9 months ago

This is a PR to add support for the recover_marginals method. This allows us to sample values and get access to the logps of discrete variables which we marginalized out during sampling.

Closes #286

ricardoV94 commented 9 months ago

Ah one reason I see for why we may want to normalize the lps is that we actually don't need to evaluate the joint logp of the whole model, but only those variables that depend on the marginalized one.

In the future we may want to be more efficient and compile a logp with vars=[marginalized, *dependent_RVs] and the unnormalized lps don't make as much sense there. In contrast, the normalized lps should come out exactly the same.

zaxtax commented 9 months ago

Ah one reason I see for why we may want to normalize the lps is that we actually don't need to evaluate the joint logp of the whole model, but only those variables that depend on the marginalized one.

In the future we may want to be more efficient and compile a logp with vars=[marginalized, *dependent_RVs] and the unnormalized lps don't make as much sense there. In contrast, the normalized lps should come out exactly the same.

Yep, I'm convinced. Will make the changes

zaxtax commented 9 months ago

I think in the future we can include the optimisations where we compile the joint_logps all at once. As well as only using a logp that includes terms which contain marginalized_value.

But this should work for now

ricardoV94 commented 9 months ago

As a follow up we may want to standardize the signature of marginalize and recover marginals to allow passing strings or the variables in either case. Right now each is restricted to a different type which feels suboptimal