rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Bugfix for poisson distribution #105

Closed mkretsch327 closed 3 years ago

mkretsch327 commented 3 years ago

@rlouf - just pushing the changes to the poisson distribution and tests from last week, it looks like a subsequent PR inadvertently over-wrote them.

I'm pushing the same changes from earlier, just rebased on top of the current master branch.

rlouf commented 3 years ago

Thank you for spotting this, I did not pay attention when I merged the other branch. Merging this one.

How did you notice? Are you using mcx?

mkretsch327 commented 3 years ago

Thank you for spotting this, I did not pay attention when I merged the other branch. Merging this one.

How did you notice? Are you using mcx?

Using yes, for my own educational purposes :) . Specifically, I was playing around with this bayesian switch-point example from tf probability. Here's a link to my notebook. I had put it on google colab too, to see speedups I could get with a GPU. I'd be happy share/add this here at some point too if you think it's useful for others!

rlouf commented 3 years ago

Nice, it's nice to see someone use it :) Couple remarks:

  1. You cannot currently use MCX with TPU, nor is it necessary for small models. I would need to add a sampler that used JAX's pmap which is not hard to do, but haven't had the need so far.
  2. For GPUs to be useful you will need a HUGE dataset. You can take the linear regression example in the README and increase the size of the dataset. It is quite impressive when it kicks in.
  3. I see that you needed to add a #args comment, which is design smell to me. In the next UI version you won't need this and the UI will look like:
model = disaster_rate_model(years.astype(np.float32)).condition(observations)
sampler = mcx.sampler(rng_key,model,hmc_kernel)

If there's anything else that doesn't feel natural to you please let me know, I'm happy to get feedback.

mkretsch327 commented 3 years ago

Nice, it's nice to see someone use it :) Couple remarks:

  1. You cannot currently use MCX with TPU, nor is it necessary for small models. I would need to add a sampler that used JAX's pmap which is not hard to do, but haven't had the need so far.
  2. For GPUs to be useful you will need a HUGE dataset. You can take the linear regression example in the README and increase the size of the dataset. It is quite impressive when it kicks in.
  3. I see that you needed to add a #args comment, which is design smell to me. In the next UI version you won't need this and the UI will look like:
model = disaster_rate_model(years.astype(np.float32)).condition(observations)
sampler = mcx.sampler(rng_key,model,hmc_kernel)

If there's anything else that doesn't feel natural to you please let me know, I'm happy to get feedback.

On the TPU note I saw this too, there was no change in timing but wasn't sure if it was code or the free-version of colab. WIth GPU enabled, I was able to see (some) speedups, but they were somewhat small. I bumped the number of chains to a ludicrous number that I'd never think of trying in pymc3 (like 250), and the slowdown was pretty small. And I like that UI example, intuitive to .condition(obs) imo.

rlouf commented 3 years ago

Yeah you should start seeing slowdown at around 1,000 chains. Even then it's quite fast. One short-term application of this could be to distribute the same model over 100s of dataset (as much as your GPU memory allows) to speed up simple models that otherwise need to be distributed.

Thanks again for contributing, don't hesitate to open an issue if you see something wrong or if something in the UI doesn't feel right!