blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
850 stars 106 forks source link

Difference in behavior between pip and github blackjax versions on example code #151

Closed bbbales2 closed 3 years ago

bbbales2 commented 3 years ago

Bug Description

I'm running the example code from the readme (this) and the version of blackjax in pip seems to not work for me (pip install blackjax) but the version in Github does (installed with git clone and pip install .).

I don't really know what goes into packaging things on pip or what the difference might be, so close if this just seems like a problem with my system!

Thanks for the nice package

Steps/Code to Reproduce

The code I'm running is this (same as example, just added a plot):

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import pandas
import plotnine

import blackjax.nuts as nuts

observed = np.random.normal(10, 20, size=1_000)
def logprob_fn(x):
  logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
  return jnp.sum(logpdf)

# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
kernel = nuts.kernel(logprob_fn, step_size, inverse_mass_matrix)
kernel = jax.jit(kernel)  # try without to see the speedup

# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.new_state(initial_position, logprob_fn)

# Iterate
rng_key = jax.random.PRNGKey(0)
positions = []
for _ in range(1_000):
    _, rng_key = jax.random.split(rng_key)
    state, _ = kernel(rng_key, state)
    positions.append(state.position['loc'])

df = pandas.DataFrame({ "draw" : range(len(positions)), "x" : np.array(positions)})

(
    plotnine.ggplot(df) +
    plotnine.geom_point(plotnine.aes("draw", "x"))
).draw(show = True)

Expected Results

Here's what I get with Github blackjax:

from_github

Actual Results

Here's what I get with pip blackjax:

from_pip

I repeated this multiple times with different seeds.

Versions

This is the version the github install shows:

>>> import blackjax; print("BlackJAX", blackjax.__version__)
BlackJAX 0.2.1

This is what the version from pip shows (along with everything else):

>>> import blackjax; print("BlackJAX", blackjax.__version__)
BlackJAX 0.2.1
>>> import sys; print("Python", sys.version)
Python 3.9.7 (default, Oct 10 2021, 15:13:22) 
[GCC 11.1.0]
>>> import jax; print("Jax", jax.__version__)
Jax 0.2.25
>>> import jaxlib; print("Jaxlib", jaxlib.__version__)
Jaxlib 0.1.74
rlouf commented 3 years ago

Thank you for the great bug report! The reason why this works with the last commit version and not pip is that the readme example is not in sync with the pip version. Indeed, since the 2.1 release we're asking users to provide a logprobability function and not a potential function anymore. If you return -jnp.sum(logpdf) instead of jnp.sum(logpdf) this should give you the correct results.

I'll cut a new release in the next few days.

bbbales2 commented 3 years ago

Ah, makes sense, thanks. Feel free to close whenever -- I'll use the github clone for now!