pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.52k stars 985 forks source link

Support for non-scalar poutine.scale with TraceEnum_ELBO #1897

Open mbabadi opened 5 years ago

mbabadi commented 5 years ago

It would be useful to consider support for non-scalar poutine.scale with TraceEnum_ELBO. In general, importance sampling (such as stratified sampling) across different plates requires scaling different samples in the ELBO differently in order to obtain an unbiased estimator of the ELBO. This can be currently done using poutine.scale and Trace_ELBO. However, TraceEnum_ELBO does not support non-scalar scale factors at the moment. Therefore, one can not perform stratified mini-batching and full enumeration on models with discrete latent variables.

So long as the scale factors have the correct batch dimension, I believe the operation is well-defined: as usual, enumeration dimensions are created on the left end, logsumexp is performed to reduce the log_prob to the batch shape, and ultimately the non-scalar scale factor is multiplied.

mbabadi commented 5 years ago

@fritzo what do you think about this? :)

fritzo commented 5 years ago

I think this makes sense. TraceEnum_ELBO is quite complex, so it would take significant effort to add this feature. Can you provide a simple example model to help me see exactly what you're requesting?

mbabadi commented 5 years ago

Sure! consider LDA:

def model(data, args):
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with pyro.plate("documents", args.num_docs):
        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
        with pyro.plate("words", args.num_words_per_doc):
            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics))
            pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)

Let us assume that the corpus is contains 3 volumes of 1000, 100, and 10 documents, respectively. The first volume contains boringly similar legal proceedings whereas the second and third volumes consist of beautiful and novel papers in probabilistic programming and botany, respectively.

Given the heterogeneity of the corpus, uniform mini-batching across all 1110 documents leads to a high variance ELBO. With a batch size of 10 documents, only 1 in 100 draws will show a botany paper and one needs to use a very small learning rate.

A more desirable mini-batching scheme is manually balance out the mini-batches to ensure equal representation of each document, let's say, to pick 5 from legal, 5 from probabilistic programming, and 5 from botany each time. To get an unbiased estimator for ELBO, though, we need to weight each legal document by 1000 / 5 = 200, each probabilistic programming document by 100 / 5 = 20, and each botany paper by 10 / 5 = 2. This can be achieved, for instance, by computing these normalizing weights as a part of the mini-batching strategy, to pass them to model (and the guide), and to use poutine.scale to scale each document accordingly:

def model(data, doc_weights, args):
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_words = pyro.sample(
            "topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words))

    with poutine.scale(doc_weights):
        with pyro.plate("documents", args.num_docs):
            doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
            with pyro.plate("words", args.num_words_per_doc):
                word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics))
                pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)

In the above example, data has shape [num_words_per_doc, num_docs], num_docs is 15, doc_weights has shape [num_docs] and specifies the responsibility of each document in the log likelihood:

doc_weights = (20, 20, 20, 20, 20, 5, 5, 5, 5, 5, 1, 1, 1, 1, 1)

The desirable behavior with parallel enumeration is to calculate the log_prob for every configuration [num_words, num_topics, num_words_per_doc, num_docs], to reduce it to [num_words_per_doc, num_docs] via two logsumexp, and to multiply the reduced log_prob with doc_weights at the very end.

ps1> @fritzo, this is a low priority feature request :-) though, smart stratified sampling and better yet, importance sampling of mini-batches can speed up training by many orders of magnitude in certain problems (see https://arxiv.org/abs/1602.02283).

ps2> My personal use case is working with single-cell gene expression data where the dynamic range of gene expression is very wide: some genes are expressed in every cell with ~ 200 copies whereas some genes are expressed in less than 1% of all cells. Naïve (uniform) mini-batching of genes results in almost complete disappearance of the lowly expressed (yet, important) genes.