Photrek / Nonlinear-Statistical-Coupling

Apache License 2.0
5 stars 1 forks source link

StudentT - KL function #10

Closed Kevin-Chen0 closed 3 years ago

Kevin-Chen0 commented 3 years ago

Create a KL-Divergence function that accepts two tfp StudentT distributions. See _kl_normal_normal function from tfp Normal as example.

@kullback_leibler.RegisterKL(Normal, Normal)
def _kl_normal_normal(a, b, name=None):
  """Calculate the batched KL divergence KL(a || b) with a and b Normal.
  Args:
    a: instance of a Normal distribution object.
    b: instance of a Normal distribution object.
    name: Name to use for created operations.
      Default value: `None` (i.e., `'kl_normal_normal'`).
  Returns:
    kl_div: Batchwise KL(a || b)
  """
  with tf.name_scope(name or 'kl_normal_normal'):
    b_scale = tf.convert_to_tensor(b.scale)  # We'll read it thrice.
    diff_log_scale = tf.math.log(a.scale) - tf.math.log(b_scale)
    return (
        0.5 * tf.math.squared_difference(a.loc / b_scale, b.loc / b_scale) +
        0.5 * tf.math.expm1(2. * diff_log_scale) -
        diff_log_scale)
Kevin-Chen0 commented 3 years ago

This is not necessary to do atm.