Closed fvisin closed 2 years ago
Logits have shape TensorShape([2, 32, 32]), I believe you want to sum over the last two latent dimensions when you are computing the KL divergence loss.
TensorShape([2, 32, 32])
My bad! As discussed offline the Categorical consumes one dimension, so the Independent distribution was correct.
Logits have shape
TensorShape([2, 32, 32])
, I believe you want to sum over the last two latent dimensions when you are computing the KL divergence loss.