pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 984 forks source link

NeuTraReparam does not work as supposed: AssertionError thrown #2317

Closed ameshkovskiy closed 4 years ago

ameshkovskiy commented 4 years ago

Hi @fehiepsi, following the discussion on the forum.

Issue Description

Issue appears when I try to use NeuTraReparam as suggested in docs: (1) train the guide using SVI inference, (2) use Neural Transport reparameterization (NeuTraReparam), as described here.

Environment

Code Snippet

# load dependencies
import torch
import pyro
import pyro.distributions as dist

from numpyro.examples.datasets import SP500, load_dataset # the one from pyro didn't work
from pyro import poutine
from pyro.infer import NUTS, HMC
from pyro.infer.mcmc.api import MCMC
from pyro.infer.autoguide.guides import AutoIAFNormal
from pyro.infer.reparam import NeuTraReparam
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import SGD

# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:100]

# define model
def model(returns):
    # init parameters
    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))

    h = torch.empty(len(returns))

    for t in pyro.poutine.markov(range(len(returns))):
        if t == 0:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu, sigma2**0.5).to_event(0))
        else:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu + phi * (h[t-1] - mu), sigma2**0.5).to_event(0))

    y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)

# define main()
def main(model, tseries):
    guide = AutoIAFNormal(model)
    optimizer = SGD({"lr": 0.001})
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    n_steps = 500
    for step in range(n_steps):
        svi.step(torch.Tensor(tseries))
        if step % 100 == 0:
            print(svi.step(torch.Tensor(tseries)))
    neutra = NeuTraReparam(guide)
    model = poutine.reparam(model, config=lambda _: neutra)
    hmc_kernel = HMC(model)
    hmc = MCMC(hmc_kernel,
               num_samples=100,
               warmup_steps=10,
               num_chains=1)

    hmc.run(torch.Tensor(tseries))

    return hmc

# run
run_output = main(model, returns)

In my case I get the error: AssertionError: NeuTraReparam does not support observe statements

From debugging I got that the error is thrown during hmc.run() and that obs are not None for variable 'y'.

Thanks a lot,

Arthur

fritzo commented 4 years ago

@ameshkovskiy thanks for the runnable script!

@fehiepsi are you able to reproduce the error? On my system the script runs without error:

$ python temp.py
Downloading - https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv.
Download complete.
157.42280922830105
132.9466832280159
127.23020279407501
119.74025339446962
116.92387166619301
Sample: 100%|██████████████████████████████████████████| 110/110 [03:20,  1.82s/it, step size=2.50e-01, acc. prob=0.720]

That's using Pyro dev

>>> import pyro
>>> pyro.__version__
'1.2.1+8b0f91c1'
>>> import torch
>>> torch.__version__
'1.4.0'
>>>
neerajprad commented 4 years ago

Hmm..I think this is a problem with the release branch and is fixed in dev with https://github.com/pyro-ppl/pyro/pull/2295.

fehiepsi commented 4 years ago

Yes, I can confirm that this is fixed in dev branch.

ameshkovskiy commented 4 years ago

Thanks a lot, the dev release works indeed. However, when I run it and call run_output.summary(), I get all the variable names renamed to phi_shared_latent[i], i.e.:

                            mean       std    median      5.0%     95.0%     n_eff     r_hat
  phi_shared_latent[0]      3.57      2.54      3.36     -0.45      7.18    195.27      1.00
  phi_shared_latent[1]     -5.76      1.87     -5.88     -8.93     -2.80    110.65      1.01
  phi_shared_latent[2]     -0.35      1.36     -0.27     -2.48      1.58    244.20      1.00
  phi_shared_latent[3]      0.00      0.80      0.02     -1.23      1.26    275.24      1.01
  phi_shared_latent[4]      0.10      0.80      0.10     -1.29      1.38    213.14      1.00
  phi_shared_latent[5]      0.00      0.74      0.00     -1.07      1.29    422.45      1.00
  phi_shared_latent[6]      0.02      0.74      0.03     -1.35      1.02    393.41      1.00
  phi_shared_latent[7]     -0.20      0.79     -0.17     -1.43      1.14    416.51      1.00
  phi_shared_latent[8]      0.05      0.73      0.02     -1.00      1.34    188.56      1.02

...

phi_shared_latent[100]      0.32      0.84      0.30     -0.93      1.81    385.45      1.00
phi_shared_latent[101]     -0.27      0.78     -0.26     -1.58      0.92    334.07      1.00
phi_shared_latent[102]     -0.14      0.82     -0.14     -1.46      1.10    238.38      1.00

Number of divergences: 0

Is it the expected behavior?

Arthur

fehiepsi commented 4 years ago

@ameshkovskiy The behavior is expected (I think we get "neutralized" samples here), but the result seems to be not useful. We need an extra step to transform those samples back to the original space.

ameshkovskiy commented 4 years ago

Thanks, got you. I see the method transform_sample(). Could you kindly elaborate more on where an how this method should be called?

neerajprad commented 4 years ago

