In your code the KL divergence is calculated by:
KLD = torch.sum(qy * (log_qy - 1. / categorical_dim), dim=-1).mean()
I think, for the 1. / categorical_dim, it should be replaced by the torch.log(1. / categorical_dim), otherwise, it is not the KL divergence.
In your code the KL divergence is calculated by:
KLD = torch.sum(qy * (log_qy - 1. / categorical_dim), dim=-1).mean()
I think, for the1. / categorical_dim
, it should be replaced by thetorch.log(1. / categorical_dim),
otherwise, it is not the KL divergence.