FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Adjoints for functions of specialized matrices #402

Closed sdewaele closed 3 years ago

sdewaele commented 4 years ago

I encountered an error with the adjoint of getindex of a special matrix. The issue is that the current getindex adjoints tries to assign to an element in the special matrix, which is not always allowed. It points to a more general problem/opportunity for adjoints of function of special matrices. Using Diagonal as an example, this is the error:

using LinearAlgebra
using Zygote
using Random
rng = MersenneTwister(54754)
n = 3
A = rand(rng,n,n)

y,B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
bA = B_getindex(1)[1] |> display
#ArgumentError: cannot set off-diagonal entry (2, 1) to a nonzero value (1)

Two potential possible are:

Solution A : adjoint is a tuple with an unconstrained array

# Copy of getindex(D::Diagonal,...)
using LinearAlgebra:diagzero
@inline function getindex_duplicate(D::Diagonal, i::Int, j::Int)
  @boundscheck checkbounds(D, i, j)
  if i == j
      @inbounds r = D.diag[i]
  else
      r = diagzero(D, i, j)
  end
  r
end

yA,∂getindexA = Zygote.pullback(x->getindex_duplicate(x,3,3),Diagonal(A))
D̄A = ∂getindexA(1)[1]
display(D̄A)
# (diag = [0.0, 0.0, 1.0],)

Pro: Simple. Con: The adjoint does not return a Diagonal. This reduces the efficiency of the adjoint code. This can perhaps be a fallback, as we implement solution B.

Solution B : Insert orthogonal projection and return a special matrix

getindexB(D::Diagonal,i::Int,j::Int) = getindex(A,i,j)
Zygote.@adjoint getindexB(D::Diagonal,i::Int,j::Int) = begin
  y,∂U = Zygote.pullback(x->getindex_duplicate(x,i,j),D)
  y, function(Ȳ)
    D̄U = ∂U(Ȳ)[1]
    D̄ = Diagonal(D̄U.diag)
    return D̄, nothing, nothing
  end
end

yB,∂getindexB = Zygote.pullback(x->getindexB(x,3,3),Diagonal(A))
D̄B = ∂getindexB(1)[1]
display(D̄B)
JuliaManifolds/Manifolds.jl#3×3 Diagonal{Float64,Array{Float64,1}}:
# 0.0   ⋅    ⋅
#  ⋅   0.0   ⋅
#  ⋅    ⋅   1.0

Pro: Returns a special matrix, so the adjoint code becomes more efficient. Con: Some work to write the adjoint.

Both solutions provide the same gradients:

ftestA(x) = getindex_duplicate(Diagonal(x.^2),3,3)
ftestB(x) = getindexB(Diagonal(x.^2),3,3)
x = [0.3,-2.5,4.3]
gA = Zygote.gradient(ftestA,x)[1]
gB = Zygote.gradient(ftestB,x)[1]
display(gA==gB)
# true

Mathematical background solution B Suppose we have a special matrix S ∈ 𝕊. For example, for Diagonal, 𝕊 is the space of diagonal matrices. 𝕊 is a subspace of the space of the regular matrices ℤ: 𝕊 ⊂ ℤ. The original function maps from 𝕊 to 𝕐: Y = f(S) : 𝕊 → 𝕐 For example, for getindex(S,i::Int,j::Int), 𝕐 is scalar, so 𝕐 = ℝ. If we take the adjoint through the the code of f, we currently get a unconstrained matrix S̄ ∈ ℝ, as shown above for getindex. To ensure that S̄ ∈ 𝕊, we insert a orthogonal projection P from ℤ to 𝕊: S = P(Z) : ℤ → 𝕊 Note that: 1) For S ∈ 𝕊, P is the identity: P(S) = S 2) Because P is an orthogonal projection, it is self-adjoint: P' = P. Therefore, the adjoint also maps ℤ → 𝕊.

Instead of computing the adjoint for f, we compute it for a function g g = f∘P which is: g' = P∘f' Because of property 1, g=f for all special matrix S ∈ 𝕊.

