gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
198 stars 16 forks source link

Zygote.gradient returns nothing when differentiating vectorized sampling. #71

Open arnauqb opened 1 year ago

arnauqb commented 1 year ago

Hi there, first of all, thanks for this great package. I'm trying to understand how StochasticAD integrates with Zygote and I ran into a problem when trying to differentiate a "dot" operation. Here is the code:

using StochasticAD, Distributions, Zygote

function sample_n_bernoullis(p, n)
    return sum([rand(Bernoulli(p)) for i in 1:n])
end

function sample_n_bernoullis_vectorized(p, n)
    probs = p * ones(n)
    return sum(rand.(Bernoulli.(probs)))
end

n = 100
p = 0.5

derivative_estimate(p -> sample_n_bernoullis(p, n), p) # this works
derivative_estimate(p -> sample_n_bernoullis_vectorized(p, n), p) # this works

Zygote.gradient(p -> sample_n_bernoullis(p, n), p) # this works
Zygote.gradient(p -> sample_n_bernoullis_vectorized(p, n), p) # this doesn't work (returns nothing)

Are vectorized operations not supported with Zygote?

Thanks!

gaurav-arya commented 1 year ago

Hey, apologies for the late reply. I've been taking a look into this, and it seems to be an interaction between Zygote's broadcast differentiation machinery and the fact that rand ∘ Bernoulli produces a boolean. Things work fine for a geometric:

function sample_n_geometrics_vectorized(p, n)
    probs = p * ones(n)
    return sum(rand.(Geometric.(probs)))
end
Zygote.gradient(p -> sample_n_geometrics_vectorized(p, n), p)  # works fine

So while this is not fixed, a couple alternatives you have are to rewrite using a map or make sure the broadcasted function doesn't return a boolean:

function sample_n_bernoullis_vectorized_1(p, n)
     probs = p * ones(n)
     return (x -> rand(Bernoulli(x))*1).(probs) |> sum
end

function sample_n_bernoullis_vectorized_2(p, n)
       probs = p * ones(n)
       return map(rand ∘ Bernoulli, probs) |> sum
end

Zygote.gradient(p -> sample_n_bernoullis_vectorized_1(p, n), p)  # works fine
Zygote.gradient(p -> sample_n_bernoullis_vectorized_2(p, n), p)  # works fine

Thanks for catching this!