Closed ricardoV94 closed 2 years ago
this nested loop would no longer be needed
That was originally the goal in #4489. My Aesara knowledge was (and still is) very limited so after a while of being stuck @brandonwillard took over and kept the nested loops. It seems like the description already outlies a clear path forward but he might also have extra insight on this.
That was originally the goal in #4489. My Aesara knowledge was (and still is) very limited so after a while of being stuck @brandonwillard took over and kept the nested loops. It seems like the description already outlies a clear path forward but he might also have extra insight on this.
A lot has changed since my original v4
branch, so I can't imagine that many/any considerations based on #4489 will be relevant now. Regardless, if the description above is correct, it would appear as though the problem is due to other changes that now require an update to this logic.
Aside from that, @ricardoV94 seems to be proposing some potential paths for improvement. If so, they involve design decisions that need to be considered carefully by the people responsible for making them.
If there are any questions about how the basic machinery works, don't hesitate to ask; otherwise, I don't know how else I can help here.
When talking with @lucianopaz I realized we completely broke log_likelihood computation in V4.
Whereas in V3:
This happened because the default
model.logpt
now returns thesummed
logp by default whereas before it returned the vectorized logp by default. The change was done in https://github.com/pymc-devs/pymc/commit/0a172c87e39ee64bf5101a5887281ad6548e6ea4Although that is a more sane default, we have to reintroduce an easy helper
logp_elemwiset
(I think this is pretty much broken right now as well) which callslogpt
withsum=False
.Also in this case we might want to just return the logprob terms as the dictionary items that are returned by
aeppl.factorized_joint_lopgrob
and let the end-user decide how he wants to combine them. These keys contain{value variable: logp term}
. The default of callingat.add
on all variables whensum=False
is seldom useful (that's why we switched the default), due to potential unwanted broadcasting across variables with different dimensions.One extra advantage of returning the dictionary items is that we don't need to create nearly duplicated graphs for each observed variable when computing the log-likelihood here:
https://github.com/pymc-devs/pymc/blob/fe2d101bb27e05b889eafda7e54b07e05250faee/pymc/backends/arviz.py#L268
We can request it for any number of observed variables at the same time, and then simply compile a function that has each variable logp term as an output, but otherwise shares the common nodes, saving on compilation, computation and memory footprint, when a model has more than one observed variable.
For instance, this nested loop would no longer be needed:
https://github.com/pymc-devs/pymc/blob/fe2d101bb27e05b889eafda7e54b07e05250faee/pymc/backends/arviz.py#L276-L282
CC @OriolAbril