gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
198 stars 16 forks source link

Problems with Categoricals and AbstractVectors #38

Closed AlCap23 closed 1 year ago

AlCap23 commented 1 year ago

Probably not an issue now, but related to the discussion #32 .

When I am trying to compute the gradient of vector of probabilities for the categorical distribution, I am running into an error:

ERROR: UndefVarError: right_nonzero not defined
Stacktrace:
  [1] δtoΔs(d::Categorical{Float64, Vector{Float64}}, val::Int64, δs::Vector{Float64}, Δs::StochasticAD.DictFIsBackend.DictFIs{Float64})
    @ StochasticAD ~/.julia/packages/StochasticAD/pyr3C/src/discrete_randomness.jl:168
  [2] rand(rng::Random._GLOBAL_RNG, d_st::Categorical{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}, Vector{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}}})
    @ StochasticAD ~/.julia/packages/StochasticAD/pyr3C/src/discrete_randomness.jl:189
  [3] rand
    @ ~/.julia/packages/Distributions/7iOJp/src/genericrand.jl:22 [inlined]
  [4] g(x::Vector{StochasticTriple{StochasticAD.Tag{typeof(g), Float64}, Float64, StochasticAD.DictFIsBackend.DictFIs{Float64}}})
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:249
  [5] (::var"#59#60"{UnionAll, typeof(g), Vector{Float64}})(i::Int64)
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:235
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] _collect(c::Base.OneTo{Int64}, itr::Base.Generator{Base.OneTo{Int64}, var"#59#60"{UnionAll, typeof(g), Vector{Float64}}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:807
  [8] collect_similar(cont::Base.OneTo{Int64}, itr::Base.Generator{Base.OneTo{Int64}, var"#59#60"{UnionAll, typeof(g), Vector{Float64}}})
    @ Base ./array.jl:716
  [9] map(f::Function, A::Base.OneTo{Int64})
    @ Base ./abstractarray.jl:2933
 [10] #stochastic_triple_vec#58
    @ ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:234 [inlined]
 [11] derivative_contribution(f::Function, p::Vector{Float64}; backend::Type)
    @ Main ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:241
 [12] top-level scope
    @ ~/.julia/dev/RLSR/benchmark/ode/predator_prey.jl:265

The MWE is:

using StochasticAD
using StochasticAD: StochasticTriple
using Distributions

function stochastic_triple1(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
    StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, one(p), backend)
end

function stochastic_triple0(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
  StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, zero(p), backend)
end

  # ix: which index to set "d/dx" term to 1
function stochastic_triple_vec(f, p, ix; backend = StochasticAD.PrunedFIs)
  [i == ix ? stochastic_triple1(f, p[i]; backend = backend) : stochastic_triple0(f, p[i], backend = backend) for i in 1:length(p)]
end

function stochastic_triple_vec(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    map(eachindex(p)) do i
        f(stochastic_triple_vec(f, p, i, backend = backend))
    end
end

function StochasticAD.derivative_contribution(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    p0 = similar(p, StochasticTriple)
    p0 .= stochastic_triple_vec(f, p, backend = backend)
    derivative_contribution.(p0)
end

g(x) = begin
    z = exp.(x) ./ sum(exp, x)
    sum(abs2, rand(Categorical(z)) .- 1.0) 
end

f(x) = begin
    z = inv.(1 .+ exp.(-x))
    sum(abs2, rand(Categorical([z[1]; 1 .- z[1]])) .- 1.0)
end

ps = randn(2)

derivative_contribution(f, ps) # Works 
derivative_contribution(g, ps) # Error 

I am using the latest release.

Cheers!

PS: Sorry to open up the issue this early, but I want to get rid of REINFORCE as fast as possible 😅. I am really excited about this package. If there is anything I can help with, let me know.

gaurav-arya commented 1 year ago

Thanks for the bug report! This issue is in fact independent of https://github.com/gaurav-arya/StochasticAD.jl/issues/32, and should be resolved by https://github.com/gaurav-arya/StochasticAD.jl/pull/39 (your MWE looks to be resolved, but feel free to reopen if the problem is still persisting in some form)