Ferrite-FEM / Tensors.jl

Efficient computations with symmetric and non-symmetric tensors with support for automatic differentiation.
https://ferrite-fem.github.io/Tensors.jl/
Other
170 stars 38 forks source link

implement custom gradient with multi-argument functions #197

Open KnutAM opened 1 year ago

KnutAM commented 1 year ago

From Slack-comment by @koehlerson; how to implement custom gradient calculation for a multi-argument function. It is common to have such a case for autodiff, so would be good to have a clear way of doing this. The solution I can come up with now is

using Tensors
import ForwardDiff: Dual

# General setup for any function f(x, args...)
struct Foo{F,T<:Tuple} <: Function # <:Function optional
    f::F
    args::T
end
struct FooGrad{FT<:Foo} <: Function # <: Function required
    foo::FT
end

function (foo::Foo)(x)
    println("Foo with Any: ", typeof(x))  # To show that it works
    return foo.f(x, foo.args...)
end
function (foo::Foo)(x::AbstractTensor{<:Any,<:Any,<:Dual})
    println("Foo with Dual: ", typeof(x))  # To show that it works
    return Tensors._propagate_gradient(FooGrad(foo), x)
end
function (fg::FooGrad)(x)
    println("FooGrad: ", typeof(x)) # To show that it works
    return f_dfdx(fg.foo.f, x, fg.foo.args...)
end

# Specific example to setup for bar(x, a, b), must then also define f_dfdx(::typeof(bar), x, a, b):
bar(x, a, b) = norm(a*x)^b 
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
f_dfdx(::typeof(bar), args...) = (bar(args...), dbar_dx(args...))

# At the location in the code where the derivative will be calculated
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
foo = Foo(bar, (a, b))
gradient(foo, t) == dbar_dx(t, a, b)

But it is quite cumbersome, especially if only needed for one function, so a better method would be good. (Tensors._propagate_gradient is renamed to propagate_gradient, exported, and documented in #181)

KristofferC commented 1 year ago

I don't understand why a closure over a and b wouldn't work here.

x->bar(x, a, b)

KnutAM commented 1 year ago

I'm not sure that I follow, that would only define one function for x::Any. Do you have a complete example? If working directly on bar, I think it is necessary to write a custom propagate_gradient using Tensors._extract_value and Tensors._insert_gradient. Alternatively, we could extend that to accept args...:

function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...)
    fval, dfdx_val = f_dfdx(_extract_value(x), args...)
    _check_gradient_shape(fval,x,dfdx_val)
    return _insert_gradient(fval, dfdx_val, x)
end
KristofferC commented 1 year ago

Okay, I missed the point:

implement custom gradient calculation for a multi-argument function.

Carry on..

koehlerson commented 1 year ago

Initially I planned to do a custom layer for energy densities something like

energy(F,material,state) = #something

analytic_or_AD(energy::FUN, F, material, state) where FUN<:Function = Tensors.hessian(x->energy(x,material,state),F)

where a generic dispatch uses Tensors.hessian and for known analytic parts you call another dispatch. However, @implement_gradient should be capable of handling this imo. Further it feels that I reinvent the wheel. I don't think that the dispatchwise approach could substitute only pieces of the derivative, so mix and match analytic and automatic differentiation when the energy function calls again something which is known analytically as e.g. strain energy densities

KnutAM commented 1 year ago

But I think the approach of allowing args... in propagate_gradient could be nice for this:

using Tensors
import Tensors: _extract_value, _insert_gradient, Dual
# Change in Tensors.jl
function propagate_gradient(f_dfdx::Function, x::Union{AbstractTensor{<:Any, <:Any, <:Dual}, Dual}, args...)
    fval, dfdx_val = f_dfdx(_extract_value(x), args...)
    # _check_gradient_shape(fval,x,dfdx_val) # PR181
    return _insert_gradient(fval, dfdx_val, x)
end

# User code:
# - Definitions
bar(x, a, b) = norm(a*x)^b
dbar_dx(x, a, b) = b*(a^b)*norm(x)^(b-2)*x
bar_dbar_dx(x, a, b) = (bar(x, a, b), dbar_dx(x, a, b))
bar(x::AbstractTensor{<:Any, <:Any, <:Dual}, args...) = (println("DualBar"); propagate_gradient(bar_dbar_dx, x, args...))
# - At call-site
t = rand(SymmetricTensor{2,3}); a = π; b = 2 # Typically inputs
gradient(x->bar(x, a, b), t)