blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
847 stars 106 forks source link

NUTS performance concerns on GPU #597

Open PaulScemama opened 11 months ago

PaulScemama commented 11 months ago

Describe the issue as clearly as possible:

On a trivial example (that of quickstart.md) there appears to be a weird bug I'm experiencing with the NUTS sampler using a GPU.

When I run the script (which I copy below) with a GPU for 200 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 200 steps
Nuts took 0.050431712468465166 minutes

When I run the script with a GPU for 300 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 300 steps
Nuts took 0.8048396507898966 minutes

When I run the script with GPU for 500 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 500 steps
Nuts took 1.2937044938405355 minutes

When I run the script on CPU with 1000 steps I get

Jax sees these devices: [CpuDevice(id=0)]
Starting to run nuts for 1000 steps
Nuts took 0.06121724049250285 minutes

Steps/code to reproduce the bug:

import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats

import blackjax

from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)

def logdensity_fn(loc, log_scale, observed=observed):
    """Univariate Normal"""
    scale = jnp.exp(log_scale)
    logpdf = stats.norm.logpdf(observed, loc, scale)
    return jnp.sum(logpdf)

logdensity = lambda x: logdensity_fn(**x)

def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)

initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = nuts.init(initial_position)
initial_state

rng_key, sample_key = jax.random.split(rng_key)

# TIMING NUTS
import time

start = time.time()
num_steps = 500
print(f"Jax sees these devices: {jax.devices()}")
print(f"Starting to run nuts for {num_steps} steps")
states = inference_loop(sample_key, nuts.step, initial_state, num_steps)
end = time.time()
print(f"Nuts took {(end-start)/60} minutes")

Expected result:

A shorter amount of time to run. I am not super familiar with the CPU/GPU benefits/pitfalls for MCMC sampling like NUTS. Maybe CPU is much faster? If so, I think a warning would be nice; consider the scenario where a user is using a GPU for things like variational inference then decides to use NUTS and it takes forever.

Error message:

No response

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 0.1.dev454+g164a4dd
Python 3.9.17 (main, Nov 28 2023, 23:51:11) 
[GCC 7.5.0]
Jax 0.4.16
Jaxlib 0.4.16

Context for the issue:

No response

junpenglao commented 11 months ago

This is not my experience, what is your environment?

Also, important note re benchmarking JAX: https://jax.readthedocs.io/en/latest/async_dispatch.html

PaulScemama commented 11 months ago

@junpenglao I will check back on this later this weekend -- possible that it is an environment problem. I'll get back to you then.

DanWaxman commented 6 months ago

This is replicable on Colab, so I don't think it's an environment issue.

Output for CPU:

Jax sees these devices: [CpuDevice(id=0)]
Starting to run nuts for 500 steps
NUTS Call took 0.12702747980753581 minutes

Output for GPU:

Jax sees these devices: [cuda(id=0)]
Starting to run nuts for 500 steps
NUTS Call took 0.7922836542129517 minutes

I think this is more or less expected behavior though when the problem is rather small, and doesn't include operations GPUs are particularly good at. There was a similar discussion for NumPyro here, with the takeaway being that Jax is particularly efficient on CPU and GPU acceleration only makes sense for certain problems.

gil2rok commented 5 months ago

Note that NUTS is control-flow heavy which makes its hard to run fast on a GPU.

See the CHEES algorithm, implemented in BlackJax, for a NUTS-like sampler that avoids this problem.