rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Refactor the core #53

Closed rlouf closed 3 years ago

rlouf commented 3 years ago

The current version of the core parses the model definition into a tree and compiles it back to Python functions to compute the distribution's logpdf or draw samples. In its current version it is however brittle and is rather limited: for instance it is complicated to describe if/else statements in this framework.

In this PR we introduce a new intermediate representation, the MCX Syntax Tree which is a slightly augmented version of Python's Abstract Syntax Tree. Model definitions are parsed into the McxST (and their syntax is checked). The McxST being a static graph, it is possible to do all the operations that were possible before (do operator, conjugacy detection, random variable transformation, etc.). It still contains if/else statements and for loops so the compilers can translate them in a version that JAX can JIT-compile and differentiate.

Since the McxST is a thin wrapper around Python's AST the resulting code is guaranteed to be more robust and to allow any operation that Python allows.

rlouf commented 3 years ago

I am aware that SymJAX allows to build a symbolic graph from code written with JAX construct. But:

  1. Such a graph (like Theano's) gives me more granularity that I need, at a fairly high complexity cost.
  2. The resulting node is not as close to python/numpy, need to use that “T“ notation, cannot use python control flow constructs directly.

Future might prove me wrong, but I don't think that I need a full symbolic graph. However, I can always build a symbolic graph with the granularity I need. This would have the following benefits:

rlouf commented 3 years ago

Control flow, in particular, gets more complicated. I believe we should represent the different branches as separated graphs:

class If(ControlFlow):
    operand: Callable -> bool
    true_branch: nx.Digraph
    false_branch: nx.Digraph

Placeholders in branches are connected to nodes in the main function iff they share the same name. Depending on the content of the If branches we should be able to convert it to a cond Op that can be converted to a JAX cond. If that is not possible it will be converted to a standard python if/else, but we won't be able to JIT the compiled function. The same applies to JAX's switch and scan/fori_loop (though more conplex to implement).

Put simply, we can translate an If node to a cond iff there is only one leaf in both branches's graph and they are named the same.

rlouf commented 3 years ago

Current restriction

Note

It is possible to store a live version of the Ops/Constants via the namespace and "getattr" thus creating a full symbolic graph.

rlouf commented 3 years ago
rlouf commented 3 years ago

Everything that worked works as expected. I still need to fix a bug on the state initialization.