pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.13k stars 232 forks source link

Improve the speed of gamma sampler #955

Closed gustavehug closed 3 years ago

gustavehug commented 3 years ago

First of all, thanks for this awesome library ! I have been going through the Rethinking translation (along with the book), it's been great.

I have developed a bayesian linear regression model, that I did fit on 40k observations. I ran a single chain of 500 samples (+ 500 warmup) on the CPU using NUTS, which takes 1:30. This is great.

It takes approximately 1 ms to make a prediction with (jitted) Predictive. This is also working very well for my use-case.

However, when I make batch predictions for the whole dataset (40k observations), Predictive takes about 10 seconds, not including the compile-time.

Now, I would like to serve the trained model with an API, and minimize latency.

Is it possible to speedup the jitted predictive function for tall data ?

AndreRoehrig commented 3 years ago

HI @gustavehug , could you explain what a 'jitted Predictive' is and how I can learn more about it? I searched the documentation and tutorials without success. I am nowhere near 1ms for even a single prediction and this would be very helpful for my use case. I hope someone else can help you with the batch prediction issue.

gustavehug commented 3 years ago

Hi @AndreRoehrig , it is a way of compiling just-in-time the function to XLA, with JAX.

The way I use it

import jax
from numpyro.infer import Predictive

def model(...):
    ....

# train the model here
...

predict_fn = jax.jit(Predictive(model, samples))

In my case, I first forgot to jit the function and it took 1 second for each run. Then I added it, and the time was divided by a factor 1000 for all runs subsequent to the first one, however the first one is longer (1 second approximately) since the function has to be compiled.

Bear in mind that this is a linear regression though, more complex models might require more time.

martinjankowiak commented 3 years ago

knowing nothing about your model it's hard to say if there's additional room for speed-ups. perhaps not.

one thing you could do is use the thinning argument in MCMC, e.g. thinning=2. adjacent MCMC samples are correlated so computationally it can make sense to throw some samples out.

gustavehug commented 3 years ago

@martinjankowiak Thank you for your answer, this does improve performance. The complexity seems linear in the number of samples with few additional overhead (other than the compile-time), which makes sense.

So i guess you recommend thinning=2 ? I have tried with 2 and 3, but where should I stop in order to retain an "acceptable" prediction ? I have 500 samples to start with, which is not huge I guess.

I was also wondering if one could trade memory for speed, since the memory usage is quite low in my case. I tried to modify the _predictive and soft_vmap bits, but did not succeed since I am not very skilled with JAX nor numpyro.

The model has the following form:

import numpyro as ny
import numpyro.distributions as dist
import jax.numpy as jnp

def reparametrize_gamma(mean, std):
    var = std ** 2
    alpha = mean ** 2 / var
    beta = mean / var
    return alpha, beta

def Gamma_2(mean, std):
    return dist.Gamma(*reparametrize_gamma(mean, std))

def model(n, m, i, j, x, y, target=None):
    a_1 = ny.sample("a_1", Gamma_2(100, 25).expand([n]))
    a_2 = ny.sample("a_2", Gamma_2(100, 10).expand([m]))
    b = ny.sample("b", Gamma_2(40, 10))
    c = ny.sample("c", Gamma_2(60, 10))
    mu_raw = a_1[i] + a_2[j] + b * x + c * y

    # gamma needs a positive mean
    mu = jnp.clip(mu_raw, 0.1)
    sigma = ny.sample("sigma", dist.Exponential(1.0))
    ny.sample("target", Gamma_2(mu, sigma), obs=target)

I use Gammas because all the parameters are positive, and for the observation, since I expect a somewhat fat tail.

martinjankowiak commented 3 years ago

i'm not making specific recommendations because i don't know your requirements. if you care much more about test time than train time than you can draw far more samples (e.g. 10k) and then thin more aggressively (e.g. thinning = 100). it's impossible for me to say what trade-offs are best for your use case.

i suspect it won't be easy to get additional speed-ups (apart perhaps from small ones).

fehiepsi commented 3 years ago

For gamma likelihood, drawing 500 x 40000 samples would take 10s. So to speedup, you can reduce the number of posterior samples to e.g. 50.

%%time
from jax import random
import numpyro.distributions as dist
d = dist.Gamma(1., 1.)
x = d.sample(random.PRNGKey(0), (500, 40000)).copy()
# CPU times: user 12 s, sys: 205 ms, total: 12.2 s
# Wall time: 10.8 s

Alternatively, you can use

import numpyro.contrib.tfp.distributions as tfd
tfd.Gamma

which seems to give 4x speed for Gamma sampler.

gustavehug commented 3 years ago

Thank you for your responses. I must mention that I am quite new at probabilistic programming.

@martinjankowiak Ok thank you for these suggestions, I will make some tests with thinning to see if I can accelerate further without losing too much precision. I tried the 10k samples / 100 thinning, it is indeed faster and has good enough precision.

@fehiepsi I tried the tfd version, it did not work unfortunately. MCMC seems to run well, but then the program hangs when compiling the predictive function and never terminates. I also had to enable x64, but even then it did not work.

I tried to jit-compile the Gamma_2 definition above, it seemed to reduce the MCMC run times, but not the prediction.

I also tried to change the final sample call of my model to a reparameterized LogNormal, it is much faster for prediction (10-20x) and for MCMC. I guess that there is inherently more complexity with implementations of the gamma distribution. The problem is that it changes the predictions a lot... I have to check the model performances to see if it is acceptable and fine-tune the model.

fehiepsi commented 3 years ago

the program hangs when compiling the predictive function and never terminates

Interesting! I tried

import numpyro
import numpyro.contrib.tfp.distributions as tfd
from numpyro.infer import Predictive
from jax import random

def model():
    numpyro.sample("x", tfd.Gamma(1, 1))

x = Predictive(model, {}, num_samples=10)(random.PRNGKey(0))["x"]
print(x.shape)

and got the expected result. It would be nice if you make a replicable code, so we can try to isolate the issue.

I tried to optimize gamma sampler implementation on jax but it is still slow. :(

fehiepsi commented 3 years ago

Closed because it is expected that jax gamma sampler is slow. As mentioned above, we can use tfp gamma sampler as an alternative solution to get 4x speed-up. Please feel free to reopen this issue if follow-ups are needed.