PrincetonLIPS / MaM

Official code for Generative Marginalization Models [ICML 2024] [SPGIM 2023 Workshop Oral]
20 stars 0 forks source link
any-order-autoregressive-models generative-models marginalization-models

Generative Marginalization Models

This repository contains the official PyTorch implementaion for the paper: Generative Marginalization Models [paper link], by Sulin Liu, Peter J. Ramadge, and Ryan P. Adams.

We introduce marginalization models (MaMs), a new family of generative models for high-dimensional discrete data.

What are Marginalization Models?

MaMs directly model the marginal distribution $p_\theta(x_s)$ for any subset of variables $x_s$ in $x$.

The learned marginals should satisfy the "marginalization self-consistency":

$$p_\theta(xs) = \sum\nolimits{x{s^c}} p\theta(xs, x{s^c})$$

where $x_{s^c}$ are the variables that are "marginalized out". See Figure 1 below for a concrete example for the binary case.

marginalization

To learn self-consistent marginals, we propose scalable training objectives that minimize the error of the following one-step self-consistency constraints imposed on marginals and conditionals over all possible orderings:

$$ \begin{gather} p\theta(x{\sigma (< d)}) p\theta(x{\sigma (d)} | x{\sigma (< d)}) = p\theta(x_{\sigma (\leq d)}), \ \quad \text{for any ordering } \sigma, \text{any }x \in [1:K]^D, d \in [1:D]. \nonumber \end{gather} $$

Why Marginalization Models?

Marginals are order-agnostic, hence MaMs allow any-order generation. Any-order autoregressive models [1,2,3] also allow any-order marginal inference via factorizing $p(x)$ into univariate conditionals. Compared to AO-ARMs, direct modeling of marginals have two main advantages:

  1. significant speeding up inference in test time for any marginal, with orders of magnitude. $\mathcal{O}(1)$ v.s. $\mathcal{O}(D)$ with autoregressive models.
  2. enabling scalable training of any-order generative models on high-dimensional problems under energy-based training, a common setting in physical sciences or applications with a target reward function. In contrast, training ARMs requires $\mathcal{O}(D)$ evaluation (a sequence of conditionals) to get the log-likelihood of one data point required for energy-based training.

Installation

git clone https://github.com/PrincetonLIPS/MaM.git
cd MaM
# optional virtual env
python -m venv env 
source env/bin/activate
python -m pip install -r requirements.txt

Maximum likelihood estimation problems

To train MaMs for maximum likelihood estimation, we fit the marginals by maximizing the expected log-likelihood over data distribution while enforcing the marginalization self-consistency.

$$ \begin{align} \max{\theta, \phi} \quad & \mathbb E{x \sim p{\text{data}}} \log p\theta(x) \ \text{s.t.} \quad & p\theta(x{\sigma (< d)}) p\phi(x{\sigma (d)} | x{\sigma (< d)}) = p\theta(x_{\sigma (\leq d)}), \quad \forall \sigma \in S_D, x \in {1,\cdots,K}^D, d \in [1:D]. \end{align} $$

For the most efficient training, the marginals can be learned in two-steps:

1. Fit the conditionals $\phi$: maximize the log-likelihood following the objective for training AO-ARMs.

$$ \max\phi \quad \mathbb E{x \sim p{\text{data}}} \mathbb E{\sigma \sim \mathcal{U}(SD)}
\sum\nolimits
{d=1}^D \log p\phi \left( x{\sigma(d)} | x_{\sigma(< d)} \right) $$

cd ao_arm
python image_main.py # MNIST-Binary dataset
python text_main.py # text8 language modeling
python mol_main.py load_full=True # MOSES molecule string dataset

2. Fit the marginals $\theta$: minimize the errors of self-consistency in Eq. (1).

$$ \min{\theta} \quad \mathbb E{x \sim q(x)} \mathbb E_{\sigma \sim \mathcal{U}(SD)} \mathbb E{d \sim \mathcal{U}(1,\cdots,D)} \left( \log p\theta(x{\sigma (< d)}) + \log p\phi(x{\sigma (d)} | x{\sigma (< d)}) - \log p\theta(x_{\sigma (\leq d)}) \right)^2. $$

cd mam
python image_main.py load_pretrain=True # MNIST-Binary
python text_main.py load_pretrain=True # text8
python mol_main.py load_pretrain=True # MOSES molecule string

Coming soon: code and model checkpoints for more image datasets including CIFAR-10 and Imagenet-32.

Energy-based training problems

In this setting, we do not have data samples from the distribution of interest. Instead, we have access to evaluate the unnormalized (log) probability mass function $f$ , usually in the form of reward function or energy function, that are defined by humans or by physical systems to specify how likely a sample is. The goal is to match the learned distribution $p_\theta(x)$ to the given desired probability $f(x)$ so that we can sample from $f(x)$ efficiently with a generative model. It is commonly encountered in modeling the thermodynamic equilibrium ensemble of physical systems [4] and goal-driven generative design problems with reward functions [5].

Training of ARM are expensive because of the need to calculate $p_\theta(x)$ with a sequence of conditionals. MaMs circumvent this by training directly with the marginals while enforcing the marginalization self-consistency.

$$ \begin{align} \min{\theta, \phi} \quad D\text{KL}\big( p{\theta} (x) \parallel p (x) \big) + \lambda \,\mathbb E{x \sim q(x)} \mathbb E{\sigma} \mathbb E{d} \left( \log p\theta (x{\sigma \left(< d\right)} ) + \log p\phi (x{\sigma \left(d\right)} | x{\sigma \left(< d\right)}) - \log p\theta (x_{\sigma \left(\leq d\right)} ) \right)^2. \end{align} $$

cd mam
# ising model energy-based training
python ising_eb_main.py
# molecule property energy-based training with a given reward function
python mol_property_eb_main.py 

Citation

Please check the paper for technical details and experimental results. Please consider citing our work if you find it helpful:

@article{liu2023mam,
  title={Generative Marginalization Models},
  author={Liu, Sulin and Ramadge, Peter J and Adams, Ryan P},
  journal={arXiv preprint arXiv:2310.12920},
  year={2023}
}

Acknowledgement

The code for training any-order conditionals of autoregressive models (in ao_arm/) are adapted from https://github.com/AndyShih12/mac, using the original any-order masking strategy proposed for training AO-ARMs without the [mask] token in the output.