Closed jeremiecoullon closed 3 years ago
@jeremiecoullon It seems to me that jnp.dot(x, y)
is different from x * y
. Could you add some print statements to see if you have the same U
, V
, and est_rating
? FYI, with
with numpyro.plate('plate_user', num_users):
U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)))
the site U_i
will have shape (D,) and num_users == D
. If you want U_i
has shape (num_users, D)
, then you can use
with numpyro.plate('plate_user', num_users):
U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))
# or U = numpyro.sample("U_i", dist.Normal(0, 1).expand([D]).to_event(1))
Ah I didn't know about .to_event()
, thanks!
I had tried yesterday printing out the shapes of the variables but for some reason it didn't print anything. I tried it again today and it works, so I must have been doing something wrong yesterday.. :p
I modified my model to be the following:
def do_inner(U_i, V_j):
return jnp.dot(U_i, V_j)
batch_inner = vmap(do_inner, in_axes=(0,0))
def pmf_model_1(user_IDs, film_IDs, ratings):
alpha = 2
D = 1
num_users = len(np.unique(user_IDs))
num_films = len(np.unique(film_IDs))
with numpyro.plate('plate_user', num_users):
U = numpyro.sample("U_i", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))
with numpyro.plate('plate_film', num_films):
V = numpyro.sample("V_j", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1))
est_rating = batch_inner(U[user_IDs], V[film_IDs])
with numpyro.plate("data", len(ratings)):
numpyro.sample("obs", dist.Normal(est_rating, 1/alpha), obs=ratings)
So now everything has the correct shapes, and changing the dimension D
works fine:
U
has shape (num_users, D)
U[user_IDs]
has shape (len(user_IDs), D)
est_rating
has shape (len(user_IDs),)
. Note that len(user_IDs)==len(ratings)
The only thing I'm still a bit confused is what the to_event
does and why it's there. I've read the docs but I don't get what "dependent event dimensions" means.
It seems that if I don't include to_event(1)
it just ignores the dimension D
; is this correct?
I'm still a bit confused
@jeremiecoullon you might take a look at Pyro's Tensor Shapes Tutorial. That tutorial is based on Pyro rather than NumPyro, but the shape concepts are common.
@fritzo : ah ok thanks for the link!
Hello!
I'm trying to implementing a model in numpyro (PMF), and I'm stuck on a particular bit to do with plates and vector multiplication.
I’m using the following simplified model: with i from 1 to N, and j from 1 to M:
I show 2 models below, one uses
jnp.dot
for the multiplication ofU_i
andV_j
, and the other one uses standard float multiplication. AsD=1
, these two models should be identical. However from looking at the trace plots we can see that they are not.I also tried the even simpler model of having a common
U
andV
for all the data (so without using plates for U and V). In that case usingjnp.dot
and standard float multiplication gives identical samples. So the issue seems to be the interactions between plates and vector random variables.I looked at the examples in the docs but couldn't find any examples that use both plates and vector random variables. So I'm not sure if I'm not defining my model correctly or if this is a bug.
Here is the code that completely reproduces the issue (I use a small sample from the Movielens dataset):