tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.11k forks source link

Keras model.compile uses incorrect loss when using Bayesian layers. #282

Open xht033 opened 5 years ago

xht033 commented 5 years ago

When I use my custom loss function, I got a wrong loss output if I choose keras.compile:

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data as mnist_data

train, valid, test = mnist_data.read_data_sets('~/code/Python')
num_classes = 10
from tensorflow import keras
import tensorflow_probability as tfp
model = keras.Sequential()

model.add(tfp.layers.DenseReparameterization(10, activation = 'softmax', input_shape=(784,)))

sgd = keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
def my_loss(y_true,y_pred):
    return tf.keras.losses.categorical_crossentropy(y_true,y_pred)
model.compile(loss=my_loss, optimizer=sgd, metrics=['accuracy'])

x_train, y_train = train.images, train.labels
x_test, y_test = test.images, test.labels

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model.fit(x_train, y_train,
              batch_size=128,
              epochs=10,
              validation_data=(x_test, y_test),
              shuffle=True)

However, If I replace the bayesian layer by a traditional Dense layer, it seems everything is correct.

SiegeLordEx commented 5 years ago

So the issue is that the correct usage of TFP's layers involves scaling the KL divergence (stored inside the Models's losses property) by the number of examples (see here). Unfortunately, when you do model.compile the scaling is not done so the model ends up being over-regularized.

Here's one way to resolve this, override the DenseReparameterization kernel_divergence_fn argument to perform the necessary scaling:

model.add(tfp.layers.DenseReparameterization(
    10,
    activation = 'softmax',
    input_shape=(784,),
    kernel_divergence_fn=lambda q, p, _: tfp.distributions.kl_divergence(q, p) / tf.to_float(train.num_examples)))

This should do the right thing at least for the train loss. Depending on how you interpret what validation loss should be it'll do the right thing for that as well.

xht033 commented 5 years ago

Can I ask why we divide the total number of training_data_set?

alexv1247 commented 5 years ago

Hi there, thanks for that advice about how to handle the keras compile issue. At first I thought, it does solve my problem. But after 3 trainings iterations while adding new data to the trainig set at each iteration (i am doing active learning), the loss for both model outputs were getting bigger again. I choose to scale the kl_divergence with the batchsize. Is that correct? or do I need do scale it with total number of training examples that model will face during a training iteration?

nbro commented 4 years ago

@SiegeLordEx Could you please answer the question above by @xht033 and @alexv1247?

If we perform mini-batch stochastic gradient descent, shouldn't we divide the KL loss by the size of the mini-batch dataset rather than the size of the whole training dataset? I think the answer to this question depends on when the KL loss (regularisation) is added to the final loss, which should be at each training step (so after each mini-match).

Apparently, in the example https://github.com/tensorflow/probability/blob/master/tensorflow_probability/examples/bayesian_neural_network.py, you also divide the kl term by the total number of training examples, even though you perform mini-batch stochastic gradient descent.

bjarthur commented 4 years ago

i have a similar problem with DistributionLambda layers when using a custom metric. what gets passed as yhat to the metric function is a sample drawn from the distribution. would be better pass the actual distribution, as is done for custom loss functions. the workaround is to specify convert_to_tensor_fn as lambda t: t.loc.

should i open a separate issue or is this similar enough to the problem above with DenseReparameterization?