Closed gaurav-arya closed 3 months ago
It is quite common for other packages (e.g. Distributions.jl) to overload Random.rand with parameter-dependent distributions, e.g.
Distributions.jl
Random.rand
using Enzyme import Random # N(x, 1) struct MyDistribution x::Float64 end Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn() Random.rand(d::MyDistribution) = rand(Random.default_rng(), d) @show autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) # ((0.0,),)
or even
struct DiracDelta x::Float64 end Random.rand(rng::Random.AbstractRNG, d::DiracDelta) = d.x Random.rand(d::DiracDelta) = rand(Random.default_rng(), d) @show autodiff(Enzyme.Reverse, x -> rand(DiracDelta(x)), Active, Active(1.0)) # ((0.0,),)
Zero gradients are returned because the outer rand call is marked inactive:
rand
https://github.com/EnzymeAD/Enzyme.jl/blob/1bf16f8217f2f0e516666f5dff2deb27a653302d/src/internal_rules.jl#L69
Closed by #1388
It is quite common for other packages (e.g.
Distributions.jl
) to overloadRandom.rand
with parameter-dependent distributions, e.g.or even
Zero gradients are returned because the outer
rand
call is marked inactive:https://github.com/EnzymeAD/Enzyme.jl/blob/1bf16f8217f2f0e516666f5dff2deb27a653302d/src/internal_rules.jl#L69