JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
257 stars 62 forks source link

thunk still runs for non Flux.params which leads to unnecessary computation #558

Closed ziyiyin97 closed 5 months ago

ziyiyin97 commented 2 years ago

Hello! I have a minimum example here about Flux and ChainRulesCore

using Flux, ChainRulesCore
import ChainRulesCore.rrule

function f(a::Float32, b::Float32)
    return a * b
end

function rrule(::typeof(f), a::Float32, b::Float32)
    println("rrule is called")
    y = f(a,b)
    function pullback(Δy)
        da = @thunk(∇a(a,b,Δy))
        db = @thunk(∇b(a,b,Δy))
        return (NoTangent(), da, db)
    end
    return y, pullback
end

function ∇a(a,b,Δy)
    println("∇a is called")
    return b * Δy
end

function ∇b(a,b,Δy)
    println("∇b is called")
    return a * Δy
end

a = 1f0
b = 2f0

ga = gradient(()->f(a,b), Flux.params(a))

which defines my custom function f (as a multiplication of 2 scalars) and defines the rrule from ChainRulesCore. In the last line, when I compute gradient w.r.t. variable a only, as ga = gradient(()->f(a,b), Flux.params(a)) , I expect to only see ∇a being called but actually I see both of ∇a and ∇b being called in the log

rrule is called
∇a is called
∇b is called

any idea why? This could be problematic when f is complicated function and it is unnecessary to call ∇b if time-consuming. Thanks for any help!

nickrobinson251 commented 2 years ago

Flux.jl uses Zygote.jl, and Zygote.jl doesn't yet utilise ChainRulesCore.jl's Thunks https://github.com/FluxML/Zygote.jl/blob/9602c6b2038879034c2de14d1f4aa251d99c6ea4/src/compiler/chainrules.jl#L104

There is a WIP PR to make Zygote.jl utilise Thunks here: https://github.com/FluxML/Zygote.jl/pull/966

ziyiyin97 commented 2 years ago

Thanks for your quick reply. Looking forward to the PR being merged

oxinabox commented 5 months ago

This is a Zygote problem not a CRC problem