Open AlexanderAivazidis opened 2 years ago
Hi @AlexanderAivazidis thanks for pointing out this bug. I wonder if the issue is with .expand()
? Could you try removing the .expand()
(instead expanding the arguments before creating the distribution)? And could you check to see if the result of .expand()
is still a RelaxedBernoulliStraightThrough
object, or if perhaps it has been type-changed to a RelaxedBernoulli
object?
Hi,
after removing the expand I still get the same problems and don't see discrete samples anywhere during training:
from pyro.distributions import RelaxedBernoulliStraightThrough
RelaxedBernoulliStraightThrough.mean = property(lambda self: self.probs)
p_m = pyro.sample('p_m', dist.Beta(self.activation_probability_alpha,
self.activation_probability_beta
).expand([1,1,self.n_modules]).to_event(3))
with obs_plate:
I_cm = pyro.sample('I_cm',
RelaxedBernoulliStraightThrough(probs = p_m,
temperature = self.one/1000.))
print('I_cm', I_cm)
I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)
I have also made this minimum example with the RelaxedBernoulliStraightThrough used in a Gaussian Mixture Model:
https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/RelaxedBernoulliMinimalExample.ipynb
Best wishes,
Alexander
We are using AutoHierarchicalNormalMessenger/AutoNormalMessenger guides. Could it be that RelaxedBernoulliStraightThrough needs a custom guide for sampling this variable which does hard sampling (rather than the usual correct_transform(Normal sample)
)? Could this be a source of the problem and a solution?
@vitkl i don't know all the details of how AutoHierarchicalNormalMessenger/AutoNormalMessenger are constructed but basically yes. maybe you can block
the bernoulli sites when creating your autoguide and use an AutoGuideList
etc
@karalets do you recall your return value assumptions on RelaxedBernoulliStraightThrough
?
I am considering trying something like this:
guide = AutoGuideList(model)
guide.append(AutoHierarchicalNormalMessenger(poutine.block(model, hide=["bernoulli_site"])))
guide.append(AutoDiscreteParallel(poutine.block(model, expose=["bernoulli_site"])))
svi = SVI(model, guide, optim, Trace_ELBO())
But I wonder what would be the correct guide for the "bernoulli_site"? Here I have put in AutoDiscreteParallel, just to have a complete example.
@AlexanderAivazidis depending on details (e.g. how many sites with bernoulli latent variables you have) it's probably easier to easier to build the "discrete" guide by hand; afaik it doesn't need to be an AutoGuide
it just needs to be callable
(e.g. a function)
Hi @martinjankowiak
By manually building a discrete guide you mean something like below?
def model(data):
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
c = pyro.sample("c", dist.RelaxedBernoulliStraightThrough(0.5))
pyro.sample("obs", dist.Normal((a+b) * c, 1), obs=data)
class MyGuideMessenger(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name == "c":
# Use a custom distribution at site c.
loc = pyro.param("c_loc", lambda: torch.zeros(()))
scale = pyro.param("c_scale", lambda: torch.ones(()),
constraint=constraints.positive)
return transform_that_yeilds_discrete_samples(dist.Normal(loc, scale))
# Fall back to mean field.
return super().get_posterior(name, prior)
What would you use in place of transform_that_yeilds_discrete_samples
?
@vitkl i believe you can do something like this
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoGuideList, AutoHierarchicalNormalMessenger
from pyro.optim import ClippedAdam
from pyro.poutine import block
def model():
pyro.sample("z", dist.Normal(0.0, 1.0))
pyro.sample("b", dist.Bernoulli(probs=0.5))
def b_guide():
pyro.sample("b", dist.RelaxedBernoulliStraightThrough(probs=0.5, temperature=torch.tensor(1.0)))
guide = AutoGuideList(model)
guide.append(AutoHierarchicalNormalMessenger(block(model, hide=["b"])))
guide.append(b_guide)
svi = SVI(model, guide, ClippedAdam({"lr": 0.001}), Trace_ELBO())
svi.step()
I see! So using Bernoulli in the model and RelaxedBernoulliStraightThrough in the guide.
Do you think it is necessary to use AutoGuideList, not the approach below?
class MyGuideMessenger(AutoNormalMessenger):
def get_posterior(self, name, prior):
if name == "b":
b_probs = pyro.param("b_probs", lambda: torch.tensor(0.5), constrains=...)
dist.RelaxedBernoulliStraightThrough(probs=b_probs, temperature=torch.tensor(1.0))
return super().get_posterior(name, prior)
@vitkl i'm not very familiar with the internals of AutoNormalMessenger
but something like that (but returning a distribution in the if branch) may be ok too. if it runs and if you trace the guide and get the samples you expect it's probably ok.
@fritzo would have a better idea.
We could try to get my minimum example gaussian mixture model to work:
https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/RelaxedBernoulliMinimalExample.ipynb (I made sure the notebook can be rendered this time)
I have added the guide that @martinjankowiak suggested there, but so far I classify samples with only 50% accuracy. So no learning seems to take place and I must have done something wrong. But I do see discrete samples during training. I am wondering also how this kind of guide can take into account prior probabilities for the Bernoulli sampled from another distribution, like "ps" in the minimum example model:
def model(data, n_components, n_observations, i):
ps = pyro.sample('ps', dist.Dirichlet(torch.ones(n_components)/10.))
mus = pyro.sample('mus', dist.Gamma(20,2).expand([n_components]).to_event(1))
b = pyro.sample('b', dist.Bernoulli(probs = ps[0]).expand([n_observations]).to_event(1))
mean = mus[0] + mus[1]*b
pyro.sample("data_target", dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)
I see! So using Bernoulli in the model and RelaxedBernoulliStraightThrough in the guide.
Do you think it is necessary to use AutoGuideList, not the approach below?
class MyGuideMessenger(AutoNormalMessenger): def get_posterior(self, name, prior): if name == "b": b_probs = pyro.param("b_probs", lambda: torch.tensor(0.5), constrains=...) dist.RelaxedBernoulliStraightThrough(probs=b_probs, temperature=torch.tensor(1.0)) return super().get_posterior(name, prior)
@vitkl , do you think it is enough to adjust "get_posterior" method? Maybe this does not change what the guide does during training (unless the get_posterior method is also called during training).
@AlexanderAivazidis i think it's important to distinguish between something being formally correct (code does the sequence of expected mathematical operations) and actually working (the code converges to an acceptable solution after appropriate hyperparameters have been chosen, enough steps have been done, etc). for a sufficiently difficult stochastic optimization you can have a formally "correct" algorithm that still doesn't work because of bad hyperparameters etc. e.g. in the case of these relaxed distributions the computed gradients are biased and potentially high-variance, which can lead to all kinds of issues
in your case i think the main problem is that you're trying to be too bayesian by putting priors on ps
and mus
. if you make these parameters i seem to get reasonable results:
def model(data, n_components, n_observations, i):
ps = pyro.param('ps', torch.ones(n_components)/10., constraint=constraints.simplex)
mus = pyro.param('mus', torch.ones(n_components), constraint=constraints.positive)
b = pyro.sample('b', dist.Bernoulli(probs = ps[0]).expand([n_observations]).to_event(1))
mean = mus[0] + mus[1]*b
pyro.sample("data_target", dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)
can we take a step back? what is your actual goal with these bernoulli latent variables? what do they encode? are they cell-level local latent variables? something else? do they encode component membership in a mixture distribution? if the latter why can't you use a mixture distribution directly?
I see your point, that we expect the RelaxedBernoulliStraightThrough distribution to give worse results than exact enumeration for example. And it could be that the results we get are the fundamental performance limit of this kind of setup. I have tried to understand what "acceptable" performance would mean by solving the Minimal Gaussian Mixture Problem also with the RelaxedBernoulli and with Bernoulli+exact enumeration. With exact enumeration we assign 98% of points to the right component, see here:
https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/MinimalExample_Enumeration.ipynb
If we round the RelaxedBernoulli posterior parameters to 0 and 1, we assign 82% to the right component:
My expectation is that the RelaxedBernoulliStraightThrough distribution should classify between 82% and 98% correct, do you agree? And I would also expect that this value increases with lower temperature, if we always train to convergence. @vitkl is trying to implement hierarchical guides for the RelaxedBernoulliStraightThrough, based on our discussion. He will probably give an update at some point, but I think so far he has seen worse performance than 82%. @martinjankowiak did you get better performance than 82%? Do you agree with how we investigate this problem?
I would be happy to explain my model in more detail if it is necessary, but it would take more time. For now let's just say that indeed the Bernoulli variables are cell-level local latent variables. Typically, there are 10s to 100s of these variables, but only 2 or 3 are typically "ON" (= 1) in a cell. So we are essentially trying to infer the correct combination of latent variables for each cell. I found that I cannot use enumeration/mixture distributions, because it is computationally too expensive to account for all possible combinations of latent variable states. And I cannot use the RelaxedBernoulli, because even small values above 0 - when the variable should really be 0 - result in completely wrong parameter inference of other variables in my model. Even though this sounds like an almost impossible optimization problem I am quite confident it should work in practice, because I can use informative initialisation for the Bernoulli variables that will be very close to the true result.
If we get "acceptable" performance of the RelaxedBernoulliStraightThrough distribution I would be happy to write this up into a tutorial for the pyro community. I would also try to include such examples where the RelaxedBernoulliStraightThrough is essential vs others where the Relaxed Bernoullis is fine or enumeration works fine.
Hi,
I have used the RelaxedBernoulliStraightThrough distribution in my model in a code block like this:
However, I do not see discrete samples in either:
1.) the posterior samples for I_cm 2.) the posterior samples for I_cm_tracking when I add
I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)
3.) in the printed output during training when I add:
print(I_cm)
So I wonder did you test that RelaxedBernoulliStraightThrough indeed gives discrete samples in the foward pass during training?
Thanks!
Alexander