rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Incorrect results when sampling from the prior #90

Open elanmart opened 3 years ago

elanmart commented 3 years ago

While going through Statistical Rethinking I wanted to execute a prior-predictive simulation, but the results did not match the textbook example, see below.

What's more, I played with some other synthetic examples and they also give unintuitive results, see further down.

Examples

Example from the rethinking

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def model():
    μ <~ dist.Normal(178, 20)
    σ <~ dist.Uniform(0, 50)
    h <~ dist.Normal(μ, σ)

    return h

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=model, 
    model_args=(), 
    num_samples=10_000
)

fig, axes = plt.subplots(2, 2, figsize=(7, 5), dpi=128)
axes = axes.reshape(-1)

sns.kdeplot(prior_predictive["μ"], ax=axes[0])
sns.kdeplot(prior_predictive["σ"], ax=axes[1])
sns.kdeplot(prior_predictive["h"], ax=axes[2])

plt.tight_layout()

Result

image

Expected

image

Synthetic example 1

In this example I sample an offset from Uniform(0, 1). Then I sample from Uniform(12 - offset, 12 + offset) So I expect my samples to be distributed in range [11, 13] But I get samples in range [-15, 15]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_1():

    center = 12
    offset <~ dist.Uniform(0, 1)

    low = (center - offset)
    high = (center + offset)

    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_1, 
    model_args=(), 
    num_samples=10_000
)

ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Synthetic example 2

This is the same example as above, but center variable is passed as argument, not hardcoded, and results are different (although still not in range [11, 13]

Code

import seaborn as sns
import matplotlib.pyplot as plt
import jax
import mcx

from mcx import distributions as dist
from mcx import sample_joint

@mcx.model
def example_2(center):

    offset <~ dist.Uniform(0, 1)

    low = (center - offset)
    high = (center + offset)

    outcome <~ dist.Uniform(low, high)

rng_key = jax.random.PRNGKey(0)

prior_predictive = sample_joint(
    rng_key=rng_key, 
    model=example_2, 
    model_args=(12, ), 
    num_samples=10_000
)

ax = sns.kdeplot(prior_predictive["outcome"]);
ax.set_title("Outcome");

Result

image

Expectation

For the examples 1 and 2, here's what I'd expect to get:

image

Environment

Linux-5.8.0-44-generic-x86_64-with-glibc2.10
Python 3.8.5 (default, Sep  4 2020, 07:30:14) 
[GCC 7.3.0]
JAX 0.2.8
NetworkX 2.5
JAXlib 0.1.58
mcx 2a2b94801e68d94d86826863eeee80f0b84c390d
elanmart commented 3 years ago

Hi @rlouf

I've looked into this a bit more and identified two issues:

  1. In the mcx models, the same random key seems to be used for multiple distributions, giving incorrect results.
  2. The subtraction Op and negation Op seem broken

Please find the examples of the two issues in the notebook: https://gist.github.com/elanmart/9ab0ba21f282f6b24d972cbfb76b4578

Hope this is helpful

rlouf commented 3 years ago

Hi @elanmart,

Thank you for taking the time to share this with me! Regarding what you identified:

  1. Indeed, I just noticed that recently. It is indeed problematic if you use the same distribution more than once in the model. This should be corrected soon.
  2. In what sense? Would you mind pasting the output of print(example_1.sample_joint_src) and print(example_2.sample_joint_src) ?
elanmart commented 3 years ago

Thanks for the answer! I was wondering how I can inspect the models, sample_joint_src reveals what goes wrong indeed!

The following model

@mcx.model
def example_2_mcx_v1():

    offset  <~ dist.Uniform(0, 5)
    low     =  12 - offset
    outcome <~ dist.Uniform(low, 12)

    return outcome

is transformed into

def example_2_mcx_v1_sample_forward(rng_key):
    offset = dist.Uniform(0, 5).sample(rng_key)
    low = offset - 12
    outcome = dist.Uniform(low, 12).sample(rng_key)
    forward_samples = {'offset': offset, 'outcome': outcome}
    return forward_samples

Notice how

low = 12 - offset

became

low = offset - 12

EDIT

The issue is not limited to constants. The arguments in subtraction are switched to match the order in which they were defined, so

A <~ ...
B <~ ...
B - A

becomes

A - B

and so the model here

@mcx.model
def example():
    A <~ dist.Normal(0, 1)
    B <~ dist.Normal(0, 2)

    μ = B - A
    Y <~ dist.Normal(μ, 1)

    return Y

becomes

def example_sample_forward(rng_key):
    B = dist.Normal(0, 2).sample(rng_key)
    A = dist.Normal(0, 1).sample(rng_key)
    μ = A - B
    Y = dist.Normal(μ, 1).sample(rng_key)
    forward_samples = {'A': A, 'B': B, 'Y': Y}
    return forward_samples
elanmart commented 3 years ago

Ah, and also regarding point 1. (same rng_key used many times): is there any simple workaround I could use as a temporary solution, however hacky?

rlouf commented 3 years ago

That's strange regarding A-B, I identified the problem 10 days ago and I thought I'd fixed it. Are you running the latest version (latest commit)?

Unfortunately no workaround for the rng_key but I can try to push a fix next week. I'll make sure it works on these examples. In the meantime you can keep moving forward, checking the source code each time there's something weird. You'd just have to re-run your notebooks once the fixes are made.

Now I see how convenient compiling to a python function is for debugging 😄 Thank you for dealing with the teething problems here, it is really helpful for us.

elanmart commented 3 years ago

OK, so my poetry.lock file indicated that I have the latest commit, but after clean re-install the issue is indeed resolved... I'm really sorry for generating noise :cry:

Do you want me to close this ticket and open a clean one for rng_key topic? Out of curiosity -- what is the fix you envision there? Adding _, key = jax.random.split(key) statement to the graph after each sample() call? Or is there a nicer solution?

Thank you for dealing with the teething problems

No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.

rlouf commented 3 years ago

I'm really sorry for generating noise :cry:

No worries, you're really helpful :)

Do you want me to close this ticket and open a clean one for rng_key topic?

Yes please! Leave this one open until we solve the issue completely.

Out of curiosity -- what is the fix you envision there? Adding _, key = jax.random.split(key) statement to the graph after each sample() call? Or is there a nicer solution?

So that would be the quick and dirty solution. I think that I might instead generate as many keys as needed at the beginning of the function.

No worries, I would love to understand the compiler a bit better to be able to debug similar issues myself.

Well now you know that you can at least print the code generated by the compiler. It's a good start point.

tblazina commented 3 years ago

@elanmart - with regards to the compiler - I made some (not really organized) notes here, some of which I suppose are (or will soon be invalid) invalid after the <~ operator is phased out. In any case maybe they would be helpful to you!

rlouf commented 3 years ago

some of which I suppose are (or will soon be invalid) invalid after the <~ operator is phased out.

Actually the general principle with stay exactly the same.

elanmart commented 3 years ago

Thank you @tblazina ! This looks extremely useful, will go through it over the weekend!

rlouf commented 3 years ago

Just to let you know, I'll make some time to work on this and the other issue one my NUTS PR is merged on BlackJAX (which means MCX will support NUTS). How is the implementation on SR going?

elanmart commented 3 years ago

Thanks for the update! Looking forward to the NUTS sampler as well.

I've decided to first go through the theory, and then make a second pass implementing the examples.

I've just finished the book, so I'm going back to the code, which hopefully should go faster now.

There were a few places in the book where some advanced STAN featuers were used. I'm a bit worried about those, but we'll see how it goes.

rlouf commented 3 years ago

Great! If you remember which ones don't hesitate to open issues now.