Closed loreloc closed 1 year ago
for the partition func, let's try to materialize it as a circuit
I was thinking having something like this:
log_partition_function
to our input layers that computes the marginalization over all variables, since parameters are usually tensorized. For the distributions we have now it would just output a zero-tensor.MarginalizedInputLayer
(another name?) that takes any input layer and a list of variables to marginalize. This forward
should combine the output of the layer and its marginalization by applying a mask constructed from the list of variables to marginalize.from_marginalized_variables
(another name?) that constructs another tensorized circuit having a MarginalizedInputLayer
as input layer, and by passing the list of inner layers by reference.model.eval()
is called (e.g., during evaluation), a query result is cached (e.g., the computation of the partition function).We need to distinguish two cases for marginalization, and hence for MarginalizedInputLayer
s.
First, we want to create a new circuit marginalizing out a fixed set of vars. This is essentially implementing marginalization as an operation in the atlas. For it, it is fine to create a fresh new computational graph (layer).
However, if we want to allow for marginalizing many different (not known a priori) sets of variables (which is the use case of the current marginalization routines), having a different input layer per var set can be wasteful. Instead, we need to use sets of masks. The difference with the current implementation shall be that i) masks should not be stored as state properties of the layer and ii) we shall allow for batched computations. I.e., having a batch of input examples, each one with its possibly different mask of variables to marginalize.
Add the following methods to the PC class.
log_likelihood
computing the log-likelihood of some input.log_partition_function
computing the logarithm of the partition function.Materialize marginalized circuits to compute the partition function and any marginal probability.