pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

Feature request: add a parameter that allows Predictive to propagate gradients #3404

Open pwsiegel opened 1 month ago

pwsiegel commented 1 month ago

Issue Description

When generating samples using pyro.infer.predictive.Predictive, gradients are dropped. It looks like this is an intentional design choice from the code, but I'm not sure why - if there's a good reason then ignore me.

Code Snippet

Create a model and fit a guide such that model(x).requires_grad and guide(x)['some_site'].requires_grad both return True when x has gradients enabled. Then do:

predictive = Predictive(model, guide=guide, num_samples=100)
posterior_samples = predictive(x)

Then posterior_samples['some_site'].requires_grad is False.