rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
324 stars 17 forks source link

Find the right boundary between runtimes and programs #13

Closed rlouf closed 4 years ago

rlouf commented 4 years ago

I am not sure what runtimes and programs' respective responsibilities are yet. This PR is an experiment to find the right API.

Requirements

Slow initialisation

The initialisation is unacceptably slow. It is caused by repeatedly calling ravel_pytree for each chain. Luckily for us, ravel_pytree arranges the values of the dictionary in a deterministic way, sorting them by key. To fix the performance, we stack the arrays in the dictionary returned by sample_forward to obtained flattened position, and apply ravel_pytreeto one dictionary to get the unraveling function.

This PR follows issue #12.

rlouf commented 4 years ago

Good enough for now, merging.