EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
443 stars 62 forks source link

Enzyme requests reverse rule on Const returns #1380

Closed dominic-chang closed 6 months ago

dominic-chang commented 6 months ago

I'm trying to define custom autodiff rules on some special functions. Defining any custom autodiff rules however causes Enzyme to request a rule for Const returns. Here's a MWE

function augmented_primal(config::ConfigWidth{1}, func::Const{typeof(sqrt)}, ::Type{<:Active}, x::Active) 
    println("In custom augmented primal rule.")
    if needs_primal(config)
        primal = func.val(x.val)
    else
        primal = nothing
    end

    if overwritten(config)[2]
        tape = copy(x.val)
    else
        tape = nothing
    end

    return AugmentedReturn(primal, nothing, tape)
end

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Active, tape, x::Active) 
    println("In custom reverse rule.")
    xval = overwritten(config)[2] ? tape : x.val
    dx = inv(2*sqrt(xval))' * dret.val
    return (dx, )
end

The following error then occurs when a Const output is requested from a reverse diff

In custom augmented primal rule.
ERROR: Enzyme execution failed.
Enzyme: No custom reverse rule was applicable for Tuple{ConfigWidth{1, false, false, (false, false)}, Const{typeof(sqrt)}, Type{Const{Float64}}, Nothing, Active{Float64}}

Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:1289
 [2] macro expansion
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5378 [inlined]
 [3] enzyme_call
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:5056 [inlined]
 [4] CombinedAdjointThunk
   @ ~/.julia/packages/Enzyme/l4FS0/src/compiler.jl:4998 [inlined]
 [5] autodiff(::ReverseMode{false, FFIABI}, f::Const{typeof(sqrt)}, ::Type{Const}, args::Active{Float64})
   @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:215
 [6] autodiff(::ReverseMode{false, FFIABI}, ::typeof(sqrt), ::Type, ::Active{Float64})
   @ Enzyme ~/.julia/packages/Enzyme/l4FS0/src/Enzyme.jl:224
 [7] top-level scope
   @ REPL[10]:1
wsmoses commented 6 months ago

If a function could mutate things in place, returning a constant output is well defined. You need to also define a rule for the const output case (which for this function would do nothing since it is read-only).

dominic-chang commented 6 months ago

Sorry, I forgot to mention that I did try doing that, but still received the same error.

function reverse(config, ::Const{typeof(sqrt)}, dret::Const, tape, x::Active) 
    println("In custom reverse rule.")
    return (zero(x.val), )
end
autodiff(Enzyme.Reverse, sqrt, Const, Active(0.5))
ERROR: Enzyme execution failed.
Enzyme: No custom augmented_primal rule was applicable for Tuple{ConfigWidth{1, false, false, (false, false)}, Const{typeof(sqrt)}, Type{Const{Float64}}, Active{Float64}}
wsmoses commented 6 months ago

dret should be Type{<:Const}, not const. An actual value isn't passed unless it is active

dominic-chang commented 6 months ago

I receive the same error with this method signature

function reverse(config::ConfigWidth{1}, ::Const{typeof(sqrt)}, dret::Type{<:Const}, tape, x::Active) 
    println("In custom reverse rule.")
    xval = overwritten(config)[2] ? tape : x.val
    return (zero(x.val), )
end
wsmoses commented 6 months ago

You also need the corresponding augmented primal rule as well (the error above says augmented_primal) wasn't found

dominic-chang commented 6 months ago

Sorry. Your right. I missed that. This worked 😀