danijar / dreamerv2

Mastering Atari with Discrete World Models
https://danijar.com/dreamerv2
MIT License
898 stars 195 forks source link

Fix sum KL distribution across both latent dims. #32

Closed fvisin closed 2 years ago

fvisin commented 2 years ago

Logits have shape TensorShape([2, 32, 32]), I believe you want to sum over the last two latent dimensions when you are computing the KL divergence loss.

fvisin commented 2 years ago

My bad! As discussed offline the Categorical consumes one dimension, so the Independent distribution was correct.