Closed shivakanthsujit closed 3 years ago
KL balancing is implemented as weighted average of two terms, the KL with stop-grad prior and the KL with stop-grad posterior.
The value
you found in the code is only used for logging. It is not what the gradient is computed of.
The algorithm for the KL balancing in the paper has the posterior and prior terms given as
kl_loss = compute_kl(stop_grad(posterior), prior)
. So I had assumed that the code would have computed the loss asvalue = kld(dist(sg(post)), dist(prior))
.But instead the code has the terms reversed, with the KL loss formulated as (in networks.py, line 168)
value = kld(dist(prior), dist(sg(post)))
.Does that have something to do with the implementation of the kl divergence function in tensorflow_probability?