pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Support for deterministic dependent samples in PyroSample [enhancement] #3376

Closed BenZickel closed 5 months ago

BenZickel commented 5 months ago

Make pyro.nn.PyroSample(prior) statements sample deterministic values when the prior is a function that returns a torch.Tensor value, as described in the below example:

from pyro.nn import PyroSample, PyroModule
from pyro import distributions as dist

class Location(PyroModule):
    def __init__(self):
        super().__init__()
        # Independent priors
        self.radius = PyroSample(dist.LogNormal(0,1))
        self.theta = PyroSample(dist.Normal(0.5, 0.1))
        # Dependent deterministic
        self.true_x = PyroSample(lambda self: self.radius * self.theta.cos())
        self.true_y = PyroSample(lambda self: self.radius * self.theta.sin())
        # Dependent samples
        self.observed_x = PyroSample(lambda self: dist.Normal(self.true_x, 0.05))
        self.observed_y = PyroSample(lambda self: dist.Normal(self.true_y, 0.05))

    def forward(self):
        return self.true_x, self.true_y, self.observed_x, self.observed_y

In the above code the dependent deterministic samples can be converted to dependent samples by assigning them pyro.nn.PyroSample(prior) statements with the prior being a function returning a proper distribution, allowing simple exploration of various model complexity levels for the same problem.

The updated docs can be reviewed here.

fehiepsi commented 5 months ago

Thanks, @BenZickel! The feature looks great.