Create a KL-Divergence function that accepts two tfp StudentT distributions. See _kl_normal_normal function from tfp Normal as example.
@kullback_leibler.RegisterKL(Normal, Normal)
def _kl_normal_normal(a, b, name=None):
"""Calculate the batched KL divergence KL(a || b) with a and b Normal.
Args:
a: instance of a Normal distribution object.
b: instance of a Normal distribution object.
name: Name to use for created operations.
Default value: `None` (i.e., `'kl_normal_normal'`).
Returns:
kl_div: Batchwise KL(a || b)
"""
with tf.name_scope(name or 'kl_normal_normal'):
b_scale = tf.convert_to_tensor(b.scale) # We'll read it thrice.
diff_log_scale = tf.math.log(a.scale) - tf.math.log(b_scale)
return (
0.5 * tf.math.squared_difference(a.loc / b_scale, b.loc / b_scale) +
0.5 * tf.math.expm1(2. * diff_log_scale) -
diff_log_scale)
Create a KL-Divergence function that accepts two tfp StudentT distributions. See _kl_normal_normal function from tfp Normal as example.