@ameshkovskiy - you might find this example on the dev branch useful.

ameshkovskiy commented 4 years ago

Thanks a lot for the reply and example. Looks like working. However, I find it a bit confusing that all the variable names, i.e. sigma, h[i], etc. are renamed to phi_shared_latent. I understand that in .summary() output, the order of variables is preserved, right?

ameshkovskiy commented 4 years ago

Guys, I have a follow up question on the topic, but related to the slow performance of pyro, as compared to pymc3 in particular. Please let me know if you want me to move this to a separate thread.

I came across the following related discussion on slow runs, and this discussion on comparison to pymc3 for a very similar model; however, it seems that the difference in computation times still persist and it is huge:

Do you happen to know what is the underlying cause of such a dramatically slow performance, as compared to pymc3 in particular?

Below I copy the pyro code based on the example you provided earlier. If needed, I can provide the pymc3 one as well.

Thanks a lot!

import argparse
import logging
from functools import partial
import torch
import pyro

from pyro import optim, poutine
import pyro.distributions as dist
from pyro.distributions.transforms import iterated, planar
from pyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormalizingFlow
from pyro.infer.reparam import NeuTraReparam
from numpyro.examples.datasets import SP500, load_dataset

# load data
_, fetch = load_dataset(SP500, shuffle=False)
dates, returns = fetch()
returns = returns[:101]

logging.basicConfig(format='%(message)s', level=logging.INFO)

def model(returns):

    phi = pyro.sample("phi", dist.Beta(20, 1.5))
    phi = 2 * phi - 1
    sigma2 = pyro.sample('sigma2', dist.InverseGamma(2.5, 0.025))
    mu = pyro.sample("mu", dist.Normal(0, 10))

    h = torch.empty(len(returns))
    for t in pyro.poutine.markov(range(len(returns))):
        if t == 0:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu, sigma2**0.5/torch.sqrt(1. - phi * phi)).to_event(0))
        else:
            h[t] = pyro.sample(f'h_{t}', dist.Normal(mu + phi * (h[t-1] - mu), sigma2**0.5).to_event(0))

    y = pyro.sample('y', dist.Normal(0., (h / 2.).exp()), obs=returns)

# define aux fn to fit the guide
def fit_guide(guide, args):
    pyro.clear_param_store()
    adam = optim.Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, adam, Trace_ELBO())
    for i in range(args.num_steps):
        loss = svi.step(args.data)
        if i % 500 == 0:
            logging.info("[{}]Elbo loss = {:.2f}".format(i, loss))

# define aux fn to run HMC
def run_hmc(args, model, print_summary=False):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, warmup_steps=args.num_warmup, num_samples=args.num_samples)
    mcmc.run(args.data)
    if print_summary:
        mcmc.summary()
    return mcmc

def main(args):
    pyro.set_rng_seed(args.rng_seed)

    outDict = {}

    # If we want the Normalizing Flow
    # fit autoguide
    logging.info('\nFitting a BNAF autoguide ...')
    guide = AutoNormalizingFlow(model, partial(iterated, args.num_flows, planar))
    fit_guide(guide, args)

    # Draw samples using NeuTra HMC
    logging.info('\nDrawing samples using BNAF autoguide + NeuTra HMC ...')
    neutra = NeuTraReparam(guide.requires_grad_(False))
    neutra_model = poutine.reparam(model, config=lambda _: neutra)
    mcmc = run_hmc(args, neutra_model)
    zs = mcmc.get_samples()['phi_shared_latent']

    samples = neutra.transform_sample(zs)

    outDict['nf_neutra_mcmc'] = mcmc
    outDict['nf_neutra_samples'] = samples

    return outDict

if __name__ == '__main__':
    assert pyro.__version__.startswith('1.2.1')
    parser = argparse.ArgumentParser(description='Example illustrating NeuTra Reparametrizer')
    parser.add_argument('-n', '--num-steps', default=1000, type=int,
                        help='number of SVI steps')
    parser.add_argument('-lr', '--learning-rate', default=1e-4, type=float,
                        help='learning rate for the Adam optimizer')
    parser.add_argument('--rng-seed', default=123, type=int,
                        help='RNG seed')
    parser.add_argument('--num-warmup', default=1000, type=int,
                        help='number of warmup steps for NUTS')
    parser.add_argument('--num-samples', default=2000, type=int,
                        help='number of samples to be drawn from NUTS')
    parser.add_argument('--data', default=torch.Tensor(returns), type=float,
                        help='Time-series of returns')
    parser.add_argument('--num-flows', default=4, type=int,
                        help='number of flows in the BNAF autoguide')

    args = parser.parse_args(args=[])

    outDict = main(args)
fritzo commented 4 years ago

Hi @ameshkovskiy could you ask this question at http://forum.pyro.ai? Short answer is: you should vectorize your model.

ameshkovskiy commented 4 years ago

Thanks, created here.

fehiepsi commented 4 years ago

The topic issue has been fixed. Let's have further discussion at the forum.