Open NAThompson opened 2 months ago
Not straight forward. ForwardDiff differentiates w.r.t numbers and abstract vectors. You might be able to hack something together with generated functions.
If you want to do this yourself, and only have a struct of real numbers, then it will be fairly simple:
julia> using ForwardDiff: Dual, partials
julia> make_dual(z::Foo) = Foo(Dual(z.x,1,0,0), Dual(z.t,0,1,0), Dual(z.c,0,0,1));
julia> get_Foo(dy::Dual) = (; x=partials(dy,1), t=partials(dy,2), c=partials(dy,3));
julia> get_Foo(bar(make_dual(foo)))
(x = 1.0, t = -3.0e8, c = -3.0)
julia> Zygote.gradient(bar, foo)[1]
(x = 1.0, t = -3.0e8, c = -3.0)
With a bit more work you could automate this to work with many structs of numbers, struct_gradient(f, x)
. And even allow structs of structs.
Allowing structs containing arrays will be much more tricky, basically thanks to ForwardDiff's chunk mode.
In
Zygote.jl
, we can take the gradient with respect to all fields of a structfoo
passed through a functionbar
viaCan this be done in
ForwardDiff
as well?Reproducer: