pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

poutine.reparam incompatible with poutine.scale #2938

Open fritzo opened 3 years ago

fritzo commented 3 years ago

Consider the model

def model1(data):
    # -x is exponentially distributed
    x = pyro.sample("x", TransformedDistribution(Uniform(0, 1), ExpTransform().inv)
    pyro.sample("obs", Normal(0, -x), obs=data)

and its reparametrized version

model2 = poutine.reparam(model1, {"x": TransformReparam()})

equivalent to

def model3(data):
    x_base = pyro.sample("x_base", Uniform(0, 1))
    pyro.sample("obs", Normal(0, -x_base.log()), obs=data)

Now observe that @poutine.scaleing the first model changes the effective rate parameter of x, whereas scaling model2 has no effect on the prior.

Am I missing something? Does this mean poutine.reparam is incompatible with subsampling? Can we fix it somehow?

martinjankowiak commented 3 years ago

yeah i think this is inherent. changing coordinate systems changes densities. scaling is raising a density to a power. raising densities to a given power in different coordinate systems will have different effects.

still, in the prototypical case with a global latent variable and N iid likelihood terms if you reparam the latent variable and then do SVI + data subsampling then everything is still fine provided that the scaling happens after the reparameterization. you're in a different coordinate system so of course the (say) mean field approximation may work better or worse but it'll still be a valid inference procedure in both coordinate systems