pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Effect handler that conditions a model on sample sites having the same value #3395

Closed BenZickel closed 2 months ago

BenZickel commented 2 months ago

Problem Description

It would be helpful to have an effect handler that can condition a model on sample sites having the same value.

Suggested Solution

Use the EqualizeMessenger effect handler with a newly added option keep_dist, that when set to True keeps the original distribution functions of the sample sites, as opposed to the default behavior of converting the second and subsequent sites to be deterministic.

Usage Example

Consider the model

def model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.sample('y', pyro.distributions.Normal(5, 3))

The model can be conditioned on ‘x’ and ‘y’ having the same value by

conditioned_model = pyro.poutine.equalize(model, ['x', 'y'], keep_dist=True)

which is equivalent to

def conditioned_model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.sample('y', pyro.distributions.Normal(5, 3), obs=x)

as opposed to the default behavior of EqualizeMessenger with keep_dist equal to False such that

equalized_model = pyro.poutine.equalize(model, ['x', 'y'])

which is equivalent to

def equalized_model():
    x = pyro.sample('x', pyro.distributions.Normal(0, 1))
    y = pyro.deterministic('y', x)

Note that the conditioned model defined above calculates the correct unnormalized log-probablity density, but in order to correctly sample from it one must use SVI or MCMC techniques.

Testing

I've added a test for the conditioned model case, with two normally distributed random variables, which allows for analytic calculation of the expected resulting normal distribution.