TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
151 stars 30 forks source link

MixtureModel gives lot of numerical errors while sampling #110

Open krishvishal opened 4 years ago

krishvishal commented 4 years ago

@mohamed82008

using Turing, ReverseDiff, Memoization
using ReverseDiff: @grad, TrackedArray

Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

function logsumexp(mat)
    maxima = maximum(mat, dims=1)
    exp_mat = exp.(mat .- maxima)
    sum_exp_mat = sum(exp_mat, dims=1)
    return maxima .+ log.(sum_exp_mat)
end

logsumexp(x::TrackedArray) = ReverseDiff.track(logsumexp, x)
@grad function logsumexp(x::AbstractArray)
    lse = logsumexp(ReverseDiff.value(x))
    return lse, Δ -> (Δ .* exp.(x .- lse),)
end

@model function mwe(y, A, ::Type{T} = Vector{Float64}) where {T}
    n = size(A,1)
    σ ~ truncated(Normal(0,2),0,Inf)
    x ~ filldist(MixtureModel(Normal[Normal(10,5), Normal(45, 3)], [1/3, 2/3]), n)
    # x ~ filldist(Normal(),n)
    μ = vec(logsumexp(A .- x))
    y .~ Normal.(μ, σ)
    return x
end
y = rand(10)
A1 = rand(10, 10)
model = mwe(y, A1)
chain = sample(model, NUTS(.65), 100)

If you run this, you'll get lot of numerical errors while sampling.

Virtually no numerical errors are present if,:

  1. you disable caching
  2. if you use any other distribution other than MixtureModel, with caching enabled
mohamed82008 commented 4 years ago

MixtureModel has branches, so we need a custom adjoint for it.