FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Zygote differentiates through dropgrad even though the result is discarded #621

Open cossio opened 4 years ago

cossio commented 4 years ago
using Zygote
isderiving() = false  # https://fluxml.ai/Zygote.jl/latest/adjoints/#Gradient-Reflection-1
Zygote.@adjoint isderiving() = true, _ -> nothing
f(x) = (@show isderiving(); sum(x))
g(x) = Zygote.dropgrad(f(x))
g'(randn())  # isderiving() = true   (!)

Here g'(randn()) returns nothing, which is the correct answer because we are discarding the gradient of f. However, this prints isderiving() = true, indicating that the adjoint computation of f is still being done. This is wasteful, since this branch will be discarded anyway.

There should be a way for the gradient computation to completely bypass the branch through f.

cossio commented 4 years ago

Once https://github.com/FluxML/Zygote.jl/pull/623 is merged, we will be able to write:

using Zygote; using Zygote: @adjoint, ignore
isderiving() = false  # https://fluxml.ai/Zygote.jl/latest/adjoints/#Gradient-Reflection-1
@adjoint isderiving() = true, _ -> nothing
f(x) = (@show isderiving(); sum(x))
g(x) = ignore(() -> f(x))
g'(randn())  # isderiving() = false

So with ignore we really prune the computation of the gradient of f.