TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
151 stars 30 forks source link

logpdf_with_trans different for Zygote and Forward/ReverseDiff in Dirichlet case #86

Closed paschermayr closed 4 years ago

paschermayr commented 4 years ago

Hi guys,

I am using Bijectors v0.8.0, and updated all modules used below yesterday.

  [76274a88] Bijectors v0.8.0
  [163ba53b] DiffResults v1.0.2
  [31c24e10] Distributions v0.23.4
  [ced4e74d] DistributionsAD v0.6.0
  [f6369f11] ForwardDiff v0.10.10
  [37e2e3b7] ReverseDiff v1.2.0
  [e88e6eb3] Zygote v0.4.21

I am getting different gradient results depending on whether I use Zygote or Forward/ReverseDiff for the Dirichlet distribution in the logpdf_with_trans calculation, MWE below:

using Distributions, DistributionsAD, Bijectors
using ForwardDiff, ReverseDiff, Zygote

function get_loglik(data)
    function loglik(theta_transformed::AbstractVector{<:Real})
    p = Bijectors.invlink(Dirichlet(3,3), theta_transformed)
    lprior = logpdf_with_trans( Dirichlet(3,3), p, true ) ## !!! Difference occurs here
    llik = logpdf(Categorical(p), data)
    return llik + lprior
    end
end
theta_transformed = randn(3)
ll = get_loglik(3)
ll(theta_transformed)

ForwardDiff.gradient(ll, theta_transformed)
ReverseDiff.gradient(ll, theta_transformed) #same as ForwardDiff
Zygote.gradient(ll, theta_transformed) #different than both

I am unsure if this should be posted to Bijectors, Distributions, DistributionsAD or Zygote, but started here :D. Apologies if this is not related to Bijectors.

Best regards, Patrick

devmotion commented 4 years ago

It seems this is a bug in DistributionsAD:

julia> using Distributions, DistributionsAD, Tracker, FiniteDiff, ForwardDiff, ReverseDiff, Zygote;

julia> g(x) = logpdf(Dirichlet(3, 3), x);

julia> x = (y = rand(3); y ./ sum(y))
3-element Array{Float64,1}:
 0.3996881361942186
 0.11886701124092888
 0.48144485256485253

julia> Tracker.gradient(g, x)[1]
Tracked 3-element Array{Float64,1}:
  5.0039013392885625
 16.825526099467957
  4.154162183571361

julia> ForwardDiff.gradient(g, x)
3-element Array{Float64,1}:
  5.0039013392885625
 16.825526099467957
  4.154162183571361

julia> FiniteDiff.finite_difference_gradient(g, x)
3-element Array{Float64,1}:
  5.003901339640391
 16.82552611394734
  4.154162183750874

julia> ReverseDiff.gradient(g, x)
3-element Array{Float64,1}:
  5.0039013392885625
 16.825526099467957
  4.154162183571361

julia> Zygote.gradient(g, x)
([2.0, 2.0, 2.0],)

I used

  [31c24e10] Distributions v0.23.4
  [ced4e74d] DistributionsAD v0.6.0
  [6a86dc24] FiniteDiff v2.3.2
  [f6369f11] ForwardDiff v0.10.10
  [37e2e3b7] ReverseDiff v1.2.0
  [9f7883ad] Tracker v0.2.7
  [e88e6eb3] Zygote v0.4.22
devmotion commented 4 years ago

BTW the error went unnoticed because of the special parameter choices in the tests (which I got rid of for a bunch of multivariate and matrixvariate distributions in https://github.com/TuringLang/Bijectors.jl/pull/116, but precisely not for the Dirichlet distribution). The choice Dirichlet(ones(3)) in the tests just yields the uniform distribution on the probability simplex, and hence the gradient of the (log) probability with respect to the sample is just a vector of zeros. In this case Zygote yields the correct result.

paschermayr commented 4 years ago

Thanks a lot! Is it possible to move this issue to the DistributionsAD site manually, or shall I open a new issue there?

devmotion commented 4 years ago

Found the culprit: the custom adjoints in https://github.com/TuringLang/DistributionsAD.jl/blob/696fe18923a98c3060200455dd059d48d60c3a8d/src/multivariate.jl#L83 and in https://github.com/TuringLang/DistributionsAD.jl/blob/696fe18923a98c3060200455dd059d48d60c3a8d/src/multivariate.jl#L88 are incorrect, there's a ./ x missing in the third element of the tuple returned by the pullback (that's the computation with respect to the sample).

devmotion commented 4 years ago

I'll open a PR with a fix.

torfjelde commented 4 years ago

Haha, I literally did the same thing and then suddenly the issue was gone :sweat_smile: @devmotion is too darn quick!

For reference: the following works

julia> using ZygoteRules

julia> ZygoteRules.@adjoint function DistributionsAD.simplex_logpdf(alpha, lmnB, x::AbstractVector)
           DistributionsAD.simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1) ./ x)
       end

julia> Zygote.gradient(g, x)
([14.1467456104497, 14.360457585043832, 2.7802746168862287],)

julia> FiniteDiff.finite_difference_gradient(g, x)
3-element Array{Float64,1}:
 14.146745619193984
 14.36045759402944
  2.7802746169657384
torfjelde commented 4 years ago

@devmotion Let me know when the PR is up, and I'll review asap :+1: