Open Uroc327 opened 2 years ago
Is it possible to implement rules for Zygote as well?
It is possible, e.g. (just a quick hack):
function ChainRulesCore.rrule(::typeof(rand), d::Normal{T}, n::Integer) where {T}
vals = rand(d, n)
function rand_pullback(rand_bar)
d_bar = Tangent{Normal{T}}(;μ=n, σ=sum((vals .- d.μ)) / d.σ)
return NoTangent(), d_bar, NoTangent()
end
return vals, rand_pullback
end
But I wonder why a dedicated rule is necessary at all. Looking at the the definition for rand for Normal, this should be easily differentiable. I don't really see why Zygote returns nothing
here
It is due to https://github.com/JuliaDiff/ChainRules.jl/blob/158ca756ef99ccf3f1dde2e66b5855e8e68e0363/src/rulesets/Random/random.jl#L23-L25. It is a deliberate design decision to mark rand
methods as non-differentiable in ChainRules (which is used by Zygote). They were made more restrictive to explicitly not cover e.g. rand(Normal())
(see e.g. the discussion in https://github.com/JuliaDiff/ChainRules.jl/issues/262). However, I assume the problem with the example above is that sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand!
- and that one is marked non-differentiable, regardless of the type of the arguments. See also https://github.com/JuliaDiff/ChainRules.jl/issues/603.
sampling multiple samples in Distributions is done via pre-allocating the output and sampling with rand!
Ah, I missed that. Then everything makes senes.
Sort of a meta question: Where would one put AD rules for such things?
Here in the ext/DistributionsChainRulesCoreExt
module?
Also found TuringLang/DistributionsAD.jl#123 but not sure if that Package is still meant to be used?
It seems that only ForwardDiff is supported for ADing sampling. Is it possible to implement rules for Zygote as well?