tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.23k stars 1.09k forks source link

How exactly should we scale the KL divergence of a layer when doing stochastic gradient descent? #651

Open nbro opened 4 years ago

nbro commented 4 years ago

The documentation of the property losses of the class Convolution2DFlipout states

Upon being built, this layer adds losses (accessible via the losses property) representing the divergences of kernel and/or bias surrogate posteriors and their respective priors. When doing minibatch stochastic optimization, make sure to scale this loss such that it is applied just once per epoch (e.g. if kl is the sum of losses for each element of the batch, you should pass kl / num_examples_per_epoch to your optimizer).

The documentation for Convolution2DReparameterization, DenseReparameterization, etc., states a similar thing. I've read the paper "Weight Uncertainty in Neural Networks", so I am quite familiar with the theoretical topics behind Bayesian neural networks, including the ELBO loss (which is composed of a KL divergence part and the likelihood part).

What exactly does the losses property keep track of, in mathematical terms? For example, in the first forward pass (of the first epoch) with one training example (a batch of size one), what does losses will contain? In the second forward pass with another single training example, what will losses contain? Will it accumulate the KL part of the loss for each training example fed into the network (or maybe it resets the previous losses field before calculating the new one)? If in the first forward pass we use a batch of size K > 1, rather than just 1 training example, what will losses contain? And in the second iteration with again a batch of size K > 1, what will losses contain? And in the M iteration?

The KL part of the ELBO loss does not depend on the input data, but it depends on the specifically sampled weights during the forward pass.

In your description (the documentation), the relationship between an epoch (which I assume to be a forward pass for all training examples), a batch and kl (which you assume to be the sum of the losses) is unclear. It is not even clear what kl really is. You say it is the sum of losses, but, again, it is unclear how losses is computed (for each layer and for the model and for 1, 2, or more training examples) or what it really is. Nonetheless, I know that the losses field of a model is a list of size M, where M is the number of layers of this model that contain the field losses (i.e. Bayesian layers or layers that have a regularisation term).

There's a related issue https://github.com/tensorflow/probability/issues/282, where @SiegeLordEx suggests dividing the divergence between the prior and the posterior by the total number of training examples, which is also suggested by @junpenglao in another related issue https://github.com/tensorflow/probability/issues/396. See also https://github.com/tensorflow/probability/issues/127.

bgroenks96 commented 4 years ago

The BNN example does something even worse. It rescales by the number of examples in the whole dataset. This makes the KL term very small and as a result the estimated posterior variance is artificially small. Just add this line somewhere in the training loop:

print(f'Layer 0 posterior stddev: {tf.math.reduce_mean(model.layers[0].kernel_posterior.variance())}')

I am seeing variances of about 0.0013 which is a standard deviation of about 0.03. This is two orders of magnitude lower than the default prior variance of 1.

nbro commented 4 years ago

@bgroenks96 I am actually still wondering how scaling affects learning. As you say, it can actually affect our perception of the quality of the model, but not only, because, if it affects also the variance and mean of the posteriors, then e.g. you will be sampling wrong stuff during the forward pass.

bgroenks96 commented 4 years ago

Right. As the variance of the posterior approaches zero, the model will converge to the MLE solution. Perhaps we might not consider this to be a bad thing. It would be sensible for neural networks to have tight priors anyway, given the high dimensionality of the parameter space. However, ignoring this issue in evaluating the posterior samples will result in a kind of overestimate of the model's confidence.

nbro commented 4 years ago

@bgroenks96 Yes, good points.

