pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 239 forks source link

AutoContinuous not working correctly with TransformedDistribution #594

Closed lumip closed 4 years ago

lumip commented 4 years ago

The implementation of AutoContinuous will not work with TransformedDistribution objects, because the transformations specified for the distribution seem to be never applied. We encountered this issue with a LogNormal distribution and here is a minimal example:

import jax.numpy as np
import jax

import numpyro.distributions as dist
from numpyro.primitives import sample
from numpyro.infer import SVI, ELBO
from numpyro.contrib.autoguide import AutoDiagonalNormal
from numpyro.optim import SGD

import numpy as onp

data = .3 * onp.random.randn(100) + 1.

def model(x):
    mu = sample('mus', dist.Normal(0., 1.))
    sig = sample('sig', dist.LogNormal(0., 1.))
    sample('x', dist.Normal(mu, sig), obs=x)

guide = AutoDiagonalNormal(model)

optim = SGD(1e-4)
svi = SVI(model, guide, optim, ELBO())

rng_key = jax.random.PRNGKey(0) # works with 1, does not work with 0
svi_state = svi.init(rng_key, data)
init_svi_state = svi_state

for _ in range(10):
    svi_state, loss = svi.update(svi_state, data)
    print(loss)

If the key is seeded with 0, this will produce nan values (but it will work fine for a seed of 1, or even if the order of mu and sig in the model is changed for seed 0). The reason is that for a seed of 0 the guide samples negative initial values for sigmas (to which LogNormal.log_prob then applies a log).

We believe the issue lies in AutoContinuous._setup_prototype (particularly line 154ff) where, for sample sites with intermediates, the guide just extracts a transform from unconstrained optimization space to the support of a base distribution and stores it in its _inv_transforms lookup, ignoring all other transformation the site's distribution applies. The __call__ implementation therefore applies only the transform from unconstrained space to the constraint support of the base distribution, not the target distribution.

The LogNormal is implemented as TransformedDistribution(Normal, ExpTransform()), so it's base distribution has real support and therefore the guide extracts an IdentityTransform. The ExpTransform seems to be never applied (at least we could not pinpoint a location where that would be the case), thus the value returned by the guide can be negative.

Note: Pareto is an interesting case, because it doesn't fall into the same trap. It is a TransformedDistribution, applying an ExpTransform and an AffineTransform, which the guide ignores, but it's base distribution is Exponential, which has positive support and thus the guide will extract an ExpTransform anyways, ensuring that the values it samples are positive (albeit probably not scaled properly).

A pull request with a suggested bugfix will follow soon. Edit: We're not sure what would be the best fix for this.

From how the __call__ function is written, it seems that a straightforward fix would be to change line 153 from transform = biject_to(site['fn'].base_dist.support) to transform = ComposeTransform([biject_to(site['fn'].base_dist.support)] + site['fn'].transforms, however, this comes with several caveats:

Edit 2: We have not much experience with Pyro, but from a quick look it seems that the implementation there is almost identical (minus the obvious structural differences and some case distinctions), so it would be interesting to know whether the same problems persists there, but we currently lack the time to do so..

fehiepsi commented 4 years ago

Hi @lumip, we haven't moved autoguide out of contrib to the main numpyro.infer module because there are some issues with it. One reason is as you mentioned above. Could you try AutoContinuousELBO as in neutra example? I believe we can move autoguide to the main inference module after #511 is resolved (then we don't need AutoContinuousELBO anymore).

lumip commented 4 years ago

AutoContinuousELBO works but makes our life a tiny bit more difficult, but I guess that's what we get from using "non-standardized" modules :) Looking forward to the final integration of the autoguide.

I have to say, however, that it seems like a somewhat odd choice to implement the autoguide such that it violates assumptions for existing library parts and then, rather then addressing that, creating variants of those other parts that only work(?) with the autoguide. Is there a specific reason for that?

fehiepsi commented 4 years ago

Great question! Back then, pyro does not have reparameterization implemented, so autoguide will give incorrect results if the model has dynamic support (e.g. a Uniform(0, a) prior, where a is another latent variable). When we developed numpyro, we want to address it by using automatic reparameterization all priors with dynamic support. The solution is good but it adds a lot of complexity in the codebase. A few months ago, a reparameterization module is developed in Pyro. It solves the dynamic support issue and does not add complexity to the main codebase. We'll incorporate it into NumPyro. I believe that it is also the time to move those auto guides to the main library.