Closed cortner closed 2 years ago
Comments by @ettersi on Zulip:
This sounds like the kind of problem that the Julia community is currently working on in ChainRules.jl. I'd therefore recommend you either directly implement their frule
and rrule
interface, or at least have a good look at it to learn about the various pitfalls which lead to this design.
In a nutshell, here's what I believe their approach is / should be (some design aspects are still work in progress):
frule(dx,f,x)
(implements $$\nabla f(x) \, dx$$) and rrule(f,x) -> g
(implements $$g(df) = df^T \, \nabla f(x)$$) come in. dx
and df
as vectors. Any representation will do as long as 1) it contains the same information as the vector $$dx \in \mathbb{R}^m$$ / $$df \in \mathbb{R}^n$$ and 2) it should be addable, i.e. +(dx1,dx2)
should be defined (see https://juliadiff.org/ChainRulesCore.jl/stable/gradient_accumulation.html for some more details). I like to think of this as saying that every type should be assigned a unique differential type, and then the forward and reverse chain rules become functions (value, differential) -> differential
. I believe this framework applies to your situation pretty straightforwardly. In particular:
in the invariant case, should
dphi * dAAdrr
be aEuclideanVector
orAdjoint{EuclideanVector}
?
In the ChainRules.jl framework, what you are computing here is (g = rrule(phi, rr); g(1))
(the final output is g(1)
because that corresponds to the vjp 1 * dphi_drr
). I'd therefore choose g = rrule(phi,rr)
to be a function g(::Real) -> ::EuclidianVector
. The Adjoint
is implied by the fact that what we are computing is the output of an rrule
.
In the equivariant case should
phi * dAAdrr
be aEuclideanMatrix
?
In this case, g = rrule(phi,rr)
becomes a function g(::Vector) -> ::Vector
(unless I'm misinterpreting your problem statement). I'd probably leave it at that, i.e. I would not translate this g(dphi)
into a matrix, but of course the details here depend on what you want to do with your derivative.
What is the corresponding "thing" for Spherical vectors or matrices?
Define a differential type for your spherical vectors and matrices, and everything else should follow from there. I guess we can talk about this further once we have an agreement on the earlier points.
Just to record what I've currently implemented as a stop-gap solution: (typing is too strong but ok for now... this is just illustrative, the actual implementation varies...)
*(phi::AbstractProperty, dAA::SVector) = reshape( phi.val[:], dAA', Size( size(phi.val)..., length(dAA) ) )
that way we can re-use the matrix multiplication to get from AA
to B
but it is very risky, and definitely needs rethinking!! This is no more than a hack.
another comment from JuLIP: adjoints can be used to evaluate derivatives in a lazy way. This means we can have a relatively simple unified implementation for multiple arguments with respect to which we might want derivatives.
this is all happening and in fact mostly done so I'm closing it. new rules will be added as needed.
We should probably redesign differentiation in a more organised way, likely following the
ChainRules.jl
ideas or even usingChainRules.jl
directly. In particular this should enable us to leverage AD tools when needed.