EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
428 stars 59 forks source link

Overly broad inactive marker for Random.rand #1387

Closed gaurav-arya closed 3 months ago

gaurav-arya commented 3 months ago

It is quite common for other packages (e.g. Distributions.jl) to overload Random.rand with parameter-dependent distributions, e.g.

using Enzyme
import Random

# N(x, 1) 
struct MyDistribution
    x::Float64
end

Random.rand(rng::Random.AbstractRNG, d::MyDistribution) = d.x + randn()
Random.rand(d::MyDistribution) = rand(Random.default_rng(), d)

@show autodiff(Enzyme.Reverse, x -> rand(MyDistribution(x)), Active, Active(1.0)) # ((0.0,),)

or even

struct DiracDelta 
    x::Float64
end

Random.rand(rng::Random.AbstractRNG, d::DiracDelta) = d.x
Random.rand(d::DiracDelta) = rand(Random.default_rng(), d)

@show autodiff(Enzyme.Reverse, x -> rand(DiracDelta(x)), Active, Active(1.0)) # ((0.0,),)

Zero gradients are returned because the outer rand call is marked inactive:

https://github.com/EnzymeAD/Enzyme.jl/blob/1bf16f8217f2f0e516666f5dff2deb27a653302d/src/internal_rules.jl#L69

gaurav-arya commented 3 months ago

Closed by #1388