aicenter / ConditionalDists.jl

Conditional probability distributions powered by DistributionsAD.jl
MIT License
21 stars 4 forks source link

Second order gradients of Normal distribution #33

Open pevnak opened 4 years ago

pevnak commented 4 years ago

I need the normal distribution to support second order gradients of Gaussian distribution, since the current implementation relies on broadcasting, which breaks down.

I have quickly hack it as follows, but I would be happy if it gets merged with proper tests, such that I am sure it does what is supposed to do.


function _l(x::Matrix{T}, n, μ, σ2) where {T}
    -(vec(sum(((x - μ).^2) ./ σ2 .+ log.(σ2), dims=1)) .+ n*log(T(2π))) / 2
end 

function _∇l(Δ, x, n, μ, σ2)
    Δ = Δ'
    δ = Δ .* (x - μ) ./ σ2
    (- δ, nothing, δ, Δ .* (((x - μ).^2 ./ (σ2.^2))  - 1 ./ σ2) / 2)
end

function Distributions.logpdf(d::ConditionalDists.BMN, x::Matrix{T}) where T<:Real
    n = size(d.μ,1)
    μ = mean(d)
    σ2 = var(d)
    _l(x, n, μ, σ2)
end

Zygote.@adjoint function _l(x, n, μ, σ2)
    _l(x, n, μ, σ2), Δ -> _∇l(Δ, x, n, μ, σ2)
end
nmheim commented 4 years ago

Do you have a MWE for this?