FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

In-Place Adjoints #120

Closed willtebbutt closed 5 years ago

willtebbutt commented 5 years ago

Most of the time you have to make a copy of the inputs to a primitive if said primitive modifies said inputs. There are some operations, however, which are efficiently invertible and for which it may, therefore, be advantageous to re-compute the inputs to the adjoint on the fly in the reverse-pass for the sake of keeping memory requirements low. e.g. ldiv! and cholesky! to name but two.

As far as I can tell, from Zygote's perspective they're both essentially the same thing and, this is where I might be wrong, they're both doable at the minute. My understanding is that you really just have to ensure that the state of the arguments is restored when the adjoint method is computed, and one can do this either by caching the inputs on the forwards-pass / avoid modifying them by working on copies of them, or reconstructing them at some point in the adjoint. Consider the following simple function and it's adjoint:

using Zygote
using Zygote: @adjoint
function add_vecs!(a ,b)
    a .+= b # do thing that modifies inputs
    return a
end
@adjoint function add_vecs!(a, b)
    c = add_vecs!(a, b)
    return c, function(Δ)
        a .-= b # restore inputs to original state
        return (Δ, Δ)
    end
end

As far as I can tell,

y, back! = Zygote.forward(add_vecs!, a, b)
back(ȳ)

produces the correct answer. The only thing that's maybe slightly iffy is that forward isn't forward!, but everything else seems to be fine. @MikeInnes am I missing anything, or are we okay to go ahead and start implementing in-place adjoints?

MikeInnes commented 5 years ago

I'm generally on board with this approach, and it's nicely compatible with what's currently in #75. But new mutating adjoints will have to go on that branch at the moment, rather than master, because we still need some extra infrastructure to support this properly.

For example, consider deriving something like

add_vecs!(a, b)
return sum(a)

Mutation gives us a kind of non-local dataflow, here from b to the function output, but in general it could be across functions or even across different tasks/threads (a gets updated on one thread and used from another). As it stands now, the sensitivity for add_sum! will be nothing and so we'll get incorrect gradients.

To solve this in general you need adjoints of mutable objects to also be mutable objects (and in particular, the same object each time), so that we can update gradient itself in-place. Rather than passing gradients around as values we do something like da = cache[a]; da .= Δ, where cache is a global IdDict of gradients. Much of the complexity in #75 comes from trying to do this automatically without changing the @adjoint syntax.

willtebbutt commented 5 years ago

A couple of clarifications: I'm assuming the example above is meant to read

function add_sum!(a, b)
    add_vecs!(a, b)
    return sum(a)
end

When you say deriving, do you mean implementing a custom adjoint, or running it through Zygote with just a custom adjoint for add_vecs!?

MikeInnes commented 5 years ago

Right, the issue is when running through Zygote. Building on your example:

julia> gradient([1], [1]) do a, b
         add_vecs!(a, b)
         return sum(a)
       end
(Int8[1], nothing)

julia> gradient([1], [1]) do a, b
         a = add_vecs!(a, b)
         return sum(a)
       end
(Int8[1], Int8[1])

It's possible to get the right answer on top of #75; I haven't tried it yet but would be happy to work through it. It might not be convenient yet.

willtebbutt commented 5 years ago

Interesting. I have to say that I'm confused as to why

julia> Zygote.gradient(add_sum!, randn(2), randn(2))
(Int8[1, 1], nothing)

but

julia> y, back = Zygote.forward(add_sum!, randn(2), randn(2))
(-0.007131090465492429, getfield(Zygote, Symbol("##73#74")){typeof(∂(add_sum!))}(∂(add_sum!)))
julia> back(1)
([1, 1], [1, 1])

as it should be. Is there some implicit assumption that gradient is making that manually calling forward and back doesn't? (I'm just running on master btw)

See comment below.

MikeInnes commented 5 years ago

That shouldn't be happening. Zygote.refresh() maybe?

willtebbutt commented 5 years ago

