dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Incorrect gradient when diff path goes through a `for` loop #126

Open dfdx opened 1 year ago

dfdx commented 1 year ago

From https://github.com/FluxML/NNlib.jl/pull/434#issuecomment-1235674312

function prod2(xs::Vector)
    p = one(eltype(xs))
    for x in xs
        p = p * x
        p == 0 && break  # exit early once you know the answer
    end
    p
end

ChainRulesCore.@non_differentiable eltype(::Any)

function main()
    x = rand(3)
    Yota.grad(prod2, x)
    _, tape = trace(prod2, x; ctx=GradCtx())
end