rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

WIP: Implement a `denormalize` custom Jaxpr operator simplifying MCX logpdfs #71

Open balancap opened 3 years ago

balancap commented 3 years ago

Overview

This PR is implementing a generic denormalize decorator which removes normalizing constants in a logpdf. Per call to contributions in #65.

Implementation

The current implementation is a two passes algorithm:

Once the latter simplifying mapping is found, the rest of decorator code is just a simple execution pass on the Jaxpr, skipping the operations where a simplifying mapping exists.

Limitations

Even though we try to have a fairly generic implementation, some simplifications are not supported at the moment. For instance, we do not propagate constant simplification in concat or mul operations. These cases could be supported in the future, if it happens to be a performance bottleneck in MCX.

balancap commented 3 years ago

@rlouf As we discussed on Slack, there is quite a bit of additional complexity to add to this PR to handle properly the support select condition appearing in lot of distributions logpdf.

I'll start with a fairly dummy implementation, getting it working, and I think then we can iterate on it to make it less naive and using more properly symbolic programming concept (I started looking at Oryx codebase on that).

rlouf commented 3 years ago

That sounds like a very good plan to me! I'll have a closer look too when my big PR is merged.