Open gui11aume opened 1 year ago
@fritzo It seems that you wrote the following comment in pyro/distributions/score_parts.py
71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 17) def scale_and_mask(self, scale=1.0, mask=None):
71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 19) Scale and mask appropriate terms of a gradient estimator by a data multiplicity factor.
71b7d3da0 (Fritz Obermeyer 2018-08-26 10:51:42 -0700 20) Note that the `score_function` term should not be scaled or masked.
I ran the tests in pyro/tests/infer/test_inference.py
with masking and/or scaling score_function
but there was no difference so I could not understand why score_function
should not be masked. Maybe there is a test somewhere else that depends on this? If not, do you remember the case you had in mind when you designed this part?
If you point me in the right direction I can write additional tests and work on a pull request for this issue. Thanks!
Hi @gui11aume, responding to your last question: the best tests we have of ScoreParts
behavior are in tests/infer/test_gradient.py. We have heavily relied on those test because they are much faster than end-to-end inference-as-optimization tests since the gradient tests compute only a single gradient update. Because the gradient tests are fast, we can run them over a large grid of model and inference configurations via @pytest.mark.parametrize
.
For example, this incorrect change to ScoreParts
:
diff --git a/pyro/distributions/score_parts.py b/pyro/distributions/score_parts.py
index 15d39156..bb758d82 100644
--- a/pyro/distributions/score_parts.py
+++ b/pyro/distributions/score_parts.py
@@ -25,6 +25,6 @@ class ScoreParts(
:type mask: torch.BoolTensor or None
"""
log_prob = scale_and_mask(self.log_prob, scale, mask)
- score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, scale, mask)
entropy_term = scale_and_mask(self.entropy_term, scale, mask)
return ScoreParts(log_prob, score_function, entropy_term)
results in a test failure in around one second:
Thanks @fritzo! That's so useful to understand how the code works. I'll start from there and think of ways to address the issues without breaking everything. By the way, I am reading your work of the past few years and I am so impressed. I don't say this very often, but you are an example to look up to.
I think I understand why the score.function
terms are not scaled. Let me know if this is correct @fritzo.
What you call log_r
in the code must refer to $\log p(x,z) - \log q_\theta(z)$, i.e., the integrand of the ELBO as in this article. With the REINFORCE estimator, the target is to compute $\nabla\theta \log q\theta(z) \cdot \log r + \nabla\theta \log r$, which is done by differentiating the surrogate loss $\log q\theta(z) \cdot \overline{\log r} + \log r$, where the horizontal bar means that the term is treated as a constant. The score.function
terms store the values of $\log q_\theta(z)$ and $\overline{\log r}$ is computed by the function _compute_log_r
(in the file infer/trace_elbo.py
). The term $\log r$ is computed beforehand by subtracting all the $\log p(x,z)$ terms (i.e., log_prob
in the model sites) with the $\log q_\theta(z)$ terms (i.e., log_prob
in the guide sites).
Loosely speaking, if we estimate the gradient of the ELBO with only half of the terms, then we have to multiply $\log r$ by 2 to maintain the expected value. But we should not multiply $\log q\theta(z)$ by 2, otherwise we estimate $4 \nabla\theta \log q\theta(z) \cdot \log r + 2 \nabla\theta \log r$. So, internally you multiply the log_prob
terms by 2 (in the model and in the guide), which updates $\log r$ but leaves $\log q_\theta(z)$ as is.
As far as I understand, scaling and masking are processed together in scale_and_mask
, but I think they obey different rules. I'll clarify this in a separate comment. For now I just want to mention that the tests seem to pass when you mask the score_function
terms but do not scale them.
- score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled
I believe that this is because there are no tests to check the gradient when masking is active. I'll see if I can write some for this case, along the lines of those in test_subsample_gradient
.
@gui11aume yes, your explanation sounds right. I'm sorry I don't recall why masking behaves differently from scaling; I vaguely recall there was a reason, but I forget whether that reason was due to deep mathematics or merely incidental complexity in our implementation. Additional tests would be great!
Thanks for confirming @fritzo. Below is a very long post, I don't expect anyone to read it. It is mostly here for reference, to keep track of my rationale.
After some thinking, my opinion is that it should be allowed to have sites in the model but not in the guide (the bad_sites
), because some cases are fully legitimate. To build one, we start a typical example with a global Gaussian parameter and local Gaussian observations (no missing observation or masking for now).
import pyro
import pyro.distributions as dist
import torch
def model(data):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data)
return
def guide(data):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2.]))
print(pyro.param("loc"), pyro.param("scale"))
# tensor([0.9798], requires_grad=True) tensor([0.7382], requires_grad=True)
The inference is correct: in this case the posterior distribution is $N(\mu=x/2, \sigma=1/\sqrt{2})$. Now, say that some observations are missing. We just need to add an obs_mask
field to the sample x
.
import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))
# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
# tensor([0.9631], requires_grad=True) tensor([0.7807], requires_grad=True)
We get a warning but the inference is correct. In the first model, we have only one observation (x = 2.0
). In the second model, we have two observations, but only the first is known (still x = 2.0
). The cases are indistinguishable. Notice how with masking, some log-likelihoods have to be removed, which shifts the expected ELBO. This is an important difference with scaling.
Moving on, if a masked variable has no r_sample
(it cannot be reparametrized to obtain pathwise derivatives), then pyro
crashes. Below we artificially deactivate has_r_sample
for the Gaussian to force pyro
to use the REINFORCE estimator.
import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
z_dist.has_rsample = False # <== Pretend we cannot use the reparametrization trick.
pyro.sample("z", z_dist)
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc"), pyro.param("scale"))
# (...)/pyro/util.py:303: UserWarning: Found vars in model but not guide: {'x_unobserved'}
# warnings.warn(f"Found vars in model but not guide: {bad_sites}")
# Traceback (most recent call last):
# File "tmp3.py", line 20, in <module>
# svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
# File "(...)/pyro/infer/svi.py", line 145, in step
# loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
# File "(...)/pyro/infer/trace_elbo.py", line 141, in loss_and_grads
# loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
# File "(...)/pyro/infer/trace_elbo.py", line 106, in _differentiable_loss_particle
# log_r = _compute_log_r(model_trace, guide_trace)
# File "(...)/pyro/infer/trace_elbo.py", line 27, in _compute_log_r
# log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"]
# KeyError: 'x_unobserved'
The code should have the same behavior as before, masking has nothing to do with r_sample
in theory. I found three places where the code can fail: pyro/infer/trace_elbo.py:27
, pyro/infer/trace_mean_field_elbo.py:112
and pyro/infer/tracegraph_elbo.py:221
. In all three cases, I think that the solution is to do nothing if a site of the model is not found in the guide (tests are on the way).
# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_elbo.py#L20C1-L29C17
def _compute_log_r(model_trace, guide_trace):
log_r = MultiFrameTensor()
stacks = get_plate_stacks(model_trace)
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
log_r_term = model_site["log_prob"]
if not model_site["is_observed"]:
log_r_term = log_r_term - guide_trace.nodes[name]["log_prob"] # <== This can fail.
log_r.add((stacks[name], log_r_term.detach()))
return log_r
# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/trace_mean_field_elbo.py#L107C1-L114C63
for name, model_site in model_trace.nodes.items():
if model_site["type"] == "sample":
if model_site["is_observed"]:
elbo_particle = elbo_particle + model_site["log_prob_sum"]
else:
guide_site = guide_trace.nodes[name] # <== This can fail.
if is_validation_enabled():
check_fully_reparametrized(guide_site)
# https://github.com/pyro-ppl/pyro/blob/0e82cad30f75b892a07e6c9a5f9e24f2cb5d0d81/pyro/infer/tracegraph_elbo.py#L217C1-L223C66
# construct all the reinforce-like terms.
# we include only downstream costs to reduce variance
# optionally include baselines to further reduce variance
for node, downstream_cost in downstream_costs.items():
guide_site = guide_trace.nodes[node] # <== This can fail.
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
score_function = guide_site["score_parts"].score_function
Let us go back to the second case, where the reparametrization trick is available, and let us try to infer the distribution of the missing value of x
. As we have seen, we need to add to the guide a sample called x_unobserved
. We also wrap it in a poutine.mask
with a mask in mirror-image of the observations because we have to sample observed values even if we do not need them. This way, the dummy samples have no influence at all.
import pyro
import pyro.distributions as dist
import torch
def model(data, mask):
z = pyro.sample("z", dist.Normal(0, 1))
with pyro.plate("data", len(data)):
pyro.sample("x", dist.Normal(z, 1), obs=data, obs_mask=mask)
return
def guide(data, mask):
loc = pyro.param("loc", lambda: torch.tensor([0.]))
scale = pyro.param("scale", lambda: torch.tensor([1.]))
z_dist = dist.Normal(loc, scale)
pyro.sample("z", z_dist)
with pyro.plate("data", len(data)):
loc_x = pyro.param("loc_x", lambda: torch.tensor([0., 0.]))
scale_x = pyro.param("scale_x", lambda: torch.tensor([1., 1.]))
with pyro.poutine.mask(mask=~mask):
pyro.sample("x_unobserved", dist.Normal(loc_x, scale_x))
svi = pyro.infer.SVI(model, guide, pyro.optim.Adam({"lr": 0.01}), pyro.infer.Trace_ELBO())
for step in range(2000):
svi.step(torch.tensor([2., 0.]), torch.tensor([True, False]))
print(pyro.param("loc_x"), pyro.param("scale_x"))
# tensor([0.0000, 0.9787], requires_grad=True) tensor([1.0000, 1.0632], requires_grad=True)
As expected, the warning is gone. We also see that parameters for the observed values are exactly as they were initialized, meaning that for these values, the gradient was 0 throughout, as expected. It is therefore the behavior that should be achieved when the variables have no r_sample
.
This can be done with the update mentioned above, where the terms of score_function
are masked but not scaled.
- score_function = self.score_function # not scaled
+ score_function = scale_and_mask(self.score_function, 1.0, mask) # not scaled
I have written some new tests and I will open a pull request draft shortly.
I have opened pull request https://github.com/pyro-ppl/pyro/pull/3265 with the changes discussed above and some new tests involving gradients with poutine.mask
.
Summary
Inference for partially observed discrete variables occasionally produces some counter-intuitive results. Those are not bugs but users may waste a lot of time dealing with them or trying to understand them. The behavior has been tested on Pyro 1.8.5 and 1.8.6.
A simple example with coins
The example below is meant to show in which kind of context the issues appear. It is artificial and has no practical applications, but it is inspired from real examples I stumbled upon. In the model, we flip a fair coin and do not show the result; if it lands 'heads' we flip a coin with bias 0.05; if it lands 'tails' we flip a coin with bias 0.95. We always observe the result of the biased coin (but not which coin was flipped). In the guide, we simply sample the unbiased coin.
The result is correct, the second coin landed 'tails' so the posterior probability that the unbiased coin landed 'tails' is 0.95.
Issue 1: Code failure when masking
If the second coin is sometimes observed, we can introduce an observation mask for the
obs
sample. Let us modify the code and run the same example, i.e., we specify that the second coin landed 'tails' and this is observed.The code fails with the error below.
First there is a warning for missing site in the guide and then a
KeyError
for the same reason. This is counter-intuitive: either guides with missing sites should be allowed (warn only), or they should not (raise an error for every guide with missing sites).The error comes from a part of the code that evaluates the loss using the REINFORCE estimator (i.e., when the reparametrization trick cannot be used, as in the case of discrete random variables). Line 27 in
trace_elbo.py
assumes that every unobserved site in the model also exists in the guide. The user may not be aware that the siteobs_unobserved
is created in the model (but not in the guide) as soon as the argumentobs_mask
is notNone
.The solution is to define the sample
obs_unobserved
in the guide (see how below), but there are barely any mentions of this, so we cannot assume that users will do it. If guides with missing sites are allowed, line 27 intrace_elbo.py
should be replaced with a fail-safe version. Ideally, a message could point users in the right direction if Pyro creates an_unobserved
site that is not in the guide.Issue 2: Counter-intuitive gradient
Now if the unbiased coin is sometimes observed, we can introduce an observation mask for the
unbiased
sample, together with some observations when they are available. As mentioned above, we need to add a site in the guide calledunbiased_unobserved
explaining what to do when the coin is not observed (i.e., sample it as we were doing until now). We have to sample the whole tensor; Pyro will automatically mix in observed and sampled values for us as needed.Some values in
unbiased_unobserved
are sampled for nothing: they will be replaced with the observed values if they are available. In this case, the sampled values have no effect on the inference, but just to be sure, we are going to mask them in the guide to set theirlog_prob
terms to 0. We do this by usingpoutine.mask
where we invert the observation mask with~mask
.In this example, we observed the first two flips of the unbiased coin, but not the third. We set the value to
heads
with0
but this is irrelevant because the value is never used throughout the inference. The inference is correct for the third flip and there is nothing to infer for the first two flips because the values were observed... So why did the values ofpost_p
change from the initial0.5
and what do the current values represent?As far as I understand, the values have no special meaning. There is nothing to infer anyway. So why did they change? Once again, this has to do with the way Pyro evaluates the loss using the REINFORCE estimator. Internally, it keeps track of a
log_prob
term and ascore_function
term for the sites of the guide. Thelog_prob
terms are masked but not thescore_function
terms, so all the values ofunbiased_unobserved
contribute to the gradient, even those that are overwritten by observed values.I don't think that this has side effects, so this is not really a bug. The issue here is that Pyro is difficult enough to debug, and erratic behaviors make it harder. It would help if parameters that have no effect on the inference have gradient 0, so that the user gets alerted when there is an error in the model (e.g., when values that should have no effect on the inference do in fact have an effect).