FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Unnecessary gradients #658

Open cossio opened 4 years ago

cossio commented 4 years ago
using Zygote
# https://fluxml.ai/Zygote.jl/latest/adjoints/#Gradient-Reflection-1
isderiving() = false
@adjoint isderiving() = true, _ -> nothing
g(x) = (println("g: ", isderiving()); x)
h(x) = (println("h: ", isderiving()); 2x)
f(x,y) = g(x) + h(y)
gradient(1.0) do x
    f(x, 2.0)
end

Executing this code shows that both g and h are being differentiated by Zygote:

g: true
h: true

Obviously here we only want to differentiate g, not h.

Maybe related to https://github.com/FluxML/Zygote.jl/issues/621, https://github.com/FluxML/Zygote.jl/issues/571.

MikeInnes commented 4 years ago

I don't think differentiating g is unnecessary here, since the output of the code actually depends on whether g gets differentiated. I take the point that it's redundant if you remove the print statement (which is #571).

Zygote should be better at optimising away unnecessary work, but by definition it's only unnecessary if you can't see its effects.

cossio commented 4 years ago

But why should we care about preserving the side effects of adjoint code? I mean, this is the original program:

isderiving() = false
g(x) = (println("g: ", isderiving()); x)
h(x) = (println("h: ", isderiving()); 2x)
f(x,y) = g(x) + h(y)

I think we should only care about preserving the semantics of this original program. And (ignoring the "hacky" isderiving()), that's satisfied here whether g is derived or not.

MikeInnes commented 4 years ago

But why should we care about preserving the side effects of adjoint code?

For one thing, side effects in the adjoint code often mirror those in the primal, e.g. when you're mutating a dictionary. Ignoring those cases wouldn't just remove a print statement, it'd get you incorrect gradients. That's the issue discussed in #571.

Aside from that, we could potentially warn people that Zygote might replace y, _ = pullback(f, x) with y = f(x) when it doesn't need the gradient, making functions like zygote_isderiving incorrect (and you'd have to assume code using it produces nonsense).

But this wouldn't help us in the short term. The main problem right now is that we don't do dead code elimination, not that we can't. I'm not sure adding this guarantee would actually open significantly more optimisation opportunities.

cossio commented 4 years ago

Is there a situation where the pullback(f, x) has side effects which differ from the side effects of f(x)?

(Again, ignoring the "pathological" isderiving).

MikeInnes commented 4 years ago

pullback(f, x) usually doesn't, although I couldn't rule out the possibility (it's common to compute f(x) differently inside the pullback for efficiency reasons – e.g. splitting broadcasts rather than fusing them – so this could apply to side effects as well in principle).

The main point is that the pullback function itself has side effects in general, so you need to call pullback even if the output of back(dy) is not used.

cossio commented 4 years ago

Can't we say that it's the responsibility of the user (or whoever writes an explicit adjoint rule) to make sure that pullback(f,x) and f(x) are interchangeable?

Maybe that's a strong restriction, but I don't see why.

And if we rely on this assumption, then Zygote can decide not to call pullback(f,x) if it's not necessary to compute a gradient.

MikeInnes commented 4 years ago

We could; that's the option I addressed in detail in my previous comment.

I think there's some confusion over the behaviour / side effects of y, b = pullback(f, x), vs the side effects of the pullback itself b(dy). The former is what you're discussing, but the latter is what actually prevents us doing simple DCE, since those side effects are needed to compute gradients. Are you clear on that distinction?