Research package for automatic differentiation of programs containing discrete randomness.
Problems with Categoricals and AbstractVectors #38

AlCap23 commented 1 year ago

Probably not an issue now, but related to the discussion #32 .

When I am trying to compute the gradient of vector of probabilities for the categorical distribution, I am running into an error:

ERROR: UndefVarError: right_nonzero not defined
  [1] δtoΔs(d::Categorical{Float64, Vector{Float64}}, val::Int64, δs::Vector{Float64}, Δs::StochasticAD.DictFIsBackend.DictFIs{Float64})
    @ StochasticAD ~/.julia/packages/StochasticAD/pyr3C/src/discrete_randomness.jl:168
  [2] rand(rng::Random._GLOBAL_RNG, d_st::Categorical{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}, Vector{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}}})
    @ StochasticAD ~/.julia/packages/StochasticAD/pyr3C/src/discrete_randomness.jl:189
  [3] rand
    @ ~/.julia/packages/Distributions/7iOJp/src/genericrand.jl:22 [inlined]
  [4] g(x::Vector{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}})
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:249
  [5] (::var"#59#60"{UnionAll, typeof(g), Vector{Float64}})(i::Int64)
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:235
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] _collect(c::Base.OneTo{Int64}, itr::Base.Generator{Base.OneTo{Int64}, var"#59#60"{UnionAll, typeof(g), Vector{Float64}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:807
  [8] collect_similar(cont::Base.OneTo{Int64}, itr::Base.Generator{Base.OneTo{Int64}, var"#59#60"{UnionAll, typeof(g), Vector{Float64}}})
    @ Base ./array.jl:716
  [9] map(f::Function, A::Base.OneTo{Int64})
    @ Base ./abstractarray.jl:2933
 [10] #stochastic_triple_vec#58
    @ ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:234 [inlined]
 [11] derivative_contribution(f::Function, p::Vector{Float64}; backend::Type)
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:241
 [12] top-level scope
    @ ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:265

The MWE is:

using StochasticAD
using StochasticAD: StochasticTriple
using Distributions

function stochastic_triple1(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
    StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, one(p), backend)

function stochastic_triple0(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
  StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, zero(p), backend)

  # ix: which index to set "d/dx" term to 1
function stochastic_triple_vec(f, p, ix; backend = StochasticAD.PrunedFIs)
  [i == ix ? stochastic_triple1(f, p[i]; backend = backend) : stochastic_triple0(f, p[i], backend = backend) for i in 1:length(p)]

function stochastic_triple_vec(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    map(eachindex(p)) do i
        f(stochastic_triple_vec(f, p, i, backend = backend))

function StochasticAD.derivative_contribution(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    p0 = similar(p, StochasticTriple)
    p0 .= stochastic_triple_vec(f, p, backend = backend)

g(x) = begin
    z = exp.(x) ./ sum(exp, x)
    sum(abs2, rand(Categorical(z)) .- 1.0) 

f(x) = begin
    z = inv.(1 .+ exp.(-x))
    sum(abs2, rand(Categorical([z[1]; 1 .- z[1]])) .- 1.0)

ps = randn(2)

derivative_contribution(f, ps) # Works 
derivative_contribution(g, ps) # Error 

I am using the latest release.


PS: Sorry to open up the issue this early, but I want to get rid of REINFORCE as fast as possible 😅. I am really excited about this package. If there is anything I can help with, let me know.

gaurav-arya commented 1 year ago

Thanks for the bug report! This issue is in fact independent of, and should be resolved by (your MWE looks to be resolved, but feel free to reopen if the problem is still persisting in some form)