But, even though you showed that the variance of the posteriors is quite small (in the Github example you linked), are you sure that the scaling (directly) affects the variance and means of the posteriors directly and/or proportional to the scale of the KL divergence? (This is more or less one of the questions that I've always had)

If I understood correctly, the KL divergences are added to the losses field of the model. If we assume that each layer of the Bayesian neural network (BNN) computes such loss individually and if we assume that we scale it by the number of examples, then, if Keras automatically adds the sum of these KL divergences to the likelihood loss, what will really happen? How are we going to update the posteriors given this likelihood + sum of scaled KL divergences?

Now, it isn't clear to me how the losses play a role in back-propagating errors (especially, in this case), actually.

bgroenks96 commented 4 years ago

are you sure that the scaling (directly) affects the variance and means of the posteriors directly and/or proportional to the scale of the KL divergence?

Fairly certain, yes. If you increase the scaling factor (i.e. larger KL loss component) and repeat this experiment, you will observe larger variances in the variational posterior. This makes sense intuitively as well. If you reweight the KL-divergence to be very small, the easiest way for the optimizer to minimize the NLL is to center the parameter distributions around the MLE and make the variance as small as possible.

if Keras automatically adds the sum of these KL divergences to the likelihood loss, what will really happen? How are we going to update the posteriors given this likelihood + sum of scaled KL divergences?

Yes. Keras automatically sums all of the losses together to get a total loss. This total loss is then averaged over the remaining dimensions. Thus, if you want the KL-divergence to be applied only once per input, it would make sense to normalize it by the number of batches, since Keras will already divide by the batch size internally.

Now, it isn't clear to me how the losses play a role in back-propagating errors (especially, in this case), actually.

They are included in the final loss which is backpropagated.

bgroenks96 commented 4 years ago

Although per #883, the convolutional KL losses are actually not included when using tf.function which seems to be a bug.

nbro commented 4 years ago

@bgroenks96

If you increase the scaling factor

You mean if you decrease the scaling factor, i.e. you divide the KL divergence by a small number. Right? Because if you divide the KL divergence by a bigger number, the KL term will clearly be smaller.

If you reweight the KL-divergence to be very small, the easiest way for the optimizer to minimize the NLL is to center the parameter distributions around the MLE and make the variance as small as possible.

I trained the same model but with different scaling factors and the distribution of the variances of the posteriors is quite different.

However, I didn't get what you were expecting.

With a smaller KL divergence, the variances are actually more spread rather than concentrated around a certain value. With a higher KL term (i.e. divide by smaller number), some variances are higher than the corresponding variances with a smaller KL term, but the most visible pattern is that they are more concentrated around the same value. Actually, this behaviour seems to be more intuitive because you expect the first layers to have higher variance.

I don't know if this is task-specific or not. Maybe you can also try yourself and tell me what you observe. For example, try this with MNIST and tell me if you get something similar. I think so.

Thus, if you want the KL-divergence to be applied only once per input

I am not sure I understand. In which sense do you mean "apply the KL divergence once per input"? And why would you do that?

If you are performing mini-batch GD, I suppose that, every time you forward pass an input, the KL divergence is computed and added to the final loss. Then what happens?

it would make sense to normalize it by the number of batches, since Keras will already divide by the batch size internally.

Let's assume that for M examples in your mini-batch, you compute M KL divergences kl1, kl2, ...,klM, each of kli is actually a sum of all KL divergences of all priors-posteriors of each unit of the net. Then, if what you say is correct, Keras will compute the final loss FOR A SINGLE MINI-BATCH by also averaging these M KL divergences, i.e. 1/M*(kl1 + ... + klM). This should represent the average KL-divergence of M examples.

So, I don't get why you would divide by the number of batches. Maybe I am not seeing something that you are seeing.

bgroenks96 commented 4 years ago

However, I didn't get what you were expecting.

I think you misunderstood. What you described is pretty much exactly what I would expect. A smaller KL-divergence term leads to more variance (overloading the term here) in the variational parameters, and if you check, you see that the average magnitude of your variational posterior variances (the actual values of your variances) are smaller (significantly smaller than the prior). A larger KL-divergence term will force the variational posterior closer to the prior, thus making them more "concentrated around a single value" (the prior variance) as you put it.

And why would you do that?

I think I meant once per epoch? Basically the KL-divergence should be applied once over the full objective, which means dividing it by the number of updates in one epoch for batched gradient descent.

you compute M KL divergences

This is not the default behavior in TFP. If you use analyitcal KL-divergences, you will get one KL loss per mini-batch.

Keras will compute the final loss FOR A SINGLE MINI-BATCH by also averaging these M KL divergences

No. There is only one KL value per mini-batch, which in the case of reparameterized Keras layers, is the sum of KL values over the whole kernel.

bgroenks96 commented 4 years ago

An alternative approach is to just treat the scaling factor like what it really is: a hyperparameter.

A technique that I (and other authors) have used in the past with VAEs is to set the KL divergence scaling factor very low for some "burn-in" or "warmup" period (e.g. 10 epochs) and then linearly anneal it up to 1.0 over the course of 50-100 additional epochs. This often works fairly well because it helps the optimizer get into a good region of the parameter space (w.r.t to the likelihood) before optimizing the KL term.

nbro commented 4 years ago

@bgroenks96

Basically the KL-divergence should be applied once over the full objective, which means dividing it by the number of updates in one epoch for batched gradient descent.

Why should the KL divergence be applied once over the full objective? Is it because it doesn't depend on the data, so we shouldn't count it as many times as we see the data? I think so. The BbB paper also suggests this.

which means dividing it by the number of updates in one epoch for batched gradient descent.

Right, so it means that we should scale the KL divergence at each mini-batch by 1/M, where M is the number of mini-batches.

No. There is only one KL value per mini-batch, which in the case of reparameterized Keras layers, is the sum of KL values over the whole kernel.

But how is this KL value computed at each mini-batch? That's my question! Is it a sum of all KL divergences between priors and posteriors for each unit, or what exactly (see the next question)?

Before (somewhere), you had said that Keras already divides the KL term (which is stored in the losses property of each layer) by the number of examples in mini-batch, but now you're saying that only one KL term is computed for each mini-batch. Again, how? Where can I see this in the code? I actually asked this question on Stack Overflow and Github. I want to make sure how the KL is treated at each step of GD in order to choose the appropriate weighting factor for the KL term.

A technique that I (and other authors) have used in the past with VAEs is to set the KL divergence scaling factor very low for some "burn-in" or "warmup" period (e.g. 10 epochs) and then linearly anneal it up to 1.0 over the course of 50-100 additional epochs. This often works fairly well because it helps the optimizer get into a good region of the parameter space (w.r.t to the likelihood) before optimizing the KL term.

So, are you saying that I should start with very small KL losses and then eventually let the KL term increase (which corresponds to increasing the factor by which I multiply the KL term towards 1 )?

bgroenks96 commented 4 years ago

Why should the KL divergence be applied once over the full objective?

Because in the ELBO, the likelihood term is implicitly over the full dataset. We're doing iterative learning via batched gradient descent, so this needs to be accounted for.

But how is this KL value computed at each mini-batch?

If your variational parameters are free variables, then it's simply the KL-divergence between the surrogate posterior and the prior summed over the whole kernel (assuming we're talking about BNNs and not VAEs).

