rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

sampler tests #59

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

It might be nice to add some more rigorous tests to the sampler (though the mean and variance tests at the moment are good!).

For example the Kernel goodness of fit test (here or here).

This is obviously not urgent, I just thought it might be a nice thing. I don't know how Stan and other packages do integrations tests for their samplers; perhaps they also test the first few moments of the samples?

rlouf commented 3 years ago

I totally agree! Rainier has an interesting approach to testing inference by using Simulation Based Calibration:

https://github.com/stripe/rainier/blob/develop/rainier-test/src/main/scala/com/stripe/rainier/core/SBCModel.scala

Which I quite like. What do you think?

jeremiecoullon commented 3 years ago

Nice! Is this that technique? I've come accross a similar method; but my supervisor called it the prior reproduction test

rlouf commented 3 years ago

Yes, that's the article. That's essentially what you write about in your blog post except that instead of comparing distributions directly you compare the rank statistic to the uniform distribution. And then I guess you use a chi-squared test?

By the way, I don't know if you saw on Twitter, but we're discussing with PyMC3 and Numpyro devs to move the inference code in a separate repository. It will have essentially MCX's code structure + numpyro's NUTS implementation at the beginning. Tests based on simulation-based calibration would be a great addition.

jeremiecoullon commented 3 years ago

Seperating the inference from the other stuff sounds like a good idea! I'll have a closer look at SMC at some point; it would be good to add this.

rlouf commented 3 years ago

Yes, and it shouldn't be that computationally challenging using JAX.

rlouf commented 3 years ago

@jeremiecoullon Development of inference algortihms will be moved in the short-term to BlackJAX. The library is developed in collaboration with the PyMC and Numpyro teams. We could use your help to develop more robust tests for sampling algorithms!