rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

Implement a do operator #103

Open elanmart opened 3 years ago

elanmart commented 3 years ago

I thought I would open an issue to follow up on https://github.com/rlouf/mcx/issues/1#issuecomment-812733048 , even though I'm not sure if I would be able to contribute a PR.

@rlouf Do I understand correctly that you wouldn't like to modify the model's graph, but instead partially apply the argument to target functions?

I can see how that would work for mcx.generative_function (posterior sampling), but for example the model.sample_joint function does not take any arguments, so we couldn't use the partial application approach there...

rlouf commented 3 years ago

I thought I would open an issue to follow up on #1 (comment) , even though I'm not sure if I would be able to contribute a PR.

Yes that’s a good idea, thanks. I’m happy to guide you through the PR if you want to contribute.

@.***(https://github.com/rlouf) Do I understand correctly that you wouldn't like to modify the model's graph, but instead partially apply the argument to target functions?

You would have to modify the model graph, of course. For instance for the model

@mcx.model
def example_model():
    a <~ Normal(0,1)
    b <~ HalfNormal(1)
    c <~ Normal(a, b)
    d  = b + c
return d

the model ˋexample_model.do(c=5)` would have the forward sampling function:

def sample_example_model(rng_key, c):
    b = Normal(0,1).sample(rng_key)
    d = b + c
    return d

we modified the graph, we removed a and c has become an input variable. You can then partially apply c = 1 to obtain the sampling function for the model with intervention. So there are some graph modifications involved.

the reason why we cannot directly inject c = 1 in the model’s code is that the internal representation contains the model’s AST and it is not possible to parse the AST of a live expression in python, ie the AST of an argument to a function. Is that clear?