EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
445 stars 62 forks source link

Allow custom rule for constant arg/ret in rev mode #1371

Closed wsmoses closed 4 months ago

codecov-commenter commented 6 months ago

Codecov Report

Attention: Patch coverage is 91.17647% with 9 lines in your changes are missing coverage. Please review.

Project coverage is 70.55%. Comparing base (724b9bc) to head (73a3fc5).

Files Patch % Lines
src/rules/customrules.jl 90.21% 9 Missing :warning:

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #1371 +/- ## ========================================== - Coverage 75.40% 70.55% -4.86% ========================================== Files 36 36 Lines 10671 10276 -395 ========================================== - Hits 8047 7250 -797 - Misses 2624 3026 +402 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

michel2323 commented 6 months ago

This works with arrays. I've checked now with KA using CPU() and CUDABackend(). What is weird is that with CUDABackend() the rule gets triggered, however with CPU() not.

Also, in the case of CUDABackend() I get lots of these, with CPU() not: I get lots of:

┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Union{UInt64, String}}}
└ @ Enzyme /disk/mschanen/julia_depot/packages/GPUCompiler/U36Ed/src/utils.jl:59

MWE is here (local Enzyme_jll with Enzyme build is needed)

This only prints:

Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}

Instead of

Forward synchronize for Const{CPU}
Reverse synchronize for Const{CPU}
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}
michel2323 commented 6 months ago

Could it be that Enzyme takes decisions based on the code synchronize(backend::T) where T? I think in the case of CUDABackend() it goes through the KA.synchronize(::CUDABackend) = CUDA.synchronize(), whereas KA.synchronize(::CPU) = nothing probably gets removed as dead code somehow before a rule is applied.

@vchuravy Didn't we observe Enzyme hitting weird stuff with the KA kernel rules? I somehow wonder whether some stage of Enzyme doesn't go through the whole function that has a rule defined.

michel2323 commented 6 months ago

Narrowed it down.

using Enzyme
using EnzymeCore
using EnzymeCore.EnzymeRules

struct MyConst end
struct MyConst2
    v::Vector{Float64}
end
MyConst2() = MyConst2(zeros(5))

bar(x::MyConst)::Nothing = nothing
function bar(x::MyConst2)
    x.v .*= 2.0
    nothing
end

function foo(myconst)
    bar(myconst)
    return nothing
end

function EnzymeRules.augmented_primal(
    config::Config,
    func::Const{typeof(bar)},
    ::Type{Const{Nothing}},
    myconst::T
) where T <: EnzymeCore.Annotation
    println("bar aug_fwd rule $(typeof(myconst))")
    return AugmentedReturn{Nothing, Nothing, Any}(
        nothing, nothing, (nothing)
    )
end

function EnzymeRules.reverse(
    config::Config,
    func::Const{typeof(bar)},
    ::Type{Const{Nothing}},
    tape,
    myconst
)
    println("bar rev rule $(typeof(myconst))")
    return (nothing,)
end

function driver(myconst)
    println("Running $(typeof(myconst()))")
    Enzyme.autodiff(
        ReverseWithPrimal, foo, Const(myconst())
    )
end

# Doesn't trigger rules above
driver(MyConst)
# Triggers rules above
driver(MyConst2)

This outputs

Running MyConst
Running MyConst2
bar aug_fwd rule Const{MyConst2}
bar rev rule Const{MyConst2}

when the rule should also be applied in the case of MyConst.

michel2323 commented 6 months ago

As per our chat, this should be all resolved and ready to be merged.