pymc-devs / pymc

Bayesian Modeling and Probabilistic Programming in Python
https://docs.pymc.io/
Other
8.72k stars 2.01k forks source link

inferencedata.log_likelihood is summing observations #5236

Closed ricardoV94 closed 2 years ago

ricardoV94 commented 2 years ago

When talking with @lucianopaz I realized we completely broke log_likelihood computation in V4.

import pymc as pm
with pm.Model() as m:
    y = pm.Normal("y")
    x = pm.Normal("x", y, 1, observed=[5, 2])    
    idata = pm.sample(tune=5, draws=5, chains=2)
print(idata.log_likelihood['x'].values.shape)
# (2, 5, 1)

Whereas in V3:

import pymc3 as pm
with pm.Model() as m:
    y = pm.Normal("y")
    x = pm.Normal("x", y, 1, observed=[5, 2])    
    idata = pm.sample(tune=5, draws=5, chains=2, return_inferencedata=True)
print(idata.log_likelihood['x'].values.shape)
# (2, 5, 2)

This happened because the default model.logpt now returns the summed logp by default whereas before it returned the vectorized logp by default. The change was done in https://github.com/pymc-devs/pymc/commit/0a172c87e39ee64bf5101a5887281ad6548e6ea4

Although that is a more sane default, we have to reintroduce an easy helper logp_elemwiset (I think this is pretty much broken right now as well) which calls logpt with sum=False.

Also in this case we might want to just return the logprob terms as the dictionary items that are returned by aeppl.factorized_joint_lopgrob and let the end-user decide how he wants to combine them. These keys contain {value variable: logp term}. The default of calling at.add on all variables when sum=False is seldom useful (that's why we switched the default), due to potential unwanted broadcasting across variables with different dimensions.

One extra advantage of returning the dictionary items is that we don't need to create nearly duplicated graphs for each observed variable when computing the log-likelihood here:

https://github.com/pymc-devs/pymc/blob/fe2d101bb27e05b889eafda7e54b07e05250faee/pymc/backends/arviz.py#L268

We can request it for any number of observed variables at the same time, and then simply compile a function that has each variable logp term as an output, but otherwise shares the common nodes, saving on compilation, computation and memory footprint, when a model has more than one observed variable.

For instance, this nested loop would no longer be needed:

https://github.com/pymc-devs/pymc/blob/fe2d101bb27e05b889eafda7e54b07e05250faee/pymc/backends/arviz.py#L276-L282

CC @OriolAbril

OriolAbril commented 2 years ago

this nested loop would no longer be needed

That was originally the goal in #4489. My Aesara knowledge was (and still is) very limited so after a while of being stuck @brandonwillard took over and kept the nested loops. It seems like the description already outlies a clear path forward but he might also have extra insight on this.

brandonwillard commented 2 years ago

That was originally the goal in #4489. My Aesara knowledge was (and still is) very limited so after a while of being stuck @brandonwillard took over and kept the nested loops. It seems like the description already outlies a clear path forward but he might also have extra insight on this.

A lot has changed since my original v4 branch, so I can't imagine that many/any considerations based on #4489 will be relevant now. Regardless, if the description above is correct, it would appear as though the problem is due to other changes that now require an update to this logic.

Aside from that, @ricardoV94 seems to be proposing some potential paths for improvement. If so, they involve design decisions that need to be considered carefully by the people responsible for making them.

If there are any questions about how the basic machinery works, don't hesitate to ask; otherwise, I don't know how else I can help here.