Open NiklasGustafsson opened 5 years ago
```jl # # Repro case code was based on code from: https://github.com/awf/autodiff # using Pkg using Printf using SpecialFunctions using LinearAlgebra using Zygote using Zygote: @adjoint struct Wishart gamma::Float64 m::Int end function pack(alphas,means,icf) [alphas[:];means[:];icf[:]] end function unpack(d,k,packed) alphas = reshape(packed[1:k],1,k) off = k means = reshape(packed[(1:d*k) .+ off],d,k) icf_sz = div(d*(d + 1),2) off += d*k icf = reshape(packed[off+1:end],icf_sz,k) (alphas,means,icf) end sumsq(v) = sum(abs2, v) function ltri_unpack(D, LT) d=length(D) make_row(r::Int, L) = hcat(reshape([ L[i] for i=1:r-1 ],1,r-1), D[r], zeros(1,d-r)) row_start(r::Int) = div((r-1)*(r-2),2) inds(r) = row_start(r) .+ (1:r-1) vcat([ make_row(r, LT[inds(r)]) for r=1:d ]...) end function get_Q(d,icf) ltri_unpack(exp.(icf[1:d]),icf[d+1:end]) end function log_gamma_distrib(a, p) out = 0.25 * p * (p - 1) * 1.1447298858494002 #convert(Float64, log(pi)) for j in 1:p out += lgamma(a + 0.5*(1 - j)) end out end function log_wishart_prior(wishart::Wishart, sum_qs, Qs) p = size(Qs[1],1) n = p + wishart.m + 1 C = n*p*(log(wishart.gamma) - 0.5*log(2)) - log_gamma_distrib(0.5*n, p) frobenius = sum(abs2, Qs) # frobenius = 0. # for Q in Qs # frobenius += sum(abs2,diag(Q)) # end # frobenius += sum(abs2,icf[d+1:end,:]) # @show icf[d+1:end,:] # @show icf # @show Qs 0.5*wishart.gamma^2 * frobenius - wishart.m*sum(sum_qs) - k*C end # input should be 1 dimensional function logsumexp(x) mx = maximum(x) log(sum(exp.(x .- mx))) + mx end function diagsums(Qs) mapslices(slice -> sum(diag(slice)), Qs; dims=[1,2]) end @adjoint function diagsums(Qs) diagsums(Qs), function (Δ) Δ′ = zero(Qs) for (i, δ) in enumerate(Δ) for j in 1:size(Qs, 1) Δ′[j,j,i] = δ end end (Δ′,) end end function expdiags(Qs) mapslices(Qs; dims=[1,2]) do slice slice[diagind(slice)] .= exp.(slice[diagind(slice)]) slice end end @adjoint function expdiags(Qs) expdiags(Qs), function (Δ) Δ′ = zero(Qs) Δ′ .= Δ for i in 1:size(Qs, 3) for j in 1:size(Qs, 1) Δ′[j,j,i] *= exp(Qs[j,j,i]) end end (Δ′,) end end Base.:*(::Float64, ::Nothing) = nothing function gmm_objective(alphas,means,Qs,x,wishart::Wishart) d = size(x,1) n = size(x,2) CONSTANT = -n*d*0.5*log(2 * pi) sum_qs = reshape(diagsums(Qs), 1, size(Qs, 3)) slse = sum(sum_qs) Qs = expdiags(Qs) main_term = zeros(Float64,1,k) slse = 0. for ix=1:n formula(ik) = -0.5*sum(abs2, Qs[:, :, ik] * (x[:,ix] .- means[:,ik])) sumexp = 0. for ik=1:k sumexp += exp(formula(ik) + alphas[ik] + sum_qs[ik]) end slse += log(sumexp) end CONSTANT + slse - n*logsumexp(alphas) + log_wishart_prior(wishart, sum_qs, Qs) end alphas = randn(1,10) means = rand(10,10) icf = randn(55,10) x = randn(10,10000) wishart = Wishart(1.0,0) d = size(means,1) k = size(means,2) n = size(x,2) const Qs = cat([get_Q(d,icf[:,ik]) for ik in 1:k]...; dims=[3]) # Objective # Call once in case of precompilation etc err = gmm_objective(alphas,means,Qs,x,wishart) function wrapper_gmm_objective(alphas, means, Qs) gmm_objective(alphas,means,Qs,x,wishart) end # Gradient g = (alphas, means, Qs)-> Zygote.gradient(wrapper_gmm_objective, alphas, means, Qs) J = g(alphas, means, Qs) ```
The attached repro file generates an internal error.
This was discovered while playing around with Zygote, trying to adapt a benchmark that works with ForwardDiff and Flux.Tracker. It was run on Windows 10.
This was initially opened against Julia, but Keno requested that it be moved to Zygote for further triage.