This methodology applies to, among others: Diagonal, Symmetric, UpperDiagonal, LowerDiagonal, I(UniformScaling)(?). Often, P is simply the constructor (Diagonal, UpperDiagonal, LowerDiagonal). For Symmetric, it is P(A) = (A+transpose(A))/2

Beyond getindex In general, this approach could be used for adjoints of functions mapping from the constrained to the unconstrained space, where the adjoint is not guaranteed to map to 𝕊, notably collect and Array(). As with getindex, specializing the current adjoint to Array will probably give you solution A for these. It should be noted that the adjoint for these functions do work now, unlike getindex, so coding these adjoints is less urgent.

Open questions 1) Is solution A or solution B preferred? I think B, using A as a fallback 2) Do we need to restrict the current adjoint for getindex to apply only to Array instead of AbstractArray? (I think is was like this in the past). This would have the benefit that solution A is automatically used for special matrices. But I don't know what are the repercussions of this. With the current signature of the getindex adjoint, I don't see how I can circumvent it to autodiff through the original getindex code in solution B, apart from creating a duplicate, which is of course a undesirable.

If there is an interest in this, I am happy to start a branch to start implementing some of this. I would appreciate guidance and help from others to get it right!

mcabbott commented 4 years ago

I don't get the error you see, at least on Zygote v0.3.4. (I hit JuliaManifolds/Manifolds.jl#393 on master today.)

