pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.52k stars 985 forks source link

Mixture models converging to a single mode #635

Closed kjchalup closed 6 years ago

kjchalup commented 6 years ago

Hi, thank you for developing this amazing project! I would really appreciate some help in clarifying how to define a GMM in Pyro.

I've been trying to implement various mixture models in Pyro. In all cases, they seem to do something unexpected. 1) It would be amazing if you guys added a simple Gaussian Mixture Model to the tutorials. This is a common model people are familiar with, so it'd clarify how to use Pyro a lot. 2) Below, I'm attaching a working example of a very simple "GMM" I tried to implement. For simplicity, I keep the component probabilities and sigmas constant, and try to learn the means of the Gaussians.

The true component means are -2 and 4. Even though I initialize training from that global optimum, the SVI optimizer pulls both means towards the value of 1, which is the mean of the two modes. Is this a problem with the variational approach, pyro, or am I doing something wrong?

import numpy as np
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.autograd import Variable

def make_data(n_data=1000):
    """ Sample n_data datapoints from a 1-d mixture of Gaussians. """
    p = [.5, .5]
    means = np.random.choice([-2, 4], n_data, p=p)
    x = torch.Tensor(np.random.randn(n_data) + means).view(n_data, 1)
    # Use TensorDataset, but only make use of the 'data_tensor' param.
    return DataLoader(TensorDataset(data_tensor=x, target_tensor=x),
        batch_size=n_data, shuffle=True)

def model(data):
    x = Variable(data[0]).squeeze()
    n_data = x.size()[0]

    # Define p_{theta}(z).
    mu = pyro.param('mu', Variable(torch.Tensor([-2, 4]), requires_grad=True))
    sigma = Variable(torch.ones(n_data)) # Constant variances.
    ps = Variable(torch.Tensor([.5, .5])) # Constant component probabilities.
    zs = pyro.sample('zs', dist.categorical, ps, batch_size=n_data)

    # Define p_{theta}(x | z).
    comp_ids = torch.max(zs, 1)[1] # Choose a component for each datapoint.
    mus = mu[comp_ids] # Assign the appropriate mean to each datapoint.
    pyro.observe('obs', dist.normal, x, mus, sigma) # Condition on data.

    return mu.data.numpy() # We will monitor mu as training progresses.

def guide(data):
    n_data = data[0].size()[0]

    # Sample q(z) from a categorical distribution.
    ps = Variable(torch.Tensor([.5, .5]))
    pyro.sample('zs', dist.categorical, ps, batch_size=n_data)

svi = pyro.infer.SVI(model=model,
                     guide=guide,
                     optim=pyro.optim.Adam({'lr': 0.1}),
                     loss='ELBO')

dataloader = make_data()
for t in range(100):
    losses = []
    for batch in dataloader:
        losses.append(svi.step(batch))
    mus = model(batch)
    print('err={}. mus={}'.format(np.mean(losses), mus))
fritzo commented 6 years ago

Hi Krzysztof, thanks for the feature request. I have an old gmm-tutorial branch that's a bit out-of-date, so I'll get that working and create a PR.

In your particular example, I suggest running SVI with enum_discrete=True to improve convergence. I'll take a closer look once I fix up my old tutorial.

kjchalup commented 6 years ago

I modified my code to use enum_discrete=True but unfortunately the problem persists. The same 'all the modes collapse to only one mode' problem also appears when I implement mixture density networks -- which were always very easy to train using max likelihood when I implemented them by hand. I suspect I don't understand how to use variational methods & pyro well enough.

I'm looking forward to see the tutorial!

fritzo commented 6 years ago

Ok #636 is the start of a Mixture Model tutorial. It's super simple, but should give you an idea of how to use pyro.iarange() with enum_discrete=True to get mixture models working.

kjchalup commented 6 years ago

This is great, it helped me figure out my issue, thank you so much! I had to modify your code from

mu_z = mu.index_select(0, z).unsqueeze(-1)

to

mu_z = torch.masked_select(mu, z.type(torch.ByteTensor)).unsqueeze(-1)

(the first version doesn't work on pytorch .2 and .3, didn't check other versions).

In case someone else stumbles upon here with a similar problem I had: when doing this kind of inference data order matters! My problem was that I used DataLoader(..., shuffle=True) to serve my data. Working with fritzo's script made me realize that's no good, and setting shuffle=False fixed the problem. As I see it now, the reason you can't reshuffle your data on each training iteration is that the z's are local variables, each assigned to an appropriate x instance.

fritzo commented 6 years ago

We agree that params-inside-iarange or inside-minibatching is confusing. See also #238.

1Reinier commented 6 years ago

I think I'm running into a similar issue. Should I avoid shuffling when working with params inside an iarange?

fritzo commented 6 years ago

Shuffling is fine, but make sure you are correctly using the parameter returned by pyro.param(). The returned param tensor is the same each pass through the loop, so to use it correctly you'll need to pull out the relevant slices with the shuffled indices.

  def model(data):
      assert len(data) == 100
      with pyro.iarange('data', data, subsample_size=10) as ind:
          subsample = data[ind]
          assert len(subsample) == 10
-         param = pyro.param('param', Variable(torch.zeros(10))        # Wrong
+         param = pyro.param('param', Variable(torch.zeros(100))[ind]  # Right
          assert len(param) == 10
          pyro.sample("obs", MyDist(param), obs=subsample)
vincent6606 commented 6 years ago

I am confused about putting pyro.params(...) in the model instead of the guide. All tutorials I have seen have params in the guide. Can you explain why this is not the case here?

In #636 the GMM there is an unconstrained_p and a unconstrained_ps, can you explain what is happening here?

I am trying to work out a Chinese restaurant process example, but the difference is that the number of clusters is growing and you don't know how many there are. Since it is growing, it is hard to register a parameter and back propagate it (a Variable that is increasing in number of elements) and store it as a param for output. Can you give any suggestions about this?

Thanks

fritzo commented 6 years ago

@vincent6606 pyro.params(...) in the model instead of the guide

Pyro supports optimizing over both guide parameters (for variational inference) and model parameters (for maximum likelihood or MAP inference). Putting a pyro.param('x') statement in the model is equivalent to putting a pair of statements pyro.param('x_param'); pyro.sample('x', Delta('x_param') in the guide and a pyro.param('x') statement in the model.

In #636 the GMM there is an unconstrained_p and a unconstrained_ps

In the GMM tutorial The guide learns tensor of datum-cluster assignment probabilites p of shape (len(data), num_clusters). This is the same soft-assignment probabilities that are learned in EM clustering. The model learns a set of cluster weights p of shape (num_clusters,). The ps is then expanded to a prior probability that each datum is assigned to each cluster. In this sense, the ps are prior assignments (before cluster means are known) and the p are per-datum posterior assignment probabilities (after the cluster means are estimated).

I've added issues #879 to clarify the tutorial.

(sorry for the slow reply; we've been very busy working towards our 0.2 release)