Open cossio opened 4 years ago
Once https://github.com/FluxML/Zygote.jl/pull/623 is merged, we will be able to write:
using Zygote; using Zygote: @adjoint, ignore
isderiving() = false # https://fluxml.ai/Zygote.jl/latest/adjoints/#Gradient-Reflection-1
@adjoint isderiving() = true, _ -> nothing
f(x) = (@show isderiving(); sum(x))
g(x) = ignore(() -> f(x))
g'(randn()) # isderiving() = false
So with ignore
we really prune the computation of the gradient of f
.
Here
g'(randn())
returnsnothing
, which is the correct answer because we are discarding the gradient off
. However, this printsisderiving() = true
, indicating that the adjoint computation off
is still being done. This is wasteful, since this branch will be discarded anyway.There should be a way for the gradient computation to completely bypass the branch through
f
.