Open mbabadi opened 5 years ago
@fritzo what do you think about this? :)
I think this makes sense. TraceEnum_ELBO
is quite complex, so it would take significant effort to add this feature. Can you provide a simple example model to help me see exactly what you're requesting?
Sure! consider LDA:
def model(data, args):
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
with pyro.plate("documents", args.num_docs):
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics))
pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)
Let us assume that the corpus is contains 3 volumes of 1000, 100, and 10 documents, respectively. The first volume contains boringly similar legal proceedings whereas the second and third volumes consist of beautiful and novel papers in probabilistic programming and botany, respectively.
Given the heterogeneity of the corpus, uniform mini-batching across all 1110 documents leads to a high variance ELBO. With a batch size of 10 documents, only 1 in 100 draws will show a botany paper and one needs to use a very small learning rate.
A more desirable mini-batching scheme is manually balance out the mini-batches to ensure equal representation of each document, let's say, to pick 5 from legal, 5 from probabilistic programming, and 5 from botany each time. To get an unbiased estimator for ELBO, though, we need to weight each legal document by 1000 / 5 = 200, each probabilistic programming document by 100 / 5 = 20, and each botany paper by 10 / 5 = 2. This can be achieved, for instance, by computing these normalizing weights as a part of the mini-batching strategy, to pass them to model
(and the guide
), and to use poutine.scale
to scale each document accordingly:
def model(data, doc_weights, args):
with pyro.plate("topics", args.num_topics):
topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
topic_words = pyro.sample(
"topic_words", dist.Dirichlet(torch.ones(args.num_words) / args.num_words))
with poutine.scale(doc_weights):
with pyro.plate("documents", args.num_docs):
doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights))
with pyro.plate("words", args.num_words_per_doc):
word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics))
pyro.sample("doc_words", dist.Categorical(topic_words[word_topics]), obs=data)
In the above example, data
has shape [num_words_per_doc, num_docs]
, num_docs
is 15, doc_weights
has shape [num_docs]
and specifies the responsibility of each document in the log likelihood:
doc_weights = (20, 20, 20, 20, 20, 5, 5, 5, 5, 5, 1, 1, 1, 1, 1)
The desirable behavior with parallel enumeration is to calculate the log_prob
for every configuration [num_words, num_topics, num_words_per_doc, num_docs]
, to reduce it to [num_words_per_doc, num_docs]
via two logsumexp
, and to multiply the reduced log_prob
with doc_weights
at the very end.
ps1> @fritzo, this is a low priority feature request :-) though, smart stratified sampling and better yet, importance sampling of mini-batches can speed up training by many orders of magnitude in certain problems (see https://arxiv.org/abs/1602.02283).
ps2> My personal use case is working with single-cell gene expression data where the dynamic range of gene expression is very wide: some genes are expressed in every cell with ~ 200 copies whereas some genes are expressed in less than 1% of all cells. Naïve (uniform) mini-batching of genes results in almost complete disappearance of the lowly expressed (yet, important) genes.
It would be useful to consider support for non-scalar
poutine.scale
withTraceEnum_ELBO
. In general, importance sampling (such as stratified sampling) across different plates requires scaling different samples in the ELBO differently in order to obtain an unbiased estimator of the ELBO. This can be currently done usingpoutine.scale
andTrace_ELBO
. However,TraceEnum_ELBO
does not support non-scalar scale factors at the moment. Therefore, one can not perform stratified mini-batching and full enumeration on models with discrete latent variables.So long as the scale factors have the correct batch dimension, I believe the operation is well-defined: as usual, enumeration dimensions are created on the left end,
logsumexp
is performed to reduce thelog_prob
to the batch shape, and ultimately the non-scalar scale factor is multiplied.