Closed willtebbutt closed 5 years ago
I'm generally on board with this approach, and it's nicely compatible with what's currently in #75. But new mutating adjoints will have to go on that branch at the moment, rather than master, because we still need some extra infrastructure to support this properly.
For example, consider deriving something like
add_vecs!(a, b)
return sum(a)
Mutation gives us a kind of non-local dataflow, here from b
to the function output, but in general it could be across functions or even across different tasks/threads (a
gets updated on one thread and used from another). As it stands now, the sensitivity for add_sum!
will be nothing
and so we'll get incorrect gradients.
To solve this in general you need adjoints of mutable objects to also be mutable objects (and in particular, the same object each time), so that we can update gradient itself in-place. Rather than passing gradients around as values we do something like da = cache[a]; da .= Δ
, where cache
is a global IdDict
of gradients. Much of the complexity in #75 comes from trying to do this automatically without changing the @adjoint
syntax.
A couple of clarifications: I'm assuming the example above is meant to read
function add_sum!(a, b)
add_vecs!(a, b)
return sum(a)
end
When you say deriving, do you mean implementing a custom adjoint, or running it through Zygote
with just a custom adjoint for add_vecs!
?
Right, the issue is when running through Zygote. Building on your example:
julia> gradient([1], [1]) do a, b
add_vecs!(a, b)
return sum(a)
end
(Int8[1], nothing)
julia> gradient([1], [1]) do a, b
a = add_vecs!(a, b)
return sum(a)
end
(Int8[1], Int8[1])
It's possible to get the right answer on top of #75; I haven't tried it yet but would be happy to work through it. It might not be convenient yet.
Interesting. I have to say that I'm confused as to why
julia> Zygote.gradient(add_sum!, randn(2), randn(2))
(Int8[1, 1], nothing)
but
julia> y, back = Zygote.forward(add_sum!, randn(2), randn(2))
(-0.007131090465492429, getfield(Zygote, Symbol("##73#74")){typeof(∂(add_sum!))}(∂(add_sum!)))
julia> back(1)
([1, 1], [1, 1])
as it should be. Is there some implicit assumption that gradient
is making that manually calling forward
and back
doesn't? (I'm just running on master btw)
See comment below.
That shouldn't be happening. Zygote.refresh()
maybe?
Okay, as I understand it this highlights the distinction between two ways of implementing the same functionality:
function add_sum!(a, b)
a = add_vecs!(a, b)
return sum(a)
end
function add_sum_!(a, b)
add_vecs!(a, b)
return sum(a)
end
where
julia> Zygote.gradient(add_sum!, randn(2), randn(2))
(Int8[1, 1], Int8[1, 1])
julia> Zygote.gradient(add_sum_!, randn(2), randn(2))
(Int8[1, 1], nothing)
Looking at the ir:
julia> Zygote.@code_adjoint add_sum!(a, b)
2 1 ─ %1 = (Zygote._forward)(_2, Main.add_vecs!, _4, _5)::Any │
│ %2 = (Base.getindex)(%1, 1)::Any │
│ (Base.getindex)(%1, 2)::Any │
3 │ %4 = (Zygote._forward)(_2, Main.sum, %2)::Any │
│ %5 = (Base.getindex)(%4, 1)::Any │
│ (Base.getindex)(%4, 2)::Any │
└── return %5 │
1 ─ %1 = Δ()::Any │
3 │ %2 = (@6)(%1)::Any │
│ %3 = (Zygote.gradindex)(%2, 2)::Any │
2 │ %4 = (@3)(%3)::Any │
│ %5 = (Zygote.gradindex)(%4, 2)::Any │
│ %6 = (Zygote.gradindex)(%4, 3)::Any │
│ %7 = (Zygote.tuple)(nothing, %5, %6)::Any │
└── return %7 │
, [1])
whereas
julia> Zygote.@code_adjoint add_sum_!(a, b)
2 1 ─ %1 = (Zygote._forward)(_2, Main.add_vecs!, _4, _5)::Any │
│ (Base.getindex)(%1, 1)::Any │
│ (Base.getindex)(%1, 2)::Any │
3 │ %4 = (Zygote._forward)(_2, Main.sum, _4)::Any │
│ %5 = (Base.getindex)(%4, 1)::Any │
│ (Base.getindex)(%4, 2)::Any │
└── return %5 │
1 ─ %1 = Δ()::Any │
3 │ %2 = (@6)(%1)::Any │
│ %3 = (Zygote.gradindex)(%2, 2)::Any │
2 │ %4 = (@3)(nothing)::Any │
│ %5 = (Zygote.gradindex)(%4, 2)::Any │
│ %6 = (Zygote.gradindex)(%4, 3)::Any │
│ %7 = (Zygote.accum)(%3, %5)::Any │
│ %8 = (Zygote.tuple)(nothing, %7, %6)::Any │
└── return %8 │
, [1])
The crucial difference being that in the first example we call (@3)(%3)
since a
is an explicit output of add_vecs!
, whereas in the second example (@3)(nothing)
is called as a
isn't an explicit output of add_vecs!
. Have I more or less understood correctly?
I guess there's also the double-counting of the sensitivity w.r.t. a
in the 2nd example: %7 = (Zygote.accum)(%3, %5)::Any
and it's just a coincidence that %5
is nothing
-- otherwise this would have been incorrect.
Yes that's exactly it, the fact that the gradient for a
here is essentially coincidental and it wouldn't be right for a different definition of add_vec!
. The fact that we call (@3)(nothing)
actually has mutation in mind; normally it would be redundant buts it's expected that the pullback will pull a gradient from somewhere else (the cache) and continue if needed.
Most of the time you have to make a copy of the inputs to a primitive if said primitive modifies said inputs. There are some operations, however, which are efficiently invertible and for which it may, therefore, be advantageous to re-compute the inputs to the adjoint on the fly in the reverse-pass for the sake of keeping memory requirements low. e.g.
ldiv!
andcholesky!
to name but two.As far as I can tell, from Zygote's perspective they're both essentially the same thing and, this is where I might be wrong, they're both doable at the minute. My understanding is that you really just have to ensure that the state of the arguments is restored when the adjoint method is computed, and one can do this either by caching the inputs on the forwards-pass / avoid modifying them by working on copies of them, or reconstructing them at some point in the adjoint. Consider the following simple function and it's adjoint:
As far as I can tell,
produces the correct answer. The only thing that's maybe slightly iffy is that
forward
isn'tforward!
, but everything else seems to be fine. @MikeInnes am I missing anything, or are we okay to go ahead and start implementing in-place adjoints?