Before (somewhere), you had said that Keras already divides the KL term (which is stored in the losses property of each layer) by the number of examples in mini-batch, but now you're saying that only one KL term is computed for each mini-batch

Keras computes a final average over the total loss. However, in this case, the KL-divergence is a single value for the whole batch, so it will be an average over 1. If you were to compute empirical KL-divergence values, it would be an average over the batch.

The Keras source code can be a bit difficult and time-consuming to understand, so you might consider just playing with a toy example to see how this works:

class MyLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(1.0)
        return inputs

inputs = tf.keras.layers.Input((3,3))
layer = MyLayer()(inputs)
model = tf.keras.Model(inputs=inputs, outputs=layer)
model.compile(loss='mae', optimizer='sgd', metrics=['mae'])
loss, mae = model.evaluate(tf.ones((10,3,3)), 2*tf.ones((10,3,3)))
nbro commented 4 years ago

Because in the ELBO, the likelihood term is implicitly over the full dataset. We're doing iterative learning via batched gradient descent, so this needs to be accounted for.

Well, you should read section 3.4 of "Weight Uncertainty in Neural Networks". In this paper, they actually break down the ELBO loss for mini-batch GD. There they introduce the scaling factor not because the likelihood is of all data (that's already an expectation there), but I think it's because of the reasons I mentioned in my last comment above (i.e. the calculation of the KL does not depend on the data, i.e. you can compute the KL without the data, you only need the distributions).

it's simply the KL-divergence between the surrogate posterior and the prior summed over the whole kernel

Are you talking about a single layer? Each layer has its own kernel. Even if you're talking about a single layer, are you sure it's a sum and not an average?

Anyway, let's suppose it's a sum, if what you say is correct, ONCE PER MINI-BATCH, these KL divergences of each layer are added to the losses property. So, at the end of the mini-batch, we have T KL divergences, where T is the number of Bayesian layers. In fact, if you execute the following code

    import tensorflow as tf
    import tensorflow_probability as tfp

    def get_model():
        inp = tf.keras.layers.Input(shape=(1,))
        x = tfp.layers.DenseFlipout(8)(inp)
        x = tfp.layers.DenseFlipout(16)(x)
        out = tfp.layers.DenseFlipout(1)(x)
        model = tf.keras.Model(inputs=inp, outputs=out)
        model.summary()
        return model

    def example0():
        my_model = get_model()
        my_model.compile(optimizer="adam", loss="mse")
        print(len(my_model.losses))

    if __name__ == '__main__':
        example0()

It will print 3.

Keras computes a final average over the total loss. However, in this case, the KL-divergence is a single value for the whole batch, so it will be an average over 1

I don't think that Keras computes only 1 KL divergence for each mini-batch. As the example above illustrates, it will have a KL loss for each layer of the network.

Ok, even if you meant 1 KL divergence for each layer, then what will Keras do with my_model.losses when fit is called? Will it sum, average or what the elements of my_model.losses ? Are you saying that it will average the elements of my_model.losses? I don't think so, as my example below (based on yours) illustrates.

The Keras source code can be a bit difficult and time-consuming to understand, so you might consider just playing with a toy example to see how this works:

I've played with your example. It confirms what I said above. Every time I call self.add_loss inside a layer, you add an element to model.losses. So, in your specific example above, your model.losses will have 1 element. But if you define the layer class as

class MyLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(2.0)
        self.add_loss(2.0)
        return inputs

It will contain two elements, and so on.

Actually, the properties in model.losses are not averaged. They are just summed to obtain the total loss in this case. In fact, with the following example

import tensorflow as tf
import tensorflow_probability as tfp

class MyLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        self.add_loss(2.0)
        self.add_loss(2.0)
        return inputs

def example1():
    inputs = tf.keras.layers.Input((3, 3))
    layer = MyLayer()(inputs)
    model = tf.keras.Model(inputs=inputs, outputs=layer)
    model.compile(loss='mae', optimizer='sgd', metrics=['mae'])
    model.summary()
    print("number of losses =", len(model.losses))
    loss, mae = model.evaluate(tf.ones((10, 3, 3)), 2 * tf.ones((10, 3, 3)))

if __name__ == '__main__':
    example1()

You get a loss of 5.0. You get this loss because 2+2 (which comes from the two calls to self.add_loss inside the call method of layer). The 1 comes from the fact that the MEAN absolute error between tf.ones((10, 3, 3)), 2 * tf.ones((10, 3, 3)) is 1.

bgroenks96 commented 4 years ago

Yes, I meant one KL per layer. These KLs are then summed to get the final loss. Averages are only computed over the remaining dimensions of each loss (batch and spatial dimensions), if necessary.

the calculation of the KL does not depend on the data, i.e. you can compute the KL without the data, you only need the distributions

This is not necessarily true. For complex priors, sometimes we need empirical KLs. Some architectures, like VAEs, also use the data to produce the variational parameters, in which case the surrogate posterior depends on the data.

nbro commented 4 years ago

@bgroenks96

This is not necessarily true. For complex priors, sometimes we need empirical KLs. Some architectures, like VAEs, also use the data to produce the variational parameters, in which case the surrogate posterior depends on the data.

Well, I was talking about BNNs, not the VAE. In the VAE, the mean and variance are the outputs of the encoder, so, of course, they depend on the data/input to the encoder.

bgroenks96 commented 4 years ago

It's also true for a BNN when the prior has no well defined KL divergence, like a mixture of Gaussians.

nbro commented 4 years ago

@bgroenks96

It's also true for a BNN when the prior has no well defined KL divergence, like a mixture of Gaussians.

Why? In the ByB paper, the authors use a mixture of Gaussians as the prior and the KL term doesn't depend on the data.

bgroenks96 commented 4 years ago

Nevermind, in the specific example I am working on there is an inference function which depends on the data. But for free parameters, you can approximate the KL-divergence by drawing samples from the surrogate posterior and the prior. So it depends on the samples but not the data.