EDIT: On Zygote v0.4.1 now, I do get the error. And my examples below also no longer work. (And looking at the source of 0.3.4, I don't see how they did work!)

But here are some more compact examples of the behaviour on version 0.3.4. Gradients of Diagonal input regard the off-diagonal elements as strongly zero, rather than simply treating this as a compact way of specifying a full matrix. I guess ideally these would return another Diagonal object, not a NamedTuple:

julia> gradient(x -> x[1,1] + 10x[1,2], Diagonal(ones(2)))[1]
(diag = [1.0, 0.0],)

julia> gradient(x -> x[1,1] + 10x[1,2], Diagonal(ones(2,2)))[1]
(diag = [1.0, 0.0],)

julia> gradient(x -> x[1,1] + 10x[1,2], collect(Diagonal(ones(2))))[1]
2×2 Array{Float64,2}:
 1.0  10.0
 0.0   0.0

Gradients with Symmetric also seem to take this structure seriously, and again return a namedtuple:

julia> gradient(x -> x[1,1] + 10x[1,2] + 100x[2,1], Symmetric(ones(2,2)))[1]
(data = [1.0 110.0; 0.0 0.0], uplo = nothing)

julia> gradient(x -> x[1,1] + 10x[1,2] + 100x[2,1], collect(Symmetric(ones(2,2))))[1]
2×2 Array{Float64,2}:
   1.0  10.0
 100.0   0.0

This is what you call solution B I think. Which seems correct to me.

For sparse matrices the current behaviour is more like what you call solution A, it's happy to return a dense matrix. Discussion here: https://github.com/FluxML/Zygote.jl/issues/163#issuecomment-486343665

Unfortunately the above examples are broken by JuliaManifolds/Manifolds.jl#256, which fixes other getindex problems, so perhaps that needs to change:

∇getindex (generic function with 1 method)

julia> gradient(x -> x[1,1] + 10x[1,2], Diagonal(ones(2,2)))[1]
ERROR: ArgumentError: cannot set off-diagonal entry (1, 2) to a nonzero value (10.0)

julia> gradient(x -> x[1,1], Diagonal(ones(2,2)))[1]
2×2 Diagonal{Float64,Array{Float64,1}}:
 1.0   ⋅ 
  ⋅   0.0
sdewaele commented 4 years ago

Thanks for your comments!

As you note, in the current code, the examples do not work. All of them would work if the scope of the getindex adjoint would be reduced to Array instead of AbstractArray. The reason is that then solution A (returning tuples with regular arrays) becomes active.

All the examples you mention are solution A, with the exception of the very last one gradient(x -> x[1,1], Diagonal(ones(2,2)))[1] (because diagonal indices on a Diagonal can be set). It does results in the correct gradients. However, it does not exploit the computational efficiency that can be had by using the original special matrix, e.g. Diagonal. As I mention in my first post, in many cases simply converting the object to the special matrix results in the better solution B. However, this is not always the case, e.g. for Symmetric. The reason is that the Symmetric constructor is not an orthogonal projection into the space of Symmetric matrices.

mcabbott commented 4 years ago

Ah I see, with @adjoint getindex(x::Array, inds...) you get the behaviour of my examples, because the getindex on the wrapper later calls this on the parent. Perhaps that is more elegant, although I would have thought it should be a bit wider than Array (sparse, cu?). And somehow it ought to return something array-like.

See what you think of this: https://github.com/FluxML/Zygote.jl/pull/256/commits/73fdfa9db68941509937d76dcddc98aecdf421dd . With this approach, _zero makes an un-structured array to write into, which is then projected.

In fact gradient(x -> x[1,1] + 10x[1,2], Diagonal(ones(2))) returns a Diagonal wrapping a SparseVector. Just by chance that’s what similar does here. It’s possible that _zero should more often return a sparse matrix.

sdewaele commented 4 years ago

Yes, this does the trick; thanks for writing this! I did not expect that this call to similar would have the desired effect.

I propose to rename _makelike to reflect that it has to be an orthogonal projection as I described in the math section above, e.g. by calling it _project and add in the documentation that it is an orthogonal projection. Not all functions that make dx like x will result in the correct gradients.

Although I have not had the time to test, I think that the code for Symmetric should be slightly different:

_project(dx, x::Symmetric) = Symmetric((dx+transpose(dx))/2, Symbol(x.uplo))

or, equivalently, and more similar to the current code:

_project(dx, x::Symmetric) = Symmetric(_symmetric_back(dx, x.uplo)/2, Symbol(x.uplo))

note the added /2 compared to the definition in your commit, needed to make it an orthogonal projection. Similarly, for Hermitian, except using adjoint instead of transpose. Would be good to confirm this using ngradient.

mcabbott commented 4 years ago

Yea it’s not quite right yet. I think the examples above should work (and are included as tests) but they would be broken by what you suggest. However I don’t like this 6:

julia> gradient(x -> Symmetric([x 3x; 5x 7x])[1,1], 10)
(1,)

julia> gradient(x -> Symmetric([x 3x; 5x 7x])[1,2], 10) # wrong on 73fdfa9, should be 3
(6,)
sdewaele commented 4 years ago

The issue is that the right-hand side of this test is not right:

@test gradient(x -> x[1,1] + 10x[1,2] + 100x[2,1], Symmetric(ones(2,2)))[1] == Symmetric([1 110; 0 0])

Try this on 73fdfa9:

function f(A)
  As = Symmetric(A)
  y = As[1,1]+10As[1,2]+100As[2,1]
  return y
end
A = ones(2,2)
g = ngradient(f,A)[1]
gZ = Zygote.gradient(f,A)[1]
@test g ≈ gZ
#Test Failed:
#  Expression: g ≈ gZ
#   Evaluated: [1.0 110.0; 0.0 0.0] ≈ [1.0 220.0; 0.0 0.0]

BTW, I was wrong with the /2 addition to the _symmetric_back code, the only correct one is:

_project(dx, x::Symmetric) = Symmetric((dx+transpose(dx))/2, Symbol(x.uplo))

With this definition, the ngradient test above passes.

I like it that the current test checks the type of the return value with isa Symmetric. The following may be a good addition to the tests; using ones(2,2) has the risk that it is symmetric itself and so may hide certain bugs.

rng = MersenneTwister(348302)
n = 3
A = rand(rng,n,n)
@test gradcheck(f,A)

BTW, with this out of the way, the following might be nice to add in the same spirit of returning special matrices:

for AT in [:Diagonal, :LowerTriangular, :UpperTriangular,:Symmetric]
    @adjoint collect(x::$AT) = collect(x), Δ -> (_project(Δ,x),)
end

or perhaps even just widening the scope of the current collect adjoint..? Or is this going to break stuff?

@adjoint collect(x::AbstractArray) = collect(x), Δ -> (_project(Δ,x),)
mcabbott commented 4 years ago

OK now thinking more clearly I agree that what’s needed is the projection operator. Done!

sdewaele commented 4 years ago

Excellent! Thanks a lot! I see this is part of JuliaManifolds/Manifolds.jl#256.

I agree this is done. I assume we'll close this issue once JuliaManifolds/Manifolds.jl#256 is merged to master?

Two final comments (may be of interest to @MikeInnes, @sethaxen, too):

  1. Possibly many other (all?) adjoints for linear algebra functions (e.g. matrix multiplication, etc. ...) can benefit from applying _project to returned matrices, in that they would return specialized matrices that could accelerate the reverse pass.
  2. The current signature for _project(dx,x) includes the value of x. Right now, this is somewhat superfluous, only the type _project,typeof(x)) would be enough. However, we could use the same mechanism for spaces with non-linear contraints, e.g. a correlation matrix. In that case the value of x would be required.
