gaurav-arya / StochasticAD.jl

Research package for automatic differentiation of programs containing discrete randomness.
MIT License
198 stars 16 forks source link

gradient of function with vector input #32

Closed slwu89 closed 1 year ago

slwu89 commented 1 year ago

Hi! Thanks for the very cool package/method. I really enjoyed reading the paper.

I have a question regarding taking the derivative of a function whose input p is a vector. In the example below I am just using a deterministic function to keep things simple. I was trying to find out how to get the same behavior as ForwardDiff.gradient on this simple function but couldn't find a good way to do it.

using ForwardDiff, StochasticAD
f(x) = (x[1]*x[2]*sin(x[3])+ exp(x[1]*x[2]))/x[3]
x = [1,2,π/2]

ForwardDiff.gradient(f,x)

# bad way, not generic
[derivative_estimate(p -> f([p,x[2],x[3]]), x[1]), derivative_estimate(p -> f([x[1],p,x[3]]), x[2]), derivative_estimate(p -> f([x[1],x[2],p]), x[3])]

# naive way doesn't work, because setindex! isn't differentiable
# [derivative_estimate(p -> f(setindex!(copy(x), p, i)), x[i]) for i in 1:3]

I ended up peeking at the code to find out how to set the infinitesimal component of only the element of the vector we are taking the derivative with respect to in order to get something similar, which recovers the same behavior as ForwardDiff.gradient.

function stochastic_triple1(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
  StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, one(p), backend)
end

function stochastic_triple0(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
  StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, zero(p), backend)
end

# ix: which index to set "d/dx" term to 1
function stochastic_triple_vec(f, p, ix)
  [i == ix ? stochastic_triple1(f, p[i]) : stochastic_triple0(f, p[i]) for i in 1:length(p)]
end

[derivative_contribution(f(stochastic_triple_vec(f, x, i))) for i in 1:3]

I was curious if this is how you'd recommend easily getting gradients for various functions with vector input? Or if there is something better I've overlooked? This particular code is just a quick test, but also if think looks useful in any way I'm very happy to work on a PR!

gaurav-arya commented 1 year ago

I agree we should have this functionality! I don't think there's an easier solution right now than what you have.

Regarding your proposed solution, it looks good to me, although I do wonder if we could extend it from just vectors to something functors-based. Feel free to make a draft PR with your current structure and I can try to tweak it as necessary!

slwu89 commented 1 year ago

Ok thanks @gaurav-arya! I'll do that.

AlCap23 commented 1 year ago

Just a quick comment for people trying to use it with a ComponentVector:


using StochasticAD
using ComponentArrays
using StochasticAD: StochasticTriple
using Optimisers
ps = ComponentVector((; a = 0.1, b = 0.2, c = 0.7))

function stochastic_triple1(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
    StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, one(p), backend)
end

function stochastic_triple0(f, p::V; backend = StochasticAD.PrunedFIs) where {V}
  StochasticAD.StochasticTriple{StochasticAD.Tag{typeof(f), V}}(p, zero(p), backend)
end

  # ix: which index to set "d/dx" term to 1
function stochastic_triple_vec(f, p, ix)
  [i == ix ? stochastic_triple1(f, p[i]) : stochastic_triple0(f, p[i]) for i in 1:length(p)]
end

function stochastic_triple_vec(f::Function, p::ComponentVector, ix::Int; backend = StochasticAD.PrunedFIs) 
    v_cp = similar(p, StochasticTriple)
    @inbounds @views foreach(eachindex(p)) do i 
        if i == ix 
            v_cp[i] = stochastic_triple1(f, p[i], backend = backend)
        else
            v_cp[i] = stochastic_triple0(f, p[i], backend = backend)
        end
    end
    v_cp
end

function stochastic_triple_vec(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    map(eachindex(p)) do i
        f(stochastic_triple_vec(f, p, i, backend = backend))
    end
end

function StochasticAD.derivative_contribution(f::Function, p::P; backend = StochasticAD.PrunedFIs) where {P <: AbstractVector}
    p0 = similar(p, StochasticTriple)
    p0 .= stochastic_triple_vec(f, p, backend = backend)
    derivative_contribution.(p0)
end

g(x) = x.a^2 + (1-x.b)^2
@time stochastic_triple_vec(g, ps, 1)
@time stochastic_triple_vec(g, ps)
@time [derivative_contribution(g(stochastic_triple_vec(g, ps, i))) for i in 1:3]
@time derivative_contribution(g, ps)

using Optimisers
opt = ADAM()
leaf = Optimisers.setup(opt, ps)
for i in 1:10_000
    Optimisers.update!(leaf, ps, derivative_contribution(g, ps))
end
ps

Seems to do the trick ( feel free to double check here, its early ;) ).

Edit Full examples

mschauer commented 1 year ago

Nice! I thought I just leave this here too: When doing partial derivatives with ForwardDiff, I use a trick with a closure I got from @cscherrer

using MappedArrays
function make_partial(f, d)
    ith = zeros(d)
    return function (x, i)
        ith[i] = 1
        sa = mappedarray(ForwardDiff.Dual{}, x, ith)
        δ = f(sa).partials[]
        ith[i] = 0
        δ 
    end
end
gaurav-arya commented 1 year ago

Vector inputs should be resolved by #40! If support for ComponentArrays, Functors, etc. is desired feel free to make a draft PR with the desired functionality which I can tweak as needed