For some reason, _ConvolutionVariational uses a boolean flag to avoid calling _apply_divergence again when evaluating the call function. Unfortunately, this breaks the layer when being used inside of tf.function. The change of state causes function retracing and the KL-divergence terms no longer appear in model.losses. Note that this is not a problem with _DenseVariational layers because no such flag is applied.
For some reason,
_ConvolutionVariational
uses a boolean flag to avoid calling_apply_divergence
again when evaluating thecall
function. Unfortunately, this breaks the layer when being used inside oftf.function
. The change of state causes function retracing and the KL-divergence terms no longer appear inmodel.losses
. Note that this is not a problem with_DenseVariational
layers because no such flag is applied.Simple example:
model
can be any Keras Model with a variational conv layer.This should be a simple fix. We just need to remove the flag and call
_apply_divergence
unconditionally incall
.