mcabbott commented 4 years ago

Yes, I thought about writing _project(dx, typeof(x)) but got lazy because _zero takes x anyway, and similar(typeof(x)) doesn’t work. But for other uses that would be an improvement.

Things like collect could be a different PR. Unlike the wrappers _project isn’t free, it broadcasts a new copy, so it’s not obvious you’d want to apply it where the existing gradient is already going to be a symmetric matrix.

sdewaele commented 4 years ago

That's right, for Symmetric, a copy is created. Is this also the case Diagonal etc., or not? I think it may still be often worthwhile because of the speed gain from using the specialized matrix.

What we can do to experiment with this is to write the adjoint for _project, which can be inserted here and there and see how it compares. I'll posted it here soon.

sethaxen commented 4 years ago

I'm very interested in solutions for this problem. Solution B as proposed above sounds similar to some of the proposals we've made in extending AD to Riemannian manifolds for Manifolds.jl. These issues may be useful:

I'll hopefully have the time to read this carefully within the next few days.

sdewaele commented 4 years ago

_project isn’t free, it broadcasts a new copy

If it would be the case that using _project does not provide a computational advantage, then we could also omit it from the adjoint of getindex. The resulting gradients would still be correct, the only difference is that the reverse pass no longer returns a special matrix, as in Solution A above. It would be a simplification of the code. And there is still the benefit that the error I was getting with getindex no longer occurs.

I have done some small experiments with and without projection, but could not produce a computational advantage yet. Maybe a more complicated example is needed. I am sharing it here in case others want to try it. The way the projection is used below is only for experimentation. If beneficial, it should be included in Zygote adjoints so that the user would not have to insert the projections.

using Zygote
using BenchmarkTools
using Random
using LinearAlgebra
using Test

