tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.24k stars 1.1k forks source link

Variational conv layers divergence not applied in tf.function #883

Open bgroenks96 opened 4 years ago

bgroenks96 commented 4 years ago

For some reason, _ConvolutionVariational uses a boolean flag to avoid calling _apply_divergence again when evaluating the call function. Unfortunately, this breaks the layer when being used inside of tf.function. The change of state causes function retracing and the KL-divergence terms no longer appear in model.losses. Note that this is not a problem with _DenseVariational layers because no such flag is applied.

Simple example:

optim = tf.keras.optimizers.Adam()
@tf.function
def train_batch(x, y, kl_weight=tf.constant(0.01)):
    logits = model(x)
    print(model.losses)
    print(x.shape, y.shape)
    nll = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)
    kl = tf.math.add_n(model.losses)
    loss = nll + kl_weight*kl
    grads = tf.gradients(loss, model.trainable_variables)
    optim.apply_gradients(zip(grads, model.trainable_variables))
    return nll, kl

model can be any Keras Model with a variational conv layer.

This should be a simple fix. We just need to remove the flag and call _apply_divergence unconditionally in call.

nbro commented 4 years ago

@bgroenks96 Just a guess. Try to use experimental_run_tf_function=False when compiling your model.

Also, maybe provide a complete example of your model so that we can run it.