Closed paschermayr closed 4 years ago
It seems this is a bug in DistributionsAD:
julia> using Distributions, DistributionsAD, Tracker, FiniteDiff, ForwardDiff, ReverseDiff, Zygote;
julia> g(x) = logpdf(Dirichlet(3, 3), x);
julia> x = (y = rand(3); y ./ sum(y))
3-element Array{Float64,1}:
0.3996881361942186
0.11886701124092888
0.48144485256485253
julia> Tracker.gradient(g, x)[1]
Tracked 3-element Array{Float64,1}:
5.0039013392885625
16.825526099467957
4.154162183571361
julia> ForwardDiff.gradient(g, x)
3-element Array{Float64,1}:
5.0039013392885625
16.825526099467957
4.154162183571361
julia> FiniteDiff.finite_difference_gradient(g, x)
3-element Array{Float64,1}:
5.003901339640391
16.82552611394734
4.154162183750874
julia> ReverseDiff.gradient(g, x)
3-element Array{Float64,1}:
5.0039013392885625
16.825526099467957
4.154162183571361
julia> Zygote.gradient(g, x)
([2.0, 2.0, 2.0],)
I used
[31c24e10] Distributions v0.23.4
[ced4e74d] DistributionsAD v0.6.0
[6a86dc24] FiniteDiff v2.3.2
[f6369f11] ForwardDiff v0.10.10
[37e2e3b7] ReverseDiff v1.2.0
[9f7883ad] Tracker v0.2.7
[e88e6eb3] Zygote v0.4.22
BTW the error went unnoticed because of the special parameter choices in the tests (which I got rid of for a bunch of multivariate and matrixvariate distributions in https://github.com/TuringLang/Bijectors.jl/pull/116, but precisely not for the Dirichlet distribution). The choice Dirichlet(ones(3))
in the tests just yields the uniform distribution on the probability simplex, and hence the gradient of the (log) probability with respect to the sample is just a vector of zeros. In this case Zygote yields the correct result.
Thanks a lot! Is it possible to move this issue to the DistributionsAD site manually, or shall I open a new issue there?
Found the culprit: the custom adjoints in https://github.com/TuringLang/DistributionsAD.jl/blob/696fe18923a98c3060200455dd059d48d60c3a8d/src/multivariate.jl#L83 and in https://github.com/TuringLang/DistributionsAD.jl/blob/696fe18923a98c3060200455dd059d48d60c3a8d/src/multivariate.jl#L88 are incorrect, there's a ./ x
missing in the third element of the tuple returned by the pullback (that's the computation with respect to the sample).
I'll open a PR with a fix.
Haha, I literally did the same thing and then suddenly the issue was gone :sweat_smile: @devmotion is too darn quick!
For reference: the following works
julia> using ZygoteRules
julia> ZygoteRules.@adjoint function DistributionsAD.simplex_logpdf(alpha, lmnB, x::AbstractVector)
DistributionsAD.simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1) ./ x)
end
julia> Zygote.gradient(g, x)
([14.1467456104497, 14.360457585043832, 2.7802746168862287],)
julia> FiniteDiff.finite_difference_gradient(g, x)
3-element Array{Float64,1}:
14.146745619193984
14.36045759402944
2.7802746169657384
@devmotion Let me know when the PR is up, and I'll review asap :+1:
Hi guys,
I am using Bijectors v0.8.0, and updated all modules used below yesterday.
I am getting different gradient results depending on whether I use Zygote or Forward/ReverseDiff for the Dirichlet distribution in the
logpdf_with_trans
calculation, MWE below:I am unsure if this should be posted to Bijectors, Distributions, DistributionsAD or Zygote, but started here :D. Apologies if this is not related to Bijectors.
Best regards, Patrick