Okay, as I understand it this highlights the distinction between two ways of implementing the same functionality:

function add_sum!(a, b)
    a = add_vecs!(a, b)
    return sum(a)
end
function add_sum_!(a, b)
    add_vecs!(a, b)
    return sum(a)
end

where

julia> Zygote.gradient(add_sum!, randn(2), randn(2))
(Int8[1, 1], Int8[1, 1])

julia> Zygote.gradient(add_sum_!, randn(2), randn(2))
(Int8[1, 1], nothing)

Looking at the ir:

julia> Zygote.@code_adjoint add_sum!(a, b)
2 1 ─ %1 = (Zygote._forward)(_2, Main.add_vecs!, _4, _5)::Any                                                                        │
  │   %2 = (Base.getindex)(%1, 1)::Any                                                                                               │
  │        (Base.getindex)(%1, 2)::Any                                                                                               │
3 │   %4 = (Zygote._forward)(_2, Main.sum, %2)::Any                                                                                  │
  │   %5 = (Base.getindex)(%4, 1)::Any                                                                                               │
  │        (Base.getindex)(%4, 2)::Any                                                                                               │
  └──      return %5                                                                                                                 │
  1 ─ %1 = Δ()::Any                                                                                                                  │
3 │   %2 = (@6)(%1)::Any                                                                                                             │
  │   %3 = (Zygote.gradindex)(%2, 2)::Any                                                                                            │
2 │   %4 = (@3)(%3)::Any                                                                                                             │
  │   %5 = (Zygote.gradindex)(%4, 2)::Any                                                                                            │
  │   %6 = (Zygote.gradindex)(%4, 3)::Any                                                                                            │
  │   %7 = (Zygote.tuple)(nothing, %5, %6)::Any                                                                                      │
  └──      return %7                                                                                                                 │
, [1])

whereas

julia> Zygote.@code_adjoint add_sum_!(a, b)
2 1 ─ %1 = (Zygote._forward)(_2, Main.add_vecs!, _4, _5)::Any                                                                        │
  │        (Base.getindex)(%1, 1)::Any                                                                                               │
  │        (Base.getindex)(%1, 2)::Any                                                                                               │
3 │   %4 = (Zygote._forward)(_2, Main.sum, _4)::Any                                                                                  │
  │   %5 = (Base.getindex)(%4, 1)::Any                                                                                               │
  │        (Base.getindex)(%4, 2)::Any                                                                                               │
  └──      return %5                                                                                                                 │
  1 ─ %1 = Δ()::Any                                                                                                                  │
3 │   %2 = (@6)(%1)::Any                                                                                                             │
  │   %3 = (Zygote.gradindex)(%2, 2)::Any                                                                                            │
2 │   %4 = (@3)(nothing)::Any                                                                                                        │
  │   %5 = (Zygote.gradindex)(%4, 2)::Any                                                                                            │
  │   %6 = (Zygote.gradindex)(%4, 3)::Any                                                                                            │
  │   %7 = (Zygote.accum)(%3, %5)::Any                                                                                               │
  │   %8 = (Zygote.tuple)(nothing, %7, %6)::Any                                                                                      │
  └──      return %8                                                                                                                 │
, [1])

The crucial difference being that in the first example we call (@3)(%3) since a is an explicit output of add_vecs!, whereas in the second example (@3)(nothing) is called as a isn't an explicit output of add_vecs!. Have I more or less understood correctly?

I guess there's also the double-counting of the sensitivity w.r.t. a in the 2nd example: %7 = (Zygote.accum)(%3, %5)::Any and it's just a coincidence that %5 is nothing -- otherwise this would have been incorrect.

MikeInnes commented 5 years ago

Yes that's exactly it, the fact that the gradient for a here is essentially coincidental and it wouldn't be right for a different definition of add_vec!. The fact that we call (@3)(nothing) actually has mutation in mind; normally it would be redundant buts it's expected that the pullback will pull a gradient from somewhere else (the cache) and continue if needed.