google-deepmind / PGMax

Loopy belief propagation for factor graphs on discrete variables in JAX
Apache License 2.0
131 stars 10 forks source link

Calculating the Log Partition Function in LBP #6

Closed cuent closed 5 months ago

cuent commented 1 year ago

How can we calculate the log partition function, given that we can obtain node marginals through inferer.get_beliefs? Is it also feasible to acquire factor marginals for this purpose? While the typical approach to compute the log partition function is as follows:

$$log Z \approx \sum_{i\in V} (1-d_i)Hi + \sum{(i,j)\in E} I_{ij}$$

This formula requires both the entropy $Hi$ for individual nodes and the mutual information $I{ij}$ for edges. To compute these, we need the node marginals $b_i(xsi)$ and the joint marginal probabilities $b{ij}(x_i,x_j)$.

cuent commented 12 months ago

Given infer.compute_energy to compute the energy $E(x)$ for a state $x$ in a graphical model, the probability of a state $x$ using the Boltzmann distribution can be expressed as:

$$P(x) = \frac{\exp(-E(x))}{Z}$$

where $Z$ is the partition function, and $E(x)$ is the energy of state $x$. The partition function $Z$ is the sum over all possible states:

$$Z = \sum_{x_i} \exp(-E(x_i))$$

To find the logarithm of the partition function, $\log Z$, use the log-sum-exp (LSE) function over the negative energies of all states: $$\log Z = \log \left( \sum_{x_i} \exp(-E(x_i)) \right) = \text{LSE}(-E(x_i))$$

would that be an appropriate way to compute $\log Z$?

antoine-dedieu commented 5 months ago

We thank you for this relevant question! The approach that you suggest is correct for small PGMs. However, it would not scale to larger PGMs.

For larger models, a good upper-bound is the perturb-and-map one: see https://arxiv.org/pdf/1206.6410.pdf, Corollary 1. This upper bound is the expectation of the energy of the MAP solution of a perturbed model (using Gumbel unaries).

Please note that, in practice, (a) belief propagation estimates the MAP solution, and (b) we use multiple samples to approximate the expectation. So we end up with an approximation of the upper bound (and of the log-partition function).

Following your question, we have added a cell at the bottom of the RBM example notebook to compute this approximation of the log-partition function for our pretrained RBM.