Closed codyshen0000 closed 3 years ago
Please see Eq.(9) in the paper. z follows Gaussian distribution, so its log probability is ||z||^2. CE loss does not necessarily guarantee the distribution of (Y, Z) and is only a weaker surrogate loss for stable training (refer to https://github.com/pkuxmq/Invertible-Image-Rescaling/issues/14).
Hi , I'm confused why use ''torch.sum(z**2)'' to calculate loss_ce. Besides, how can it be transformed from calculating loss_distr on Y x Z to calculating Loss_ce only on Z space?