gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
198 stars 16 forks source link

PMF of Bernoullis #77

Open FHoltorf opened 1 year ago

FHoltorf commented 1 year ago

Hey!

I am trying to do VI with models involving discrete RVs. For that purpose it would be quite handy to get derivative estimators for PMFs of Bernoulli (and other discrete) RVs. Considering the following toy example

using StochasticAD, Distributions
function func(p) 
    x = rand(Bernoulli(p))
    pdf(Bernoulli(p), x)
end
function func_alt(p) 
    x = rand(Bernoulli(p))
    p^x*(1-p)^(1-x)
end

it seems that I can only get derivative estimators for func_alt. When trying to propagate derivative info through func it appears to fail because of the way pdf(::Bernoulli, ::Bool/Real) is implemented.

Now my questions:

  1. Am I using StochasticAD incorrectly?
  2. If not, would it be easy to accommodate propagation of stochastic triples through pdf(::Bernoulli, ::Bool/Real) (and perhaps the equivalent for other distributions with discrete support where I assume similar problems would arise)?

Thanks! Flemming

gaurav-arya commented 1 year ago

The issue is that stochastic triples unfortunately cannot propagate through the ternary operator in Distributions.jl's implementation of the Bernoulli PMF. You could fix this by overloading Distributions.pdf to catch stochastic triple inputs and feed these into the experimental (undocumented) StochasticAD.propagate interface:

using StochasticAD, Distributions

# Register an overload of the pdf
using Functors; @functor Bernoulli
Distributions.pdf(d::Bernoulli, x::StochasticAD.StochasticTriple) = StochasticAD.propagate(pdf, d, x; keep_deltas = Val{true}())

derivative_estimate(func, 0.7)

This should work on the about-to-be-released 0.1.14. (Performance may not be ideal; StochasticAD.propagate is still experimental functionality.)

Let me know if you have any questions! In any case, let's leave this issue open until this works out of the box.

(Edit: added keep_deltas = Val{true}() to the propagate call; the previous version was not correct.)