pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

RelaxedBernoulliStraightThrough seems to give continuous samples when used in conjunction AutoHierarchicalNormalMessenger/AutoNormalMessenger #3127

Open AlexanderAivazidis opened 2 years ago

AlexanderAivazidis commented 2 years ago

Hi,

I have used the RelaxedBernoulliStraightThrough distribution in my model in a code block like this:

        with obs_plate:
            I_cm = pyro.sample('I_cm',
                               RelaxedBernoulliStraightThrough(probs = p_m,
                                                                     temperature = self.one/1000.
                                                                     ).expand([batch_size, 1, self.n_modules]))

However, I do not see discrete samples in either:

1.) the posterior samples for I_cm 2.) the posterior samples for I_cm_tracking when I add

I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)

3.) in the printed output during training when I add:

print(I_cm)

So I wonder did you test that RelaxedBernoulliStraightThrough indeed gives discrete samples in the foward pass during training?

Thanks!

Alexander

fritzo commented 2 years ago

Hi @AlexanderAivazidis thanks for pointing out this bug. I wonder if the issue is with .expand()? Could you try removing the .expand() (instead expanding the arguments before creating the distribution)? And could you check to see if the result of .expand() is still a RelaxedBernoulliStraightThrough object, or if perhaps it has been type-changed to a RelaxedBernoulli object?

AlexanderAivazidis commented 2 years ago

Hi,

after removing the expand I still get the same problems and don't see discrete samples anywhere during training:

        from pyro.distributions import RelaxedBernoulliStraightThrough
        RelaxedBernoulliStraightThrough.mean = property(lambda self: self.probs)

        p_m = pyro.sample('p_m', dist.Beta(self.activation_probability_alpha,
                                           self.activation_probability_beta
                                          ).expand([1,1,self.n_modules]).to_event(3))
        with obs_plate:
            I_cm = pyro.sample('I_cm',
                               RelaxedBernoulliStraightThrough(probs = p_m,
                                                               temperature = self.one/1000.))

        print('I_cm', I_cm)

        I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)

I have also made this minimum example with the RelaxedBernoulliStraightThrough used in a Gaussian Mixture Model:

https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/RelaxedBernoulliMinimalExample.ipynb

Best wishes,

Alexander

vitkl commented 2 years ago

We are using AutoHierarchicalNormalMessenger/AutoNormalMessenger guides. Could it be that RelaxedBernoulliStraightThrough needs a custom guide for sampling this variable which does hard sampling (rather than the usual correct_transform(Normal sample))? Could this be a source of the problem and a solution?

martinjankowiak commented 2 years ago

@vitkl i don't know all the details of how AutoHierarchicalNormalMessenger/AutoNormalMessenger are constructed but basically yes. maybe you can block the bernoulli sites when creating your autoguide and use an AutoGuideList etc

fritzo commented 2 years ago

@karalets do you recall your return value assumptions on RelaxedBernoulliStraightThrough?

AlexanderAivazidis commented 2 years ago

I am considering trying something like this:

guide = AutoGuideList(model)
guide.append(AutoHierarchicalNormalMessenger(poutine.block(model, hide=["bernoulli_site"])))
guide.append(AutoDiscreteParallel(poutine.block(model, expose=["bernoulli_site"])))
svi = SVI(model, guide, optim, Trace_ELBO())

But I wonder what would be the correct guide for the "bernoulli_site"? Here I have put in AutoDiscreteParallel, just to have a complete example.

martinjankowiak commented 2 years ago

@AlexanderAivazidis depending on details (e.g. how many sites with bernoulli latent variables you have) it's probably easier to easier to build the "discrete" guide by hand; afaik it doesn't need to be an AutoGuide it just needs to be callable (e.g. a function)

vitkl commented 2 years ago

Hi @martinjankowiak

By manually building a discrete guide you mean something like below?

def model(data):
    a = pyro.sample("a", dist.Normal(0, 1))
    b = pyro.sample("b", dist.Normal(0, 1))
    c = pyro.sample("c", dist.RelaxedBernoulliStraightThrough(0.5))
    pyro.sample("obs", dist.Normal((a+b) * c, 1), obs=data)

