rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Added dynamic rhat #77

Closed sidravi1 closed 3 years ago

sidravi1 commented 3 years ago

Overview

Displays Rhat values in the progress bar. Partially addresses #8.

Details

Here it is in action:

https://user-images.githubusercontent.com/14125957/107138740-18a17a00-68e4-11eb-854f-7ed6cc243858.mov

To do

  1. Probably should add variable names to Rhat
  2. Probably move it to a new line else will get messy with a lot of vars
  3. Basic offline Rhat and warning if too high (like pymc3)
  4. More testing - especially for mvn 🙊
rlouf commented 3 years ago

The code looks great and I really like the result!

  1. Probably should add variable names to Rhat
  2. Probably move it to a new line else will get messy with a lot of vars

Since this is a rough indicator, I thought we could only display the "worst" value of Rhat among all variables (in terms of distance to 1). Other values can be shown in the inference summary. The ideal would be to update a graph with all the values of Rhat over time, but that's a project in itself.

  1. Basic offline Rhat and warning if too high (like pymc3)

As discussed, implementing the rank-normalized Rhat for the inference summary would be best. Adding a warning if a value is too high is a good idea, and it is even better if that warning is actionable: what can I do as a modeler with this information?

Where should it be displayed? After the progress bar or do we print (at least part of) the inference summary first?

  1. More testing - especially for mvn speak_no_evil

Indeed, multivariate random variables are more error-prone :) The best is probably to take examples from the paper and check that computing Rhat on these chains gives the expected result.

rlouf commented 3 years ago

Btw for the sake of making incremental changes it would be better to address (3) in a separate PR.

rlouf commented 3 years ago

Hey @sidravi1 what's the status on this PR?

sidravi1 commented 3 years ago

Hey Rémi - thanks for the push. Been too occupied by work and kids these last few weeks (months?). I haven’t worked on this since the refactor. I’ll work on it on Thursday.

On Mon, Apr 12, 2021 at 2:18 AM Rémi Louf @.***> wrote:

Hey @sidravi1 https://github.com/sidravi1 what's the status on this PR?

— You are receiving this because you were mentioned.

Reply to this email directly, view it on GitHub https://github.com/rlouf/mcx/pull/77#issuecomment-817517942, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADLYXBNVDANACHUHK3TATJ3TIKGD7ANCNFSM4XHD6WVA .

rlouf commented 3 years ago

No problem, this is open source, not paid work 🙂

sidravi1 commented 3 years ago

Hi @rlouf - Got dynamic rhat working though using set_postfix does slow down performance by ~50%

Tested it with this mvnormal model as well.

@mcx.model
def linear_regression_mvn(x, lmbda=1.):
    sigma <~ dist.Exponential(lmbda)
    sigma2 <~ dist.Exponential(lmbda)
    rho <~ dist.Uniform(-1, 1)
    cov = jnp.array([[sigma**2, rho*sigma*sigma2],[rho*sigma*sigma2, sigma2**2]])
    coeffs_init = jnp.ones(x.shape[-1])
    coeffs <~ dist.MvNormal(coeffs_init, cov)
    y = jnp.dot(x, coeffs.T)
    preds <~ dist.Normal(y, sigma)
    return preds

sampler = mcx.sampler(
    rng_key,
    linear_regression_mvn,
    (x_data_mvn,),
    {'preds': y_data_mvn},
    HMC(10),
)
posterior = sampler.run()

If all good, i'll clean up the commit history before the merge.

rlouf commented 3 years ago

Great! Could you try using the mininterval flag and setting it to something like .5s or 1s and report the slowdown then? (https://github.com/tqdm/tqdm/blob/master/tqdm/std.py#L873-L880)

sidravi1 commented 3 years ago

Great! Could you try using the mininterval flag and setting it to something like .5s or 1s and report the slowdown then? (https://github.com/tqdm/tqdm/blob/master/tqdm/std.py#L873-L880)

mininterval doesn't seem to help much. The bottleneck is actually the rhat updating and not tqdm. What are your thoughts on making it optional? We could also use a pattern where you can register callbacks to run a bunch of other online stats

image

Should also point out that the bottleneck is most noticeable when the model is simple (the linear example), when it's more complex (multivariate example) then it doesn't really reduce it that much.

rlouf commented 3 years ago

It's all a question of user interface. The original idea was that, since we spend 99% of our time debugging models, the sample function would be interactive by default: it displays as much information as possible to see when issues arise and can be interrupted at any time to diagnose these issues. compile=True would show nothing but the progress bar and would correspond to situations where we need inference to be as fast as possible; we could also define a fast_sample function for that purpose.

Now, if you have to wait an extra few seconds for simple models but it does not affect large models, it is not really a problem.

Nevertheless, I like your idea of designing these online metrics as callbacks. This would allow users to customize the metrics being displayed and/or follow their own metrics. It is also cleaner from a code perspective. This way sample would be called with callbacks=[rhat, ess, divergences] by default.

PS: is the multiple progress bar a bug?

sidravi1 commented 3 years ago

Ok. Make sense.

Should we merge this in and switch to callback design pattern in another PR (when we implement ESS or divergences) or do you want me to update this one?

The multiple progress bars are because of the %%timeit cell magic on top. Just runs it multiple times to get the average run time.

rlouf commented 3 years ago

Would you mind updating this one?

sidravi1 commented 3 years ago

Yep! Can do :)

sidravi1 commented 3 years ago

Thanks for your patience @rlouf - I've made those changes. Let me know what you think.

sidravi1 commented 3 years ago

@rlouf - Thanks for reviewing. I've made that one docstring fix and squashed all the commits.

rlouf commented 3 years ago

Great work, the code was really clean and self-explanatory!