Open kimauth opened 2 years ago
One solution is to document and export the _propagate_gradient
function used by @implement_gradient
. It is not as sleek as the macro but allows the user to solve such a problem. I.e.
tensor_exp(A::SymmetricTensor{2,dim,<:ForwardDiff.Dual}) where{dim} = Tensors._propagate_gradient(tensor_exp_gradient, A)
For this to work the output of tensor_exp must be a symmetric tensor, i.e.
function tensor_exp(A::SymmetricTensor{2})
E = eigen(A)
A_exp = zero(A)
for (i, λ) in enumerate(E.values)
N = E.vectors[:,i]
A_exp += exp(λ) * otimes(N)
end
return A_exp
end
I ran into the following problem today:
The problem here is that we run into the
tensor_exp
function with dual numbers (instead of usingtensor_exp_gradient
). Looking at the methods oftensor_exp
, we can see why:The original
tensor_exp
is more specific than the one defined for Dual numbers by@implement_gradient
. The solution could be to not allow dual numbers in the original function at all, e.g. by(This is of course not so nice if one doesn't own this function. )
Perhaps there is a better solution than specifying the number type of the Tensor. In case there isn't we should probably add a hint about it to the docs.