class MyGuideMessenger(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "c":
            # Use a custom distribution at site c.
            loc = pyro.param("c_loc", lambda: torch.zeros(()))
            scale = pyro.param("c_scale", lambda: torch.ones(()),
                               constraint=constraints.positive)
            return transform_that_yeilds_discrete_samples(dist.Normal(loc, scale))
        # Fall back to mean field.
        return super().get_posterior(name, prior)

What would you use in place of transform_that_yeilds_discrete_samples?

martinjankowiak commented 2 years ago

@vitkl i believe you can do something like this

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoGuideList, AutoHierarchicalNormalMessenger
from pyro.optim import ClippedAdam
from pyro.poutine import block

def model():
    pyro.sample("z", dist.Normal(0.0, 1.0))
    pyro.sample("b", dist.Bernoulli(probs=0.5))

def b_guide():
    pyro.sample("b", dist.RelaxedBernoulliStraightThrough(probs=0.5, temperature=torch.tensor(1.0)))

guide = AutoGuideList(model)
guide.append(AutoHierarchicalNormalMessenger(block(model, hide=["b"])))
guide.append(b_guide)

svi = SVI(model, guide, ClippedAdam({"lr": 0.001}), Trace_ELBO())
svi.step()
vitkl commented 2 years ago

I see! So using Bernoulli in the model and RelaxedBernoulliStraightThrough in the guide.

Do you think it is necessary to use AutoGuideList, not the approach below?

class MyGuideMessenger(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "b":
            b_probs = pyro.param("b_probs", lambda: torch.tensor(0.5), constrains=...)
            dist.RelaxedBernoulliStraightThrough(probs=b_probs, temperature=torch.tensor(1.0))
        return super().get_posterior(name, prior)
martinjankowiak commented 2 years ago

@vitkl i'm not very familiar with the internals of AutoNormalMessenger but something like that (but returning a distribution in the if branch) may be ok too. if it runs and if you trace the guide and get the samples you expect it's probably ok.

@fritzo would have a better idea.

AlexanderAivazidis commented 2 years ago

We could try to get my minimum example gaussian mixture model to work:

https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/RelaxedBernoulliMinimalExample.ipynb (I made sure the notebook can be rendered this time)

I have added the guide that @martinjankowiak suggested there, but so far I classify samples with only 50% accuracy. So no learning seems to take place and I must have done something wrong. But I do see discrete samples during training. I am wondering also how this kind of guide can take into account prior probabilities for the Bernoulli sampled from another distribution, like "ps" in the minimum example model:

def model(data, n_components, n_observations, i):

    ps = pyro.sample('ps', dist.Dirichlet(torch.ones(n_components)/10.))
    mus = pyro.sample('mus', dist.Gamma(20,2).expand([n_components]).to_event(1))
    b = pyro.sample('b', dist.Bernoulli(probs = ps[0]).expand([n_observations]).to_event(1))                    

    mean = mus[0] + mus[1]*b

    pyro.sample("data_target", dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)
AlexanderAivazidis commented 2 years ago

I see! So using Bernoulli in the model and RelaxedBernoulliStraightThrough in the guide.

Do you think it is necessary to use AutoGuideList, not the approach below?

class MyGuideMessenger(AutoNormalMessenger):
    def get_posterior(self, name, prior):
        if name == "b":
            b_probs = pyro.param("b_probs", lambda: torch.tensor(0.5), constrains=...)
            dist.RelaxedBernoulliStraightThrough(probs=b_probs, temperature=torch.tensor(1.0))
        return super().get_posterior(name, prior)

@vitkl , do you think it is enough to adjust "get_posterior" method? Maybe this does not change what the guide does during training (unless the get_posterior method is also called during training).

martinjankowiak commented 2 years ago

@AlexanderAivazidis i think it's important to distinguish between something being formally correct (code does the sequence of expected mathematical operations) and actually working (the code converges to an acceptable solution after appropriate hyperparameters have been chosen, enough steps have been done, etc). for a sufficiently difficult stochastic optimization you can have a formally "correct" algorithm that still doesn't work because of bad hyperparameters etc. e.g. in the case of these relaxed distributions the computed gradients are biased and potentially high-variance, which can lead to all kinds of issues

in your case i think the main problem is that you're trying to be too bayesian by putting priors on ps and mus. if you make these parameters i seem to get reasonable results:

def model(data, n_components, n_observations, i):    
    ps = pyro.param('ps', torch.ones(n_components)/10., constraint=constraints.simplex)
    mus = pyro.param('mus', torch.ones(n_components), constraint=constraints.positive)
    b = pyro.sample('b', dist.Bernoulli(probs = ps[0]).expand([n_observations]).to_event(1))

    mean = mus[0] + mus[1]*b

    pyro.sample("data_target", dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)

can we take a step back? what is your actual goal with these bernoulli latent variables? what do they encode? are they cell-level local latent variables? something else? do they encode component membership in a mixture distribution? if the latter why can't you use a mixture distribution directly?

AlexanderAivazidis commented 2 years ago

I see your point, that we expect the RelaxedBernoulliStraightThrough distribution to give worse results than exact enumeration for example. And it could be that the results we get are the fundamental performance limit of this kind of setup. I have tried to understand what "acceptable" performance would mean by solving the Minimal Gaussian Mixture Problem also with the RelaxedBernoulli and with Bernoulli+exact enumeration. With exact enumeration we assign 98% of points to the right component, see here:

https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/MinimalExample_Enumeration.ipynb

If we round the RelaxedBernoulli posterior parameters to 0 and 1, we assign 82% to the right component:

https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/MinimalExample_RelaxedBernoulli.ipynb

My expectation is that the RelaxedBernoulliStraightThrough distribution should classify between 82% and 98% correct, do you agree? And I would also expect that this value increases with lower temperature, if we always train to convergence. @vitkl is trying to implement hierarchical guides for the RelaxedBernoulliStraightThrough, based on our discussion. He will probably give an update at some point, but I think so far he has seen worse performance than 82%. @martinjankowiak did you get better performance than 82%? Do you agree with how we investigate this problem?

I would be happy to explain my model in more detail if it is necessary, but it would take more time. For now let's just say that indeed the Bernoulli variables are cell-level local latent variables. Typically, there are 10s to 100s of these variables, but only 2 or 3 are typically "ON" (= 1) in a cell. So we are essentially trying to infer the correct combination of latent variables for each cell. I found that I cannot use enumeration/mixture distributions, because it is computationally too expensive to account for all possible combinations of latent variable states. And I cannot use the RelaxedBernoulli, because even small values above 0 - when the variable should really be 0 - result in completely wrong parameter inference of other variables in my model. Even though this sounds like an almost impossible optimization problem I am quite confident it should work in practice, because I can use informative initialisation for the Bernoulli variables that will be very close to the true result.

If we get "acceptable" performance of the RelaxedBernoulliStraightThrough distribution I would be happy to write this up into a tutorial for the pyro community. I would also try to include such examples where the RelaxedBernoulliStraightThrough is essential vs others where the Relaxed Bernoullis is fine or enumeration works fine.