pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 988 forks source link

[bug] JitTrace_ELBO fails with AutoNormalizingFlow #3132

Open rafaol opened 2 years ago

rafaol commented 2 years ago

Issue Description

Hi all, I'm having issues trying to use normalizing flows with JIT and Pyro's SVI. Code runs fine if I use the standard Trace_ELBO (or even TraceEnum_ELBO in models with discrete variables), but it fails if I replace the ELBO by its JIT-compiled version. The issue seems to be caused by the use of weak references (weakref) when checking unconstrained parameter values. I was wondering if anyone else is having the same issue, and if there's a fix coming up on the horizon.

Environment

Code Snippet

Here is a minimal example. The following code runs normally if I use Trace_ELBO as the loss function for SVI, but fails with JitTrace_ELBO.

from functools import partial

import pyro
import pyro.optim
import torch
from pyro import distributions
from pyro.distributions import transforms
from pyro.infer import SVI, JitTrace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from tqdm import trange

def test_model(y):
    m = pyro.sample("m", distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    n = y.shape[0]
    with pyro.plate("observations_plate", n):
        r = pyro.sample("y", distributions.MultivariateNormal(m, 0.01*torch.eye(2)), obs=y)
    return r

if __name__ == "__main__":
    pyro.set_rng_seed(0)

    pyro.clear_param_store()

    transform = partial(transforms.iterated, 1, transforms.block_autoregressive)
    guide = AutoNormalizingFlow(test_model, transform)

    svi = SVI(test_model, guide, pyro.optim.Adam(dict(lr=5e-3)), loss=JitTrace_ELBO())

    n_steps = 1000
    t_iter = trange(n_steps)
    n_data = 100

    test_y = torch.randn(n_data, 2)

    for t in t_iter:
        loss = svi.step(test_y)
        t_iter.set_postfix(loss=loss)

Console output

  0%|          | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 250, in loss_and_grads
    loss, surrogate_loss = self.loss_and_surrogate_loss(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/infer/trace_elbo.py", line 239, in loss_and_surrogate_loss
    return self._loss_and_surrogate_loss(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 107, in __call__
    self.compiled[key] = torch.jit.trace(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/torch/jit/_trace.py", line 795, in trace
    traced = torch._C._create_function_from_trace(
  File "/home/rafael/opt/anaconda3/envs/latest_torch/lib/python3.9/site-packages/pyro/ops/jit.py", line 96, in compiled
    assert constrained_param.unconstrained() is unconstrained_param
AssertionError

When I check the runtime environment with the debugger, I see that constrained_param.unconstrained() and unconstrained_param match in value, but constrained_param.unconstrained() holds a weakref to a Parameter object, while unconstrained_param points to a Tensor. The gradient's grad_fn also seems to be mismatched.

Values

name: "AutoNormalizingFlow._prototype_tensor"
constrained_param.unconstrained():
Parameter containing:
Parameter containing:
tensor([[-0.1316,  0.0000],
        [ 0.0864,  0.0000],
        [ 0.7393,  0.0000],
        [-0.7574,  0.0000],
        [-0.5140,  0.0000],
        [-0.2067,  0.0000],
        [-0.3183,  0.0000],
        [ 0.7055,  0.0000],
        [-0.5021, -0.3566],
        [-0.5412, -0.7255],
        [-0.4522,  0.6658],
        [ 0.3456,  0.3754],
        [ 0.0407, -0.3971],
        [ 0.1310, -0.7232],
        [-0.5597, -0.3993],
        [ 0.4887,  0.4542]], requires_grad=True)
unconstrained_param:
tensor([[-0.1316,  0.0000],
        [ 0.0864,  0.0000],
        [ 0.7393,  0.0000],
        [-0.7574,  0.0000],
        [-0.5140,  0.0000],
        [-0.2067,  0.0000],
        [-0.3183,  0.0000],
        [ 0.7055,  0.0000],
        [-0.5021, -0.3566],
        [-0.5412, -0.7255],
        [-0.4522,  0.6658],
        [ 0.3456,  0.3754],
        [ 0.0407, -0.3971],
        [ 0.1310, -0.7232],
        [-0.5597, -0.3993],
        [ 0.4887,  0.4542]], grad_fn=<ViewBackward0>)

Pointer addresses:

fritzo commented 2 years ago

Hi @rafaol, tl;dr don't use the jit, it isn't useful.

I have found that JitTrace_ELBO only works on very simple models, and provides little or no speed improvement. I believe this is because the PyTorch team's goals are more to execute models outside of a Python runtime, in contrast to the JAX team's goals which are more performance oriented. Historically, torch.jit.trace used to slightly speed up some models, but the jit has gradually grown slower so it is no longer useful in Pyro. I've also found that the torch.jit changes so often across PyTorch releases that I can't rely on my jitted models remaining jittable across minor PyTorch releases.

rafaol commented 2 years ago

Thanks @fritzo ! I was actually planning to use JIT on a larger and more complex model (a random effects/hierarchical), which had a significant (10x) speed up in SVI when I tried JIT with a simpler guide (AutoMultivariateNormal and AutoLaplaceApproximation). I wish I could get the same kind of speed ups with normalizing flows, since they're way more flexible and expressive, but I'll have a look at other options within Pyro.