gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
198 stars 16 forks source link

One-hot Encoding #69

Open AlCap23 opened 1 year ago

AlCap23 commented 1 year ago

Hi there!

I've been trying to implement a one hot encoding using StochasticAD. So far, I've failed 🥲.

I think it essentially boils down to this TODO in the src. After tinkering for some while, I've decided to ask for help given that I did not come up with a good solution.

Cheers!

gaurav-arya commented 1 year ago

Hey! Is it possible to provide a minimum example that you can't differentiate?

AlCap23 commented 1 year ago

Hey! Sorry for being Dormant. I think this might work out as a MWE ( and maybe in general, I have to test ).

using Revise
using StochasticAD
using Distributions

# Simple stochastic program

struct OneHot{T, K} <: AbstractVector{T}
    n::Int
    k::K
    val::T
end

OneHot(n::Int,k::K,val::T = one(K)) where {T,K} = OneHot{T, K}(n, StochasticAD.value(k), val - StochasticAD.value(val) + 1) 

Base.size(x::OneHot) = (x.n,)

Base.getindex(x::OneHot{T}, i::Int) where T = (x.k == i ? x.val : zero(T))

Base.argmax(x::OneHot) = x.k

_softmax(x) = begin
    y = exp.(x .- maximum(x))
    y ./ sum(y)
end

_logsoftmax(x) = begin
    y = (x .- maximum(x))
    y .- log(sum(exp, y))
end

f(θ) = begin
    id = rand(Categorical(_softmax(θ)))
    @info id
    v = OneHot(length(θ), id, id)
    sum(v'_logsoftmax(θ))
end

θ = randn(3)
f(θ)

m = StochasticModel(f, θ)

stochastic_gradient(m) # Returns a gradient, still have to check if it finds the right value though