Open MasonProtter opened 8 months ago
So one question I have is whether it would suffice to do this covariant vs structural derivative just at the user level. Specifically, I'd like rules authors to assume one convention (e.g. just structural) that they need to write rules for, rather than have to consider both.
If that's the case, perhaps we could have a type parameter to ReverseMode and ForwardMode signifying whether the covariant is taken or not. We just added one for Holomorphic (see the dev docs now for complex numbers for more detail), and already have type parameters for whether the original return is desired, so this seems like a natural place to put it.
We could then have the covariant autodiff call the structural autodiff, then do all the shims required by recursing through types.
Does this seem like a reasonable compromise between enabling the natural result functionality for those who want it (by calling the covariant form), and preserving the mutation aware in-place form for those who don't (since constructing the new object in place inherently requires allocation?
I do think applying this to the final user output would be enough.
Maybe it's helpful to classify how some fixup(x, dx)
function could behave:
Complex
, Diagonal
, SparseVector
, SArray
, but leave other or unknown ones structural. Here field access grad.a.b.c
always gets you the structural component.Symmetric
, which require less-obvious rules, and may require allocating new arrays, but leave unknown types structural. Here you cannot trust that field access gets you structural gradient components, which is a bit weird.AbstractArray
types, prompting you to supply the rule. Non-array structs containing arrays (e.g. Flux models) cannot always be reconstructed, as the "covariant" representation may not have the same type. Perhaps it's better not to anyway, since other behaviours attached to the struct may not make sense.Levels 1 & 2 could be potentially be done to the internal types, widening Duplicated(x, dx)
to allow mismatched types. Seems tricky, especially for mutable structs.
Any level could be done on final output. If you call gradient(Reverse, f, x)
then perhaps this should be done automatically -- at least level 2 seems unobjectionable, you can still access fields, but won't be tempted to go wrong. Perhaps level 4 should be an option. (Does it belong to gradient
or to Reverse
, if the actual AD is unchanged?)
If you construct Duplicated(x, make_zero(x))
yourself, then perhaps you are also in charge of calling fixup(x, dx, level=4)
afterwards too, if desired? autodiff(Reverse, f, Active, Duplicated(x, dx))
does not at present return dx
. There are cases like autodiff(Reverse, sum, Active, Active(Symmetric(SA[1 2; 3 4.])))[1][1] |> dump
which return a struct.
The problem:
Currently, when you compute the reverse mode derivative of some function of a
struct MyStruct
in enzyme, i.e.autodiff(Reverse, f, Duplicated(A::MyStruct, dA::MyStruct))
, the objectdA
must be an identical type toA
, but it should be interpreted in a very different way fromA
. WhatdA
actually is, is the object you get by treatingMyStruct
as a Cartesian vector whose elements are the struct's fields.The
i
th field ofdA
will be given by $\mathrm{d}A{i} = {\partial f / \partial A{i}}$The problem arises though, that often what people want is not this component-wise Cartesian derivative, but a covariant derivative, i.e. the mathematically meaningful derivative on the manifold. (Note: I consider the type
MyStruct
to define a manifold whose metric could be constructed via the constructor forMyStruct
, and I considerf
to be a function on that manifold).Here's a concrete example posted by @simsurace in Slack:
compared to the answer given by Zygote:
What's going on here? Enzyme is treating
A::Symmetric
as a vector with only one differentiable field (A.data
), and so it computes$$ dA{\mathrm{data}}[a, b] = {\partial \over \partial A{\mathrm{data}}[a, b]}\sum{i,j} A[i,j] \ = {\partial \over \partial A{\mathrm{data}}[a, b]} \sum{i,j} \begin{cases} A{\mathrm{data}}[i, j] & j \geq i \ A_{\mathrm{data}}[j, i], & \text{otherwise} \end{cases} $$
Now, since
A.data[i,j]
is never accessed forj < i
, this evaluates to$$ dA{\mathrm{data}}[a, b] = \sum{i,j} \begin{cases} \delta{a, i} \delta{b,j} & j \geq i \ \delta{a, j} \delta{b,i}, & \text{otherwise} \end{cases} = \sum{i} \delta{a,i} \delta{b,i} + \sum{i, j>i} 2 \delta{a,i}\delta{b,j} = \begin{cases} 1, & a = b \ 2, & a < b \ 0, & \text{otherwise} \end{cases} $$
meaning
The problem is that
dA
itself is aSymmetric
, so what we end up getting isdA
has the incorrect mathematical properties if we want think of it as giving an actual derivative. If we interpret this thing as a gradient, it would suggest something incorrect about the slopes of the functionsum
over the space ofSymmetric
3x3 matrices. If you check, you'll find that the slope ofsum
on this space in every[i,j]
direction is uniformly1
(the answer given byZygote
). Basically, Zygote / ChainRules try to transform the derivative to then live in the (co)tangent space, whereas Enzyme just calculates the component wise derivatives an stops.Now, the user can extract out the covariant derivative given this Cartesian derivative
dA
, and the primalA
, but doing so can be difficult, and many users are going to be quite surprised to find that Enzyme is giving them a seemingly "wrong" derivative, and are unlikely to even know how to transform their component wise derivatives into covariant derivatives.What are Zygote / ChainRules doing here?
Basically, ChainRules based AD systems are going to default to these Cartesian 'structural' derivatives, unless they have rules put in place to
ProjectTo
the shadow values back to the correct manifold, and they're doing this at the level of rules, so the projections are happening throughout the whole AD process. Here's some resources:Symmetric
and shows how one can try to systematically deal with themThe problem is that this approach 1) It's pretty hard. It can be computationally expensive, and it requires rules authors never knowing what type of shadow values they're dealing with, and they can end up returning all sorts of crazy stuff to downstream rules. I must say that despite the many difficulties with this approach, I kinda like it, but I'll let people who have actually had to suffer through writing rules and such for it talk about their experiences. 2) You never know if you'll get a Cartesian or a Covariant derivative. If someone implements a type, and you differentiate it without any
ProjectTo
rules being written, you'll end up getting a Cartesian derivative by default. Now, I should mention that this thing is surprisingly robust. i.e. it's pretty smart for general abstract arrays, i.e. writingSymmetric
myself works correctly without any custom rules orProjectTo
methods:so this concern is more of a potential concern for non-AbstractArray types if I understand correctly (please chine in if you disagree with this assessment, I'm not saying this is 100% reliable, but I think it is already pretty reliable, and can be made even more reliable).
What should Enzyme do?
Well, that's certainly up for discussion, but I think maybe the right thing to do would be to develop a set of functions that given a primal and a Cartesian derivative, can calculate a covariant derivative as a post processing pass.
E.g. we could have
autodiff_cartesian
which does something similar to whatautodiff
does currently, and then make it so that autodiff does something liketurning into something like
I think that in general in order to do this conversion to the tangent space, we'd need a record of which functions were hit (the
tape
and then a series of transformations would be applied tody_cartesian
based potentially on that tape and the primaly
(see https://github.com/mcabbott/OddArrays.jl for examples of cases that needy
anddy_tangent
).I think it'd be good to support a purely
cartesian
mode, and then have ato_tangent_space
be pretty fussy about rejecting cases it doesn't understand with hard errors, but useful error hints. E.g. something likeThe design of the
tape
would be pretty hard to do right though.