Closed lumip closed 4 years ago
Hi @lumip, we haven't moved autoguide out of contrib
to the main numpyro.infer
module because there are some issues with it. One reason is as you mentioned above. Could you try AutoContinuousELBO as in neutra example? I believe we can move autoguide to the main inference module after #511 is resolved (then we don't need AutoContinuousELBO anymore).
AutoContinuousELBO
works but makes our life a tiny bit more difficult, but I guess that's what we get from using "non-standardized" modules :) Looking forward to the final integration of the autoguide.
I have to say, however, that it seems like a somewhat odd choice to implement the autoguide such that it violates assumptions for existing library parts and then, rather then addressing that, creating variants of those other parts that only work(?) with the autoguide. Is there a specific reason for that?
Great question! Back then, pyro does not have reparameterization implemented, so autoguide will give incorrect results if the model has dynamic support (e.g. a Uniform(0, a) prior, where a
is another latent variable). When we developed numpyro, we want to address it by using automatic reparameterization all priors with dynamic support. The solution is good but it adds a lot of complexity in the codebase. A few months ago, a reparameterization module is developed in Pyro. It solves the dynamic support issue and does not add complexity to the main codebase. We'll incorporate it into NumPyro. I believe that it is also the time to move those auto guides to the main library.
The implementation of
AutoContinuous
will not work withTransformedDistribution
objects, because the transformations specified for the distribution seem to be never applied. We encountered this issue with aLogNormal
distribution and here is a minimal example:If the key is seeded with 0, this will produce
nan
values (but it will work fine for a seed of 1, or even if the order of mu and sig in the model is changed for seed 0). The reason is that for a seed of 0 the guide samples negative initial values for sigmas (to whichLogNormal.log_prob
then applies a log).We believe the issue lies in
AutoContinuous._setup_prototype
(particularly line 154ff) where, for sample sites with intermediates, the guide just extracts a transform from unconstrained optimization space to the support of a base distribution and stores it in its_inv_transforms
lookup, ignoring all other transformation the site's distribution applies. The__call__
implementation therefore applies only the transform from unconstrained space to the constraint support of the base distribution, not the target distribution.The
LogNormal
is implemented asTransformedDistribution(Normal, ExpTransform())
, so it's base distribution has real support and therefore the guide extracts anIdentityTransform
. TheExpTransform
seems to be never applied (at least we could not pinpoint a location where that would be the case), thus the value returned by the guide can be negative.Note:
Pareto
is an interesting case, because it doesn't fall into the same trap. It is aTransformedDistribution
, applying anExpTransform
and anAffineTransform
, which the guide ignores, but it's base distribution isExponential
, which has positive support and thus the guide will extract anExpTransform
anyways, ensuring that the values it samples are positive (albeit probably not scaled properly).A pull request with a suggested bugfix will follow soon.Edit: We're not sure what would be the best fix for this.From how the
__call__
function is written, it seems that a straightforward fix would be to change line 153 fromtransform = biject_to(site['fn'].base_dist.support)
totransform = ComposeTransform([biject_to(site['fn'].base_dist.support)] + site['fn'].transforms
, however, this comes with several caveats:TransformedDistribution
classes that define atransforms
attribute and will crash for all other distributions that sample intermediatesSVIState
do that already._unpack_and_constrain
seems to be doing that maybe, but we're not quite sure and wouldn't know how to integrate it into__call__
, which is currently not using_unpack_and_constrain
but doing it's own stuff.Edit 2: We have not much experience with Pyro, but from a quick look it seems that the implementation there is almost identical (minus the obvious structural differences and some case distinctions), so it would be interesting to know whether the same problems persists there, but we currently lack the time to do so..