pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 233 forks source link

Plates with random vectors #1018

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

Hello!


I'm trying to implementing a model in numpyro (PMF), and I'm stuck on a particular bit to do with plates and vector multiplication.

I’m using the following simplified model: with i from 1 to N, and j from 1 to M:

I show 2 models below, one uses jnp.dot for the multiplication of U_i and V_j, and the other one uses standard float multiplication. As D=1, these two models should be identical. However from looking at the trace plots we can see that they are not.

I also tried the even simpler model of having a common U and V for all the data (so without using plates for U and V). In that case using jnp.dot and standard float multiplication gives identical samples. So the issue seems to be the interactions between plates and vector random variables.

I looked at the examples in the docs but couldn't find any examples that use both plates and vector random variables. So I'm not sure if I'm not defining my model correctly or if this is a bug.

Here is the code that completely reproduces the issue (I use a small sample from the Movielens dataset):

import matplotlib.pyplot as plt

import numpy as np

import jax.numpy as jnp
from jax import random, vmap
from jax.scipy.special import logsumexp

import numpyro
from numpyro.diagnostics import hpdi
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

# a small sample of MovieLens data
user_IDs = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 4, 6, 6, 6, 6, 6, 6,
             7, 7, 7, 7, 8, 9])
film_IDs = np.array([ 1,  2,  3,  4,  5,  7,  8,  9, 11, 13,  1, 10, 14, 11,  1,
              7,  8,  9, 12, 13,  4,  9, 11, 12, 11,  7])
ratings = np.array([ 1.4716499, -0.5283501,  0.4716499, -0.5283501, -0.5283501,
              0.4716499, -2.52835  ,  1.4716499, -1.5283501,  1.4716499,
              0.4716499, -1.5283501,  0.4716499,  0.4716499,  0.4716499,
             -1.5283501,  0.4716499,  0.4716499,  0.4716499, -1.5283501,
              1.4716499,  1.4716499, -0.5283501,  1.4716499, -0.5283501,
              0.4716499])

# define 2 models that should be exactly the same.
def pmf_model_1(user_IDs, film_IDs, ratings):
    D = 1
    num_users = len(np.unique(user_IDs))
    num_films = len(np.unique(film_IDs))

    with numpyro.plate('plate_user', num_users):
        U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)))
    with numpyro.plate('plate_film', num_films):
        V = numpyro.sample("V_j", dist.Normal(jnp.zeros(D), jnp.ones(D)))

    # use matrix multiplication
    est_rating = jnp.dot(U[user_IDs], V[film_IDs])

    with numpyro.plate("data", len(ratings)):
        numpyro.sample("obs", dist.Normal(est_rating, 1.), obs=ratings)

def pmf_model_2(user_IDs, film_IDs, ratings):
    D = 1
    num_users = len(np.unique(user_IDs))
    num_films = len(np.unique(film_IDs))

    with numpyro.plate('plate_user', num_users):
        U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)))
    with numpyro.plate('plate_film', num_films):
        V = numpyro.sample("V_j", dist.Normal(jnp.zeros(D), jnp.ones(D)))

    # float multiplication
    est_rating = U[user_IDs]*V[film_IDs]

    with numpyro.plate("data", len(ratings)):
        numpyro.sample("obs", dist.Normal(est_rating, 1.), obs=ratings)

# run model 1
print("Running model 1..")
hmc_kernel1 = NUTS(pmf_model_1)
mcmc1 = MCMC(hmc_kernel1, num_samples=1000, num_warmup=1000, num_chains=1, progress_bar=False)

key = random.PRNGKey(0)
mcmc1.run(key, user_IDs=user_IDs, film_IDs=film_IDs, ratings=ratings)

posterior_pmf1 = mcmc1.get_samples()

# run model 2
print("Running model 2..")
hmc_kernel2 = NUTS(pmf_model_2)
mcmc2 = MCMC(hmc_kernel2, num_samples=1000, num_warmup=1000, num_chains=1, progress_bar=False)

key = random.PRNGKey(0)
mcmc2.run(key, user_IDs=user_IDs, film_IDs=film_IDs, ratings=ratings)

posterior_pmf2 = mcmc2.get_samples()

# plot one of the trace plots
plt.plot(posterior_pmf1['U_i'][:,0], label="model 1")
plt.plot(posterior_pmf2['U_i'][:,0], label="model 2", marker="+")
plt.legend(fontsize=16)
plt.show()
fehiepsi commented 3 years ago

@jeremiecoullon It seems to me that jnp.dot(x, y) is different from x * y. Could you add some print statements to see if you have the same U, V, and est_rating? FYI, with

with numpyro.plate('plate_user', num_users):
    U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)))

the site U_i will have shape (D,) and num_users == D. If you want U_i has shape (num_users, D), then you can use

with numpyro.plate('plate_user', num_users):
    U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))
    # or U = numpyro.sample("U_i", dist.Normal(0, 1).expand([D]).to_event(1))
jeremiecoullon commented 3 years ago

Ah I didn't know about .to_event(), thanks!

I had tried yesterday printing out the shapes of the variables but for some reason it didn't print anything. I tried it again today and it works, so I must have been doing something wrong yesterday.. :p

I modified my model to be the following:

def do_inner(U_i, V_j):
    return jnp.dot(U_i, V_j)

batch_inner = vmap(do_inner, in_axes=(0,0))

def pmf_model_1(user_IDs, film_IDs, ratings):
    alpha = 2
    D = 1
    num_users = len(np.unique(user_IDs))
    num_films = len(np.unique(film_IDs))

    with numpyro.plate('plate_user', num_users):
        U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))
    with numpyro.plate('plate_film', num_films):
        V = numpyro.sample("V_j", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))

    est_rating = batch_inner(U[user_IDs], V[film_IDs])
    with numpyro.plate("data", len(ratings)):
        numpyro.sample("obs", dist.Normal(est_rating, 1/alpha), obs=ratings)

So now everything has the correct shapes, and changing the dimension D works fine:

The only thing I'm still a bit confused is what the to_event does and why it's there. I've read the docs but I don't get what "dependent event dimensions" means. It seems that if I don't include to_event(1) it just ignores the dimension D; is this correct?

fritzo commented 3 years ago

I'm still a bit confused

@jeremiecoullon you might take a look at Pyro's Tensor Shapes Tutorial. That tutorial is based on Pyro rather than NumPyro, but the shape concepts are common.

jeremiecoullon commented 3 years ago

@fritzo : ah ok thanks for the link!