Closed fehiepsi closed 5 months ago
See the possibly related #3214. I'll also take a look at this.
I tried using pytorch < 2.0 but still got the same issue. I'm double checking the math of StableReparam. So far, the _safe_shift/_unsafe_shift is unrelated because the stability is 1.5.
One possibility is that the reparametrizer is correct, but the mean-field AutoNormal
variational posterior is really bad. We can't rely on the Bernstein-von-Mises theorem here because StableReparam
introduces four new latent variables per data point, so there is never any concentration in the parameter space of those latent variables, and hence Gaussian guides don't get better with more data.
Here are some ideas we might try to improve the variational approximation:
Beta
posteriors for the latent uniform random variables and/or Gamma
for the latent exponential variables. These might better stick to the constraint boundaries than AutoNormal
.nn: datum -> params
where we simply fit a curve in parameter space ranging from datum=-inf to datum=inf. In the case of StableReparam
with and AutoNormal
guide this would be a function R -> R^8
(or R -> R^4 x (0,inf)^4
) that could be fit via independent splines. This trick won't really help the variational fit, but will speed up the variational fitting computation, so we could use it together with richer reparametrizations.One way we could validate this hypothesis is to see if HMC recovers correct parameters. If so, that would imply the reparametrizers are correct, so the variational approximation is at fault.
Here's a notebook examining the posteriors over the latents introduced by StableReparam
. Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail, for example:
This suggests the SVI estimates of parameters may be off.
I used a custom guide which includes auxiliary variables (so prior and guide cancel out) and fitted the skewness to maximize the (MC estimated) likelihood - but no luck. Doing some random things like detach abs(skewness) in the reparam implementation seems to help, so it could be that some grad is wrong (due to clipping or something). I'll fit a jax version tomorrow to see how things look like.
Turns out that there's something wrong with my understanding of auxiliary methods. Let's consider a simple example: A + X = Normal(0, 1) + Normal(0, scale) ~ Normal(0, sqrt(2)), where A is an auxiliary variable. It's clear that the expected scale is 1. In the following, I used 4 approaches:
scale = sqrt(3)
.scale ~ 0.02
- under the hood, I think the MAP auxiliary point tends to be near the data point, hence the scale leans to be very small.scale ~ 1
which is what we want.scale ~ 1.38
import jax
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
import optax
def model(data=None, scale_init=1.0):
scale = numpyro.param("scale", scale_init, constraint=dist.constraints.positive)
# jax.debug.print("scale={scale}", scale=scale)
with numpyro.plate("N", data.shape[0] if data is not None else 10000):
auxiliary = numpyro.sample("auxiliary", dist.Normal(0, 1))
return numpyro.sample("obs", dist.Normal(auxiliary, scale), obs=data)
data = numpyro.handlers.seed(model, rng_seed=0)(scale_init=1.0)
print("Data std:", jnp.std(data))
svi = SVI(model, lambda _: None, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoDelta', svi_results.params['scale'])
svi = SVI(model, AutoNormal(model), optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using AutoNormal', svi_results.params['scale'])
def guide(data=None, scale_init=1.0):
with numpyro.plate("N", data.shape[0]):
loc = numpyro.param("loc", jnp.zeros_like(data))
numpyro.sample("auxiliary", dist.Normal(loc, 1))
svi = SVI(model, guide, optax.adam(0.1), Trace_ELBO())
svi_results = svi.run(random.PRNGKey(2), 201, data, progress_bar=False)
plt.plot(svi_results.losses)
print('scale using CustomGuide', svi_results.params['scale'])
gives us
Data std: 1.4132578
scale using no guide 1.7334453
scale using AutoDelta 0.02048848
scale using AutoNormal 1.0833057
scale using CustomGuide 1.3865443
So I think to make auxiliary method work, we need to know the geometry of the posterior of auxiliary variables. In the stable case, the posterior of auxiliary variables zu, ze, tu, te
are likely non-gaussians. I'm not sure what is a good approach here.
Indeed the posteriors from HMC look quite non-Gaussian for observations in the tail
Nice, @fritzo! I understand your comment better now after playing with the above toy example. So we need to use a more sophisticated guide here?
@fehiepsi correct, we need a more sophisticated guide. I just played around with Pyro's conditional normalizing flows, but I haven't gotten anything working yet 😞 I can't even seem to get an amortized diagonal normal guide working (the bottom of this notebook).
I missed that you guys moved to Github. Repeating what I wrote on the Pyro board, I implemented a log_prob()
for Stable that seems to work pretty well. I have a similar MLE example in [this notebook]{https://github.com/mawright/torchstable/blob/main/stable_demo.ipynb). It seems to be working well:
After some testing it seems to get a little unstable near $\alpha=1$ (e.g., for a true value of 1.1 it converges to 1.15 with noiseless data), which I'm looking into fixing.
[Moved the discussion in this forum thread to here.]
Doing MLE on the skewness parameter of Stable distribution does not recover the original parameter. The skewness tends to stay around 0 regardless initial values.
gives us something like