JuliaDiff / ForwardDiff.jl

Forward Mode Automatic Differentiation for Julia
Other
894 stars 145 forks source link

Gradients with respect to struct fields? #709

Open NAThompson opened 2 months ago

NAThompson commented 2 months ago

In Zygote.jl, we can take the gradient with respect to all fields of a struct foo passed through a function bar via

g = Zygote.gradient(f -> bar(f), foo)

Can this be done in ForwardDiff as well?

Reproducer:

using Zygote
using ForwardDiff

struct Foo
    x::Number
    t::Number
    c::Number
end

function bar(f::Foo)
    return f.x - f.c*f.t
end

foo = Foo(2, 3, 3e8)
println(foo)

g = Zygote.gradient(f -> bar(f), foo)

println(g)

g = ForwardDiff.gradient(f -> bar(f), foo)
println(g)
KristofferC commented 2 months ago

Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.

mcabbott commented 2 months ago

If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:

julia> using ForwardDiff: Dual, partials

julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));

julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));

julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)

julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)

With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x). And even allow structs of structs.

Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.