EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Supporting covariant derivatives #1334

Open MasonProtter opened 6 months ago

MasonProtter commented 6 months ago

The problem:

Currently, when you compute the reverse mode derivative of some function of a struct MyStruct in enzyme, i.e. autodiff(Reverse, f, Duplicated(A::MyStruct, dA::MyStruct)), the object dA must be an identical type to A, but it should be interpreted in a very different way from A. What dA actually is, is the object you get by treating MyStruct as a Cartesian vector whose elements are the struct's fields.

The ith field of dA will be given by $\mathrm{d}A{i} = {\partial f / \partial A{i}}$

The problem arises though, that often what people want is not this component-wise Cartesian derivative, but a covariant derivative, i.e. the mathematically meaningful derivative on the manifold. (Note: I consider the type MyStruct to define a manifold whose metric could be constructed via the constructor for MyStruct, and I consider f to be a function on that manifold).

Here's a concrete example posted by @simsurace in Slack:

julia> using Enzyme, LinearAlgebra

julia> A = Symmetric(rand(3, 3)); dA = make_zero(A);

julia> Enzyme.autodiff(Reverse, sum, Duplicated(A, dA))
((nothing,),)

julia> dA
3×3 Symmetric{Float64, Matrix{Float64}}:
 1.0  2.0  2.0
 2.0  1.0  2.0
 2.0  2.0  1.0

compared to the answer given by Zygote:

julia> using Zygote

julia> Zygote.gradient(sum, A) |> only
3×3 Symmetric{Float64, FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}}:
 1.0  1.0  1.0
 1.0  1.0  1.0
 1.0  1.0  1.0

What's going on here? Enzyme is treating A::Symmetric as a vector with only one differentiable field (A.data), and so it computes

$$ dA{\mathrm{data}}[a, b] = {\partial \over \partial A{\mathrm{data}}[a, b]}\sum{i,j} A[i,j] \ = {\partial \over \partial A{\mathrm{data}}[a, b]} \sum{i,j} \begin{cases} A{\mathrm{data}}[i, j] & j \geq i \ A_{\mathrm{data}}[j, i], & \text{otherwise} \end{cases} $$

Now, since A.data[i,j] is never accessed for j < i, this evaluates to

$$ dA{\mathrm{data}}[a, b] = \sum{i,j} \begin{cases} \delta{a, i} \delta{b,j} & j \geq i \ \delta{a, j} \delta{b,i}, & \text{otherwise} \end{cases} = \sum{i} \delta{a,i} \delta{b,i} + \sum{i, j>i} 2 \delta{a,i}\delta{b,j} = \begin{cases} 1, & a = b \ 2, & a < b \ 0, & \text{otherwise} \end{cases} $$

meaning

dA.data = [1.0 2.0 2.0
           0.0 1.0 2.0
           0.0 0.0 1.0]

The problem is that dA itself is a Symmetric, so what we end up getting is

julia> dA
3×3 Symmetric{Float64, Matrix{Float64}}:
 1.0  2.0  2.0
 2.0  1.0  2.0
 2.0  2.0  1.0

dA has the incorrect mathematical properties if we want think of it as giving an actual derivative. If we interpret this thing as a gradient, it would suggest something incorrect about the slopes of the function sum over the space of Symmetric 3x3 matrices. If you check, you'll find that the slope of sum on this space in every [i,j] direction is uniformly 1 (the answer given by Zygote). Basically, Zygote / ChainRules try to transform the derivative to then live in the (co)tangent space, whereas Enzyme just calculates the component wise derivatives an stops.

Now, the user can extract out the covariant derivative given this Cartesian derivative dA, and the primal A, but doing so can be difficult, and many users are going to be quite surprised to find that Enzyme is giving them a seemingly "wrong" derivative, and are unlikely to even know how to transform their component wise derivatives into covariant derivatives.

What are Zygote / ChainRules doing here?

Basically, ChainRules based AD systems are going to default to these Cartesian 'structural' derivatives, unless they have rules put in place to ProjectTo the shadow values back to the correct manifold, and they're doing this at the level of rules, so the projections are happening throughout the whole AD process. Here's some resources:

The problem is that this approach 1) It's pretty hard. It can be computationally expensive, and it requires rules authors never knowing what type of shadow values they're dealing with, and they can end up returning all sorts of crazy stuff to downstream rules. I must say that despite the many difficulties with this approach, I kinda like it, but I'll let people who have actually had to suffer through writing rules and such for it talk about their experiences. 2) You never know if you'll get a Cartesian or a Covariant derivative. If someone implements a type, and you differentiate it without any ProjectTo rules being written, you'll end up getting a Cartesian derivative by default. Now, I should mention that this thing is surprisingly robust. i.e. it's pretty smart for general abstract arrays, i.e. writing Symmetric myself works correctly without any custom rules or ProjectTo methods:

julia> struct MySymmetric{T} <: AbstractMatrix{T}
           data::Matrix{T}
       end;

julia> Base.getindex(m::MySymmetric, i::Int, j::Int) = i <= j ? m.data[i, j] : m.data[j, i];

julia> Base.size(m::MySymmetric) = size(m.data);

julia> Zygote.gradient(sum, MySymmetric(rand(3,3))) |> only
3×3 Fill{Float64}, with entries equal to 1.0

julia> Zygote.gradient(MySymmetric(rand(3,3))) do A
           s = zero(eltype(A))
           for i ∈ eachindex(A)
               s += A[i]
           end
           s
       end |> only
3×3 Matrix{Float64}:
 1.0  1.0  1.0
 1.0  1.0  1.0
 1.0  1.0  1.0

so this concern is more of a potential concern for non-AbstractArray types if I understand correctly (please chine in if you disagree with this assessment, I'm not saying this is 100% reliable, but I think it is already pretty reliable, and can be made even more reliable).

What should Enzyme do?

Well, that's certainly up for discussion, but I think maybe the right thing to do would be to develop a set of functions that given a primal and a Cartesian derivative, can calculate a covariant derivative as a post processing pass.

E.g. we could have autodiff_cartesian which does something similar to what autodiff does currently, and then make it so that autodiff does something like

autodiff(Reverse, f, Active(x), Duplicated(y, dy_cartesian))

turning into something like

tape, df_cartesian, dx_cartesian = autodiff_cartesian(ReverseWithTangentTape, f, Active(x), Duplicated(y, dy_cartesian))
df = to_tangent_space(tape, f, df_cartesian)
dx = to_tangent_space(tape, x, dx_cartesian)
dy = to_tangent_space(tape, y, dy_cartesian)
return (df, dx,dy)

I think that in general in order to do this conversion to the tangent space, we'd need a record of which functions were hit (the tape and then a series of transformations would be applied to dy_cartesian based potentially on that tape and the primal y (see https://github.com/mcabbott/OddArrays.jl for examples of cases that need y and dy_tangent).

I think it'd be good to support a purely cartesian mode, and then have a to_tangent_space be pretty fussy about rejecting cases it doesn't understand with hard errors, but useful error hints. E.g. something like image

The design of the tape would be pretty hard to do right though.

wsmoses commented 6 months ago

So one question I have is whether it would suffice to do this covariant vs structural derivative just at the user level. Specifically, I'd like rules authors to assume one convention (e.g. just structural) that they need to write rules for, rather than have to consider both.

If that's the case, perhaps we could have a type parameter to ReverseMode and ForwardMode signifying whether the covariant is taken or not. We just added one for Holomorphic (see the dev docs now for complex numbers for more detail), and already have type parameters for whether the original return is desired, so this seems like a natural place to put it.

We could then have the covariant autodiff call the structural autodiff, then do all the shims required by recursing through types.

Does this seem like a reasonable compromise between enabling the natural result functionality for those who want it (by calling the covariant form), and preserving the mutation aware in-place form for those who don't (since constructing the new object in place inherently requires allocation?

mcabbott commented 5 months ago

I do think applying this to the final user output would be enough.

Maybe it's helpful to classify how some fixup(x, dx) function could behave:

  1. Simplest is that is just strips all custom structs, so that you get something obviously structural
  2. It could reconstruct the simplest "covariant" representations, such as Complex, Diagonal, SparseVector, SArray, but leave other or unknown ones structural. Here field access grad.a.b.c always gets you the structural component.
  3. It could further reconstruct things like Symmetric, which require less-obvious rules, and may require allocating new arrays, but leave unknown types structural. Here you cannot trust that field access gets you structural gradient components, which is a bit weird.
  4. Maximally, it could do level 3 but also error on unknown AbstractArray types, prompting you to supply the rule. Non-array structs containing arrays (e.g. Flux models) cannot always be reconstructed, as the "covariant" representation may not have the same type. Perhaps it's better not to anyway, since other behaviours attached to the struct may not make sense.

Levels 1 & 2 could be potentially be done to the internal types, widening Duplicated(x, dx) to allow mismatched types. Seems tricky, especially for mutable structs.

Any level could be done on final output. If you call gradient(Reverse, f, x) then perhaps this should be done automatically -- at least level 2 seems unobjectionable, you can still access fields, but won't be tempted to go wrong. Perhaps level 4 should be an option. (Does it belong to gradient or to Reverse, if the actual AD is unchanged?)

If you construct Duplicated(x, make_zero(x)) yourself, then perhaps you are also in charge of calling fixup(x, dx, level=4) afterwards too, if desired? autodiff(Reverse, f, Active, Duplicated(x, dx)) does not at present return dx. There are cases like autodiff(Reverse, sum, Active, Active(Symmetric(SA[1 2; 3 4.])))[1][1] |> dump which return a struct.