ℙ(X) = X
ℙ(::T,X) where {T<:AbstractArray} = X
ℙ(::Type{T},X) where {T<:Diagonal} = Diagonal(X)
ℙ(::Type{T},X) where {T<:Symmetric} = Symmetric((X+X')//2)
Zygote.@adjoint ℙ(X::T) where {T} = ℙ(X), Ȳ -> (ℙ(T,Ȳ),)

rng = MersenneTwister(2845493)
n = 20

## Test function
A = randn(rng,n,n)
D = Diagonal(randn(rng,n,n))
x = 2.0
f(x) = A*(x*D)
fℙ(x) = A*ℙ(x*D) # equal to `f`, since ℙ is the identity in the forward pass
@test f(x)≈fℙ(x)

## Adjoints
Ȳ = randn(rng,n,n)
Y,B = Zygote.pullback(f,x)
x̄ = B(Ȳ)[1]
Yℙ,Bℙ = Zygote.pullback(fℙ,x)
x̄ℙ = Bℙ(Ȳ)[1]
@test x̄≈x̄ℙ

## Timing
@btime B($Ȳ)
@btime Bℙ($Ȳ)
# result: similar timing for B and Bℙ
mcabbott commented 4 years ago

_project isn’t free, it broadcasts a new copy

If it would be the case that using _project does not provide a computational advantage, then we could also omit it from the adjoint of getindex.

No, if you delete _project in JuliaManifolds/Manifolds.jl#256, then then gradient(A -> A[1,2], Diagonal([1,1])) and gradient(A -> A[1,2] - A[2,1], Symmetric(ones(2,2)))[1] will be nonzero.

Whether using _project elsewhere is a good idea is a different question. I’m not sure I’ve quite decoded your example here, but note that f(x) and hence B depend on global variables, so @code_warntype f(x) is not happy.

sdewaele commented 4 years ago

If we are okay with unconstrained adjoint, then it is correct, in the sense that it provides the correct gradients when part of a composition of functions. Just try a function like ftestA or ftestB in my original post.

Don't get me wrong, I think solution B, as you have implemented now, is great and mathematically the most satisfactory. I'm fine with proceeding with it!

sdewaele commented 4 years ago

As mentioned previously, I am okay with working with solution B as initially implemented. However, as discussed in https://github.com/FluxML/Zygote.jl/pull/256#issuecomment-562632961, it was decided not to use solution B.

Therefore, my proposal is to, at least for now, use solution A, but now implemented simply using the definition of _zero from this commit: https://github.com/FluxML/Zygote.jl/blob/73fdfa9db68941509937d76dcddc98aecdf421dd/src/lib/array.jl#L25-L26 So, to make it explicit, _project is not used, until the aforementioned discussion is completed.

I'd like to show that this approach does result in the correct gradients. I will use the first of @mcabbott's examles:

No, if you delete _project in JuliaManifolds/Manifolds.jl#256, then then gradient(A -> A[1,2], Diagonal([1,1])) and gradient(A -> A[1,2] - A[2,1], Symmetric(ones(2,2)))[1] will be nonzero.

The gradient that is computed using the definition of _zero as above, indeed yields a nonzero gradient:

using Zygote
using LinearAlgebra

Zygote._zero(xs::AbstractArray{<:Number}, T=float(eltype(xs))) =
fill!(similar(xs, T, size(xs)), false)

g = gradient(A -> A[1,2], Diagonal([1,1]))[1]
display(collect(g))
# 2×2 Array{Int64,2}:
# 0  1
# 0  0

The meaning of a gradient is that if we take an inner product with a disturbance Δx in the input, it returns the disturbance in the output. In this case, all disturbances Δx are diagonal matrices. The inner product of any diagonal matrix with the computed gradient is zero, and we obtain the correct result that the disturbance in the output is zero. So, the final result is identical to the result that we would have found after projection (solution B), which results in a zero gradient. We can further see the correctness by using this gradient in a larger function:

function f2(x)
  A = Diagonal(x)
  return A[1,2]
end
g2 = gradient(f2,[1,1])[1]
display(collect(g2))
JuliaManifolds/Manifolds.jl#2-element Array{Int64,1}:
# 0
# 0

Which is the correct result.

So, I propose to implement this small fix now. Instead of throwing an error on specialized matrices, the correct gradients are computed, which is progress!

mcabbott commented 4 years ago

To clarify, this 3-arg similar makes a full matrix, while the 2-arg one makes a Diagonal, on which off-diagonal setindex is an error. One reason to avoid that in general is that using StaticArrays; similar(SA[1,2,3], Int) isa MArray but similar(SA[1,2,3], Int, (3,)) isa Array. The maximal JuliaManifolds/Manifolds.jl#256 therefore used this only on special LinearAlgebra types.

But that aside, the proposal is to treat the zeros of Diagonal as being not structurally zero, just accidentally zero. This is how sparse arrays are currently treated, for example gradient(A -> A[1,3], sparse(Diagonal(ones(3)))) is happy to keep the off-diagonal 1. I don’t think that’s a great way to treat Diagonal, whose structure is in the type! But indeed the difference drops out of some calculations, as in your example.

sdewaele commented 4 years ago

Thanks! I agree that ideally this _zero should be applied to special LinearAlgebra types. I wonder if there is a mechanism (a trait perhaps?) that allows to distinguish them? In absence of that, a loop over the special types is an option as well.

As you say, Diagonal and sparse arrays are clearly different. When computing function values, Diagonal can only have diagonal elements. Only for the gradient computation, solution A uses the entire space of matrices.

If there is a agreement that we want to use it, I can write the code for the new proposal. Just to be clear, the result of all gradient computations will be correct - not just in "some calculations"; otherwise it would not be a good idea to implement it. We can add a couple of cases to the test suite to demonstrate this. The examples that you gave earlier are useful for that.