JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.09k stars 410 forks source link

Compatibility with Zygote AD #1516

Open Uroc327 opened 2 years ago

Uroc327 commented 2 years ago

It seems that only ForwardDiff is supported for ADing sampling. Is it possible to implement rules for Zygote as well?

julia> ForwardDiff.derivative(x -> sum(rand(Normal(x, 10), 10)), 0.)
10.0

julia> Zygote.gradient(x -> sum(rand(Normal(x, 10), 10)), 0.)
(nothing,)
trahflow commented 1 year ago

Is it possible to implement rules for Zygote as well?

It is possible, e.g. (just a quick hack):

function ChainRulesCore.rrule(::typeof(rand), d::Normal{T}, n::Integer) where {T}
   vals = rand(d, n)
   function rand_pullback(rand_bar)
       d_bar = Tangent{Normal{T}}(;μ=n, σ=sum((vals .- d.μ)) / d.σ)
       return NoTangent(), d_bar, NoTangent()
   end
   return vals, rand_pullback
end

But I wonder why a dedicated rule is necessary at all. Looking at the the definition for rand for Normal, this should be easily differentiable. I don't really see why Zygote returns nothing here

devmotion commented 1 year ago

It is due to https://github.com/JuliaDiff/ChainRules.jl/blob/158ca756ef99ccf3f1dde2e66b5855e8e68e0363/src/rulesets/Random/random.jl#L23-L25. It is a deliberate design decision to mark rand methods as non-differentiable in ChainRules (which is used by Zygote). They were made more restrictive to explicitly not cover e.g. rand(Normal()) (see e.g. the discussion in https://github.com/JuliaDiff/ChainRules.jl/issues/262). However, I assume the problem with the example above is that sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand! - and that one is marked non-differentiable, regardless of the type of the arguments. See also https://github.com/JuliaDiff/ChainRules.jl/issues/603.

trahflow commented 1 year ago

sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand!

Ah, I missed that. Then everything makes senes.

Sort of a meta question: Where would one put AD rules for such things? Here in the ext/DistributionsChainRulesCoreExt module? Also found TuringLang/DistributionsAD.jl#123 but not sure if that Package is still meant to be used?