Closed SerezD closed 6 months ago
Nevermind, I got it:
checking section 3.2 of the paper, paragraph: Residual Normal Distributions.
In the code self.mu
, self.sigma
are the parameters of the posterior distribution
prior.mu
, prior.sigma
are the parameters of the prior distribution
p(zi|z{l<i}) is the prior, defined as N(μ_p, σp) where both params are conditioned on all z{l<i}
q(zi|z{l<i}, x) is the distribution from the encoder (self), defined as:
q = N(μ_p + Δμ_q, σ_p * Δσ_q), where Δμ_q, Δσ_q are the relative shift and scale given by the hierarchical nature of the distribution.
So basically, self.mu
and self.sigma
are the parameters of the posterior:
self.mu = μ_p + Δμ_q
self.sigma = σ_p * Δσ_q
The KL Loss between two normal distributions a = N(μ_1, σ_1), b = N(μ_2, σ_2) is given by:
0.5 [ (μ_2 - μ_1)**2 / σ_2**2 ] + 0.5 (σ_1**2 / σ_2**2) - 0.5 [ln(σ_1**2 / σ_2**2)] - 0.5
proof: https://statproofbook.github.io/P/norm-kl.html In our case: μ_1 = self.mu; μ_2 = prior.mu; σ_1 = self.sigma; σ_2 = prior.sigma
So the three terms in the formula above become:
0.5 [ (μ_p - μ_p + Δμ_q)**2 / σ_p**2] = 0.5 [ Δμ_q**2 / σ_p**2]
0.5 ((σ_p * Δσ_q)**2 / σ_p**2) = 0.5 [Δσ_q**2]
0.5 [ln((σ_p * Δσ_q)**2 / σ_p**2)] = 0.5 ln(Δσ_q**2)
The final formula is thus the one written in Equation 2 and (in the code):
Δμ_q = self.mu - prior.mu
Δσ_q = self.sigma / prior.sigma
In
distributions.py
, the KL is computed as indicated in section 3.2 of the paper (residual normal distributions, Equation 2):What I don't understand is, why you compute
term2 = self.sigma / normal_dist.sigma
. Shouldn't it be:term2 = self.sigma - normal_dist.sigma
?