I need the normal distribution to support second order gradients of Gaussian distribution, since the current implementation relies on broadcasting, which breaks down.
I have quickly hack it as follows, but I would be happy if it gets merged with proper tests, such that I am sure it does what is supposed to do.
function _l(x::Matrix{T}, n, μ, σ2) where {T}
-(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
end
function _∇l(Δ, x, n, μ, σ2)
Δ = Δ'
δ = Δ .* (x - μ) ./ σ2
(- δ, nothing, δ, Δ .* (((x - μ).^2 ./ (σ2.^2)) - 1 ./ σ2) / 2)
end
function Distributions.logpdf(d::ConditionalDists.BMN, x::Matrix{T}) where T<:Real
n = size(d.μ,1)
μ = mean(d)
σ2 = var(d)
_l(x, n, μ, σ2)
end
Zygote.@adjoint function _l(x, n, μ, σ2)
_l(x, n, μ, σ2), Δ -> _∇l(Δ, x, n, μ, σ2)
end
I need the normal distribution to support second order gradients of Gaussian distribution, since the current implementation relies on broadcasting, which breaks down.
I have quickly hack it as follows, but I would be happy if it gets merged with proper tests, such that I am sure it does what is supposed to do.