pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

Trouble with basic parameter estimation with the Stable distribution #3280

Closed fehiepsi closed 5 months ago

fehiepsi commented 1 year ago

[Moved the discussion in this forum thread to here.]

Doing MLE on the skewness parameter of Stable distribution does not recover the original parameter. The skewness tends to stay around 0 regardless initial values.

# adapted from https://github.com/fritzo/notebooks/blob/master/stable_mle.ipynb
import math
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.reparam import MinimalReparam
from pyro.infer.autoguide import AutoNormal, AutoGaussian

torch.set_default_dtype(torch.float64)
pyro.set_rng_seed(20230928)

# Define true parameters and number of datapoints
alpha = 1.5
beta = 0.8
c = 1.0
mu = 0.0
n = 10000

# sample data
data = dist.Stable(alpha, beta, c, mu).sample((n,))

@MinimalReparam()
def model(data):    
    alpha = 1.5  # pyro.param("alpha", torch.tensor(1.99), constraint=constraints.interval(0, 2))
    beta = pyro.param("beta", torch.tensor(0.5), constraint=constraints.interval(-1, 1))
    c = 1.0  # pyro.param("c", torch.tensor(1.0), constraint=constraints.positive)
    mu = 0.0  # pyro.param("mu", torch.tensor(0.0), constraint=constraints.real)
    with pyro.plate("data", data.shape[0]):
        pyro.sample("obs", dist.Stable(alpha, beta, c, mu), obs=data)

def train(model, guide, num_steps=1001, lr=0.1):
    pyro.clear_param_store()
    pyro.set_rng_seed(20230928)

    # set up ELBO, and optimizer
    elbo = Trace_ELBO()
    elbo.loss(model, guide, data=data)
    optim = pyro.optim.Adam({"lr": lr})
    svi = SVI(model, guide, optim, loss=elbo)

    # optimize
    losses = []
    for i in range(num_steps):
        loss = svi.step(data) / data.numel()
        losses.append(loss)
        if i % 100 == 0:
            print(f"step {i} loss = {loss:0.6g}")

    print(f"Parameter estimates (n = {n}):")
    print(f"beta: Estimate = {pyro.param('beta')}, true = {beta}")
    return losses

guide = AutoNormal(model)
train(model, guide);

gives us something like

step 0 loss = 57.484
step 100 loss = 2.82283
step 200 loss = 2.64291
step 300 loss = 2.57734
step 400 loss = 2.57825
step 500 loss = 2.55413
step 600 loss = 2.58021
step 700 loss = 2.55925
step 800 loss = 2.56627
step 900 loss = 2.55161
step 1000 loss = 2.54688
Parameter estimates (n = 10000):
beta: Estimate = -0.0007896144629401247, true = 0.8
fritzo commented 1 year ago

See the possibly related #3214. I'll also take a look at this.

fehiepsi commented 1 year ago

I tried using pytorch < 2.0 but still got the same issue. I'm double checking the math of StableReparam. So far, the _safe_shift/_unsafe_shift is unrelated because the stability is 1.5.

fritzo commented 1 year ago

One possibility is that the reparametrizer is correct, but the mean-field AutoNormal variational posterior is really bad. We can't rely on the Bernstein-von-Mises theorem here because StableReparam introduces four new latent variables per data point, so there is never any concentration in the parameter space of those latent variables, and hence Gaussian guides don't get better with more data.

Here are some ideas we might try to improve the variational approximation:

One way we could validate this hypothesis is to see if HMC recovers correct parameters. If so, that would imply the reparametrizers are correct, so the variational approximation is at fault.

fritzo commented 1 year ago

Here's a notebook examining the posteriors over the latents introduced by StableReparam. Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail, for example:

image

This suggests the SVI estimates of parameters may be off.

fehiepsi commented 1 year ago

I used a custom guide which includes auxiliary variables (so prior and guide cancel out) and fitted the skewness to maximize the (MC estimated) likelihood - but no luck. Doing some random things like detach abs(skewness) in the reparam implementation seems to help, so it could be that some grad is wrong (due to clipping or something). I'll fit a jax version tomorrow to see how things look like.

fehiepsi commented 1 year ago

Turns out that there's something wrong with my understanding of auxiliary methods. Let's consider a simple example: A + X = Normal(0, 1) + Normal(0, scale) ~ Normal(0, sqrt(2)), where A is an auxiliary variable. It's clear that the expected scale is 1. In the following, I used 4 approaches:

import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
import optax

def model(data=None, scale_init=1.0):
    scale = numpyro.param("scale", scale_init, constraint=dist.constraints.positive)
    # jax.debug.print("scale={scale}", scale=scale)
    with numpyro.plate("N", data.shape[0] if data is not None else 10000):
        auxiliary = numpyro.sample("auxiliary", dist.Normal(0, 1))
        return numpyro.sample("obs", dist.Normal(auxiliary, scale), obs=data)

data = numpyro.handlers.seed(model, rng_seed=0)(scale_init=1.0)
print("Data std:", jnp.std(data))
svi = SVI(model, lambda _: None, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoDelta', svi_results.params['scale'])

svi = SVI(model, AutoNormal(model), optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoNormal', svi_results.params['scale'])

def guide(data=None, scale_init=1.0):
    with numpyro.plate("N", data.shape[0]):
        loc = numpyro.param("loc", jnp.zeros_like(data))
        numpyro.sample("auxiliary", dist.Normal(loc, 1))

svi = SVI(model, guide, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using CustomGuide', svi_results.params['scale'])

gives us

Data std: 1.4132578
scale using no guide 1.7334453
scale using AutoDelta 0.02048848
scale using AutoNormal 1.0833057
scale using CustomGuide 1.3865443

So I think to make auxiliary method work, we need to know the geometry of the posterior of auxiliary variables. In the stable case, the posterior of auxiliary variables zu, ze, tu, te are likely non-gaussians. I'm not sure what is a good approach here.

fehiepsi commented 1 year ago

Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail

Nice, @fritzo! I understand your comment better now after playing with the above toy example. So we need to use a more sophisticated guide here?

fritzo commented 1 year ago

@fehiepsi correct, we need a more sophisticated guide. I just played around with Pyro's conditional normalizing flows, but I haven't gotten anything working yet 😞 I can't even seem to get an amortized diagonal normal guide working (the bottom of this notebook).

mawright commented 1 year ago

I missed that you guys moved to Github. Repeating what I wrote on the Pyro board, I implemented a log_prob() for Stable that seems to work pretty well. I have a similar MLE example in [this notebook]{https://github.com/mawright/torchstable/blob/main/stable_demo.ipynb). It seems to be working well:

image

After some testing it seems to get a little unstable near $\alpha=1$ (e.g., for a true value of 1.1 it converges to 1.15 with noiseless data), which I'm looking into fixing.