tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

VI: Discrete + Continuous RVs inference: gradient estimation ? #1534

Open LouisRouillard opened 2 years ago

LouisRouillard commented 2 years ago
tensorflow==2.7.0
tensorflow-probability==0.14.1

TLDR

To perform VI on discrete RVs, should I use:

and how to implement it ?

Problem statement

Sorry in advance for the long issue, but I believe the problem requires some explaining.

I want to implement a Hierarchical Bayesian Model involving both continuous and discrete Random Variables. A minimal example is a Gaussian Mixture model:

import tensorflow as tf
import tensorflow_probability as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

G = 2

p = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        z=tfd.Categorical(
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, z: tfd.Normal(
            loc=mu[z],
            scale=1.
        )
    )
)

In this example I don't use the tfd.Mixture API on purpose to expose the Categorical label. I want to perform Variational Inference in this context, and for instance given an observed x fit over the posterior of z a Categorical distribution with parametric probabilities:

q_probs = tfp.util.TransformedVariable(
    tf.ones((G,)) / G,
    tfb.SoftmaxCentered(),
    name="q_probs"
)
q_loc = tf.Variable(0., name="q_loc")
q_scale = tfp.util.TransformedVariable(
    1.,
    tfb.Exp(),
    name="q_scale"
)

q = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Normal(q_loc, q_scale),
        z=tfd.Categorical(probs=q_probs)
    )
)

The issue is: when computing the ELBO and trying to optimize for the optimal q_probs I cannot use the reparameterization gradient estimators: this is AFAIK because z is a discrete RV:


def log_prob_fn(**kwargs):
    return p.log_prob(
        **kwargs,
        x=tf.constant([2.])
    )

optimizer = tf.optimizers.SGD()

@tf.function
def fit_vi():
    return tfp.vi.fit_surrogate_posterior(
        target_log_prob_fn=log_prob_fn,
        surrogate_posterior=q,
        optimizer=optimizer,
        num_steps=10,
        sample_size=8
    )

_ = fit_vi() 
# This last line raises:
# ValueError: Distribution `surrogate_posterior` must be reparameterized, i.e.,a diffeomorphic transformation
# of a parameterless distribution. (Otherwise this function has a biased gradient.)

I'm looking into a way to make this work. I've identified at least 2 ways to circumvent the issue: using REINFORCE gradient estimator or the Gumbel-Softmax reparameterization.

A- REINFORCE gradient

cf this TFP API link a classical result in VI is that the REINFORCE gradient can deal with a non-differentiable objective function, for instance due to discrete RVs.

I can use a tfp.vi.GradientEstimators.SCORE_FUNCTION estimator instead of the tfp.vi.GradientEstimators.REPARAMETERIZATION one using the lower-level tfp.vi.monte_carlo_variational_loss function ? Using the REINFORCE gradient, In only need the log_prob method of q to be differentiable, but the sample method needn't be differentiated.

As far as I understood it, the sample method for a Categorical distribution implies a gradient break, but the log_prob method does not. Am I correct to assume that this could help with my issue? Am I missing something here?

Also I wonder: why is this possibility not exposed in the tfp.vi.fit_surrogate_posterior API ? Is the performance bad, meaning is the variance of the estimator too large for practical purposes ?

B- Gumbel-Softmax reparameterization

cf this TFP API link I could also reparameterize z as a variable y = tfd.RelaxedOneHotCategorical(...) . The issue is: I need to have a proper categorical label to use for the definition of x, so AFAIK I need to do the following:

p_GS = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, y: tfd.Normal(
            loc=mu[tf.argmax(y)],
            scale=1.
        )
    )
)

...but his would just move the gradient breaking problem to tf.argmax. This is where I maybe miss something. Following the Gumbel-Softmax (Jang et al., 2016) paper, I could then use the "STRAIGHT-THROUGH" (ST) strategy and "plug" the gradients of the variable tf.one_hot(tf.argmax(y)) -the "discrete y"- onto y -the "continuous y".

But again I wonder: how to do this properly ? I don't want to mix and match the gradients by hand, and I guess an autodiff backend is precisely meant to avoid me this issue. How could I create a distribution that differentiates the forward direction (sampling a "discrete y") from the backward direction (gradient computed using the "continuous y") ? I guess this is the meant usage of the tfd.RelaxedOneHotCategorical distribution, but I don't see this implemented anywhere in the API.

Should I implement this myself ? How ? Could I use something in the lines of tf.custom_gradient?

Actual question

Which solution -A or B or another- is meant to be used in the TFP API, if any? How should I implement said solution efficiently?

LouisRouillard commented 2 years ago

For closure I looked into this issue for a couple days and here are my conclusions:

First off, we need to reparameterize the joint distribution p as the KL between a discrete and a continuous distribution is ill-defined (as explained in the Maddison et al. (2017) paper). To not break the gradients, I implemented a simple one_hot_straight_through operation that converts the continuous RV y into a discrete RV z:

G = 2

@tf.custom_gradient
def one_hot_straight_through(y):
    depth = y.shape[-1]
    z = tf.one_hot(
        tf.argmax(
            y,
            axis=-1
        ),
        depth=depth
    )

    def grad(upstream):
        return upstream

    return z, grad

p = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Sample(
            tfd.Normal(0., 1.),
            sample_shape=(G,)
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=tf.ones((G,)) / G
        ),
        x=lambda mu, y: tfd.Normal(
            loc=tf.reduce_sum(
                one_hot_straight_through(y)
                * mu
            ),
            scale=1.
        )
    )
)

The variational distribution q follows the same reparameterization and the following code bit does work:

q_probs = tfp.util.TransformedVariable(
    tf.ones((G,)) / G,
    tfb.SoftmaxCentered(),
    name="q_probs"
)
q_loc = tf.Variable(tf.zeros((2,)), name="q_loc")
q_scale = tfp.util.TransformedVariable(
    1.,
    tfb.Exp(),
    name="q_scale"
)

q = tfd.JointDistributionNamed(
    model=dict(
        mu=tfd.Independent(
            tfd.Normal(q_loc, q_scale),
            reinterpreted_batch_ndims=1
        ),
        y=tfd.RelaxedOneHotCategorical(
            temperature=1.,
            probs=q_probs
        )
    )
)

def log_prob_fn(**kwargs):
    return p.log_prob(
        **kwargs,
        x=tf.constant([2.])
    )

optimizer = tf.optimizers.SGD()

@tf.function
def fit_vi():
    return tfp.vi.fit_surrogate_posterior(
        target_log_prob_fn=log_prob_fn,
        surrogate_posterior=q,
        optimizer=optimizer,
        num_steps=10,
        sample_size=8
    )

_ = fit_vi()

Now there are several issues with that design:

Recent methods like REBAR (Tucker et al. (2017)) or RELAX (Grathwohl et al. (2018)) can instead obtain unbiased estimators with a lower variance than the original REINFORCE. But they do so at the cost of introducing -learnable- control variates with separate losses. Modifications of the one_hot_straight_through functions could probably implement this.

In conclusion my opinion is that the support for discrete RVs optimization is too scarce at the moment and that the API lacks native functions and tutorials to make it easier for the user. I don't think the issue should be closed.