pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.58k stars 987 forks source link

[FR] Predictive with deterministic site in the guide #3358

Open OlaRonning opened 7 months ago

OlaRonning commented 7 months ago

Hi,

I'm working on a project where we would like to access the output of an NN in the guide when using Predictive. We've implemented it using a deterministic site in the guide. The program boils down to the following.

import pyro
from pyro.infer import Predictive
from pyro.distributions import Normal
import torch

def model():
    pyro.deterministic('m_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))

def guide():
    pyro.deterministic('g_deter', torch.tensor(1.))
    pyro.sample('x', Normal(torch.zeros(()), torch.ones(())))

Predictive(
  model=model, 
  guide=guide, 
  return_sites=('model_site', 'guide_site', 'x'), 
  num_samples=1)() # Includes m_deter but not g_deter

We would like for both m_deter and g_deter to be included. It looks like Predictive currently only considers model sites for return sites. Would it be possible to expand it so we can include deterministic sites from the guide?

fritzo commented 7 months ago

This looks reasonable to me. Since deterministic guide sites are usually ignored (e.g. in AutoGuides) I think we may want to gate this new behavior by an arg like return_deterministic_guide_sites: bool or something.

OlaRonning commented 7 months ago

A guard makes sense. I'll give an implementation a shot.

SarthakNikhal commented 7 months ago

@OlaRonning Can I help with this issue?

OlaRonning commented 7 months ago

@SarthakNikhal absolutely. Feel free to look at my WIP PR. I wrote the unittest relatively fast; you can probably develop a more suitable one.

SarthakNikhal commented 7 months ago

@OlaRonning Okay. What can I do better? Also, what other unit tests can you think of

OlaRonning commented 7 months ago

I would make the test cover four cases:

  1. return_deterministic is true and no return_sites. returned samples should include all deterministic sites in the guide.
  2. return_deterministic is true and return_sites includes one of two deterministic sites in the guide. returned samples should only include the deterministic guide site in return_sites.
  3. return_determininistic is true and there are no deterministic sites in the guide. returned samples should be the same as when return_deterministic is false.
  4. return_deterministic is false and there is a deterministic site in the guide. the returned samples should not include the deterministic site from the guide.

You'd probably want to check both that sites are included in the returned samples and that their values are as expected. I believe you can work directly on the aleatory_science branch.