pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

[Docs] Adding a doc page on debugging through a model (and notebook on Probabilistic PCA for tutorials) #3031

Open nipunbatra opened 2 years ago

nipunbatra commented 2 years ago

Hi, I was preparing a tutorial on Prob. PCA mirroring the tutorial in TFP.

After non-trivial debugging , toying with to_event, I managed to get the example working in Pyro (though to be honest, still not very confident!).

I would like to suggest a tutorial which debugs some common issues in models (I'd say pyro.clear_params() and managing the shapes correct would be two such candidates). While there is already an excellent guide on Tensor shapes in Pyro, I think an example-driven tutorial page focused on showing the errors in modelling and then solving through the problems would be very useful (and potentially reduce similar questions on the forums).

As an example, I'm copying the code I used for Prob. PCA. I created two versions (with and without plates). Again, it may be nice in such a tutorial to discuss the differences in the two models and when to use which.

import pyro
import torch
import matplotlib.pyplot as plt
dist = pyro.distributions
torch.manual_seed(10)
W_gt = torch.rand(2, 2)
Z_gt = torch.randn(200, 2)

X = Z_gt@W_gt
plt.scatter(X[:, 0], X[:, 1])
plt.axis('equal');

output_3_0

pyro.clear_param_store()

def ppca_model_without_plate(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.sample(
        "W",
        dist.Normal(
            loc=torch.zeros([data_dim, latent_dim]),
            scale=5.0 * torch.ones([data_dim, latent_dim]),
        ).to_event(2)
    )
    Z = pyro.sample(
        "Z",
        dist.Normal(
            loc=torch.zeros([latent_dim, N]),
            scale=torch.ones([latent_dim, N]),
        ).to_event(1),
    )

    mean = (W @ Z).t()

    ob = pyro.distributions.Normal(mean, 1.0).to_event(2)
    return pyro.sample("obs", ob, obs=data)

pyro.render_model(
    ppca_model_without_plate, model_args=(X, 1), render_distributions=True
)

output_4_0

ppca_model_without_plate(X, 1).shape
/Users/nipun/miniforge3/lib/python3.9/site-packages/pyro/primitives.py:137: RuntimeWarning: trying to observe a value outside of inference at obs
  warnings.warn(

torch.Size([200, 2])
import pyro.poutine as poutine

trace = poutine.trace(ppca_model_without_plate).get_trace(X, 1)
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:          
 Param Sites:          
Sample Sites:          
       W dist   |   2 1
        value   |   2 1
     log_prob   |      
       Z dist 1 | 200  
        value 1 | 200  
     log_prob 1 |      
     obs dist   | 200 2
        value   | 200 2
     log_prob   |      
pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(ppca_model_without_plate)
trace = poutine.trace(auto_guide).get_trace(X, 1)
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
       Trace Shapes:            
        Param Sites:            
   AutoNormal.locs.W 2   1      
 AutoNormal.scales.W 2   1      
   AutoNormal.locs.Z 1 200      
 AutoNormal.scales.Z 1 200      
       Sample Sites:            
W_unconstrained dist     |   2 1
               value     |   2 1
            log_prob     |      
              W dist     |   2 1
               value     |   2 1
            log_prob     |      
Z_unconstrained dist 1   | 200  
               value 1   | 200  
            log_prob 1   |      
              Z dist 1   | 200  
               value 1   | 200  
            log_prob 1   |      
import logging
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model_without_plate, auto_guide, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(X, 1)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

output_9_0

pyro.clear_param_store()

def ppca_model(data, latent_dim):
    N, data_dim = data.shape
    W = pyro.sample(
        "W",
        dist.Normal(
            loc=torch.zeros([data_dim, latent_dim]),
            scale=5.0 * torch.ones([data_dim, latent_dim]),
        ).to_event(2),
    )

    with pyro.plate("data", len(data)):
        z_n = pyro.sample("z", dist.Normal(loc=torch.zeros([1, latent_dim]), scale=torch.ones([1, latent_dim])))

        mean = (W@z_n).t()
        y = dist.Normal(mean, 1.).sample()
        d = dist.Normal(mean, 1.)
        e = d.to_event(1)
        pyro.sample("obs", e, obs=data)

pyro.render_model(
    ppca_model, model_args=(X, 1), render_distributions=True
)

output_10_0

auto_guide2 = pyro.infer.autoguide.AutoNormal(ppca_model)
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(ppca_model, auto_guide2, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(X, 1)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))

plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");

output_12_0

Related #3030

fritzo commented 2 years ago

I love the idea of a debugging-focused tutorial. And I suspect core Pyro devs are the worst people to write such a tutorial since it's been so long since we first stubbed our toes 🤣

@nipunbatra this would be a great tutorial to contribute!

nipunbatra commented 2 years ago

Hi @fritzo Thanks.

I have added a long-ish notebook here

As you might notice, this will need some inputs from you especially on