danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
91 stars 13 forks source link

Q: Normal Distribution Compatibility with Reparameterization Trick? #181

Open gil2rok opened 2 weeks ago

gil2rok commented 2 weeks ago

I am interested in using Optax's stochastic gradient estimators and control variates with FlowJAX. In particular, I am interested in compatibility with the reparameterization gradient (aka pathwise estimator).

The reparameterization gradient requires the "reparameterization trick" to compute the gradient of an expectation. For a normally distributed variable $x \sim N(\mu, \sigma)$, this is implemented by rewriting it as $x = z * \sigma + \mu$ for mean $\mu$, scale $\sigma$, and $z \sim N(0,1)$. For more details see: https://gregorygundersen.com/blog/2018/04/29/reparameterization/.

Are the Normal and MultivariateNormal distributions in FlowJAX compatible with the reparameterization trick by default when using jax.grad? I believe the answer is yes because they both take the StandardNormal distribution (equivalent to z above) and transform it with some (affine?) bijection.

I was wondering if @danielward27 can please confirm if this is true. If so, it may be worth mentioning somewhere in the docs. Thank you!

danielward27 commented 2 weeks ago

I haven't explicitly tested, but I would expect all the distributions thus far in FlowJAX to support reparameterized gradients, as the sampling operations become differentiable deterministic functions after setting the key. An example:

import jax.random as jr
from flowjax.distributions import Normal
import equinox as eqx

@eqx.filter_grad
def sample_model(dist, **kwargs):
    return dist.sample(**kwargs)

dist = Normal()
grad = sample_model(dist, key=jr.key(0))
assert grad.bijection.loc == 1

You can also get the gradient of the scale parameter with grad.bijection.scale.args[0], but as the scale parameter is constrained to be positive, this corresponds to the gradient of the sample w.r.t. the unconstrained scale representation (prior to applying softplus). This is generally what we want for optimization as e.g. we to perform updates without risking invalid values. It's worth noting that generally, JAX itself supports reparameterized gradients too e.g.

import jax
@jax.grad
def sample_beta(a, **kwargs):
    return jr.beta(a=a, **kwargs)

sample_beta(0.1, b=0.2, key=jr.key(0))

So generally, FlowJAX isn't doing anything clever to support it, but inheriting this property from JAX. There are cases where reparameterized gradients are not possible (e.g. discrete distributions or non-differentiable functions), but AFAIK, currently, all the distributions (and flows) in FlowJAX are naturally compatible with reparameterized gradients.

I'll leave this open for now, and would be happy to take a pull request for improving documentation. Maybe it's worth adding an example like the normal example above to the "Distributions and Bijections as PyTrees" section. I'll likely get around to it at some point regardless.

gil2rok commented 2 weeks ago

Thanks so much for the thorough response! Your explanation all makes sense.

Closing this now and thanks again for this incredible library.