rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

logpdf cannot handle more than 3 RVs #101

Closed elanmart closed 3 years ago

elanmart commented 3 years ago

On mcx master branch:

import mcx
from mcx import distributions as dist

@mcx.model
def model():
    a <~ dist.Normal(0, 1)
    b <~ dist.Normal(0, 1)
    c <~ dist.Normal(0, 1)
    d <~ dist.Normal(0, 1)

print(model.logpdf_src)

produces

def model_logpdf(a, b, c, d):
    logpdf_model_d = dist.Normal(0, 1).logpdf_sum(d)
    logpdf_model_c = dist.Normal(0, 1).logpdf_sum(c)
    logpdf_model_b = dist.Normal(0, 1).logpdf_sum(b)
    logpdf_model_a = dist.Normal(0, 1).logpdf_sum(a)
    logpdf = logpdf_model_c + logpdf_model_d + logpdf_model_b
    return logpdf

Note that logpdf_model_a is missing in the summation.

rlouf commented 3 years ago

Thanks for finding this! I just pushed a simple fix, should work now. How's Statistical Rethinking coming along?

elanmart commented 3 years ago

Thanks for the quick fix! Slower than I expected :( Maybe I'll be able to focus more on it soon, I've just started Chapter 5, need to push Chapter 4 code.

I love the simplicity of mcx syntax so far!

Here's another issue I found with logpdf, should I open another ticket :grinning: ?

import mcx
from mcx import distributions as dist

@mcx.model
def model(x, xbar):
    beta <~ dist.Normal(0, 1)
    mu = beta * (x - xbar)
    yhat <~ dist.Normal(mu, 1)

    return yhat

print(model.logpdf_src)
def model_logpdf(x, xbar, beta, yhat):
    logpdf_model_beta = dist.Normal(0, 1).logpdf_sum(beta)
    mu = beta * x - xbar
    logpdf_model_yhat = dist.Normal(mu, 1).logpdf_sum(yhat)
    logpdf = logpdf_model_beta + logpdf_model_yhat
    return logpdf

So beta * (x - xbar) changes to beta * x - xbar because parenthesis are discarded. But I think I saw a warning about this somewhere in the docs...?

rlouf commented 3 years ago

There is a warning in the docs but that constraint is actually not justified anymore. I'll see what I can do!

rlouf commented 3 years ago

This was, again, an easy fix. Thank you for spotting and reporting! By the end of the book MCX should be very robust :)