rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Example in readme running #57

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

The example the readme returns this error

SyntaxError: Expected a random variable of a constant to initialize distribution, got np.zeros(np.shape(x)[-1])
 instead.
Maybe you are trying to initialize a distribution directly, or call a function inside the distribution initialization. While this would be a perfectly legitimate move, it is currently not supported in mcx. Use an intermediate variable instead: 

Do not do `x <~ Normal(Normal(0, 1), 1)` or `x <~ Normal(my_function(10), 1)`, instead do `y <~ Normal(0, 1) & x <~ Normal(y, 1)` and `y = my_function(10) & x <~ Normal(y, 1)`

This is due to the linear_regression function. The version of that function in the hmc_test.py works fine though.

It would also be good to generate the data in the this example so that the example is self-contained.

So the following script has those 2 fixes and should work instead:

from jax import numpy as np
import jax
import numpy as onp
import mcx
from mcx.distributions import Exponential, Normal

rng_key = jax.random.PRNGKey(0)

x_data = onp.random.normal(0, 5, size=1000).reshape(-1, 1)
y_data = 3 * x_data + onp.random.normal(size=x_data.shape)

observations = {'x': x_data, 'predictions': y_data, 'lmbda': 3.}

@mcx.model
def linear_regression(x, lmbda=1.0):
    sigma <~ Exponential(lmbda)
    coeffs_init = np.ones(x.shape[-1])
    coeffs <~ Normal(coeffs_init, sigma)
    y = np.dot(x, coeffs)
    predictions <~ Normal(y, sigma)
    return predictions

kernel = mcx.HMC(100)
sampler = mcx.sampler(
    rng_key,
    linear_regression,
    kernel,
    **observations
)
posterior = sampler.run()
rlouf commented 3 years ago

In a way I am happy the error message was clear enough that you were able to find a solution.

The example will actually work once the core refactor is merged, but it might be a good idea to fix it for the current version. And I like the idea of including data generation for quick copy-and-paste. You can give it a go if you'd like!

jeremiecoullon commented 3 years ago

Ok I'll do this!

rlouf commented 3 years ago

Solved by #60

rlouf commented 3 years ago

Reopen since I inadvertently reversed part of the changes on master.

rlouf commented 3 years ago

Reversed it to a self-contained example.