danijar / dreamerv2

Mastering Atari with Discrete World Models
https://danijar.com/dreamerv2
MIT License
898 stars 195 forks source link

Difference in the KL loss terms in the paper and the code #4

Closed shivakanthsujit closed 3 years ago

shivakanthsujit commented 3 years ago

The algorithm for the KL balancing in the paper has the posterior and prior terms given as kl_loss = compute_kl(stop_grad(posterior), prior). So I had assumed that the code would have computed the loss as value = kld(dist(sg(post)), dist(prior)).

But instead the code has the terms reversed, with the KL loss formulated as (in networks.py, line 168) value = kld(dist(prior), dist(sg(post))).

Does that have something to do with the implementation of the kl divergence function in tensorflow_probability?

danijar commented 3 years ago

KL balancing is implemented as weighted average of two terms, the KL with stop-grad prior and the KL with stop-grad posterior.

The value you found in the code is only used for logging. It is not what the gradient is computed of.