EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
439 stars 62 forks source link

Help writing complex rules #744

Closed sethaxen closed 1 year ago

sethaxen commented 1 year ago

While writing #739, I ran into some difficulties defining rules for functions with complex inputs and outputs. Here's a simple example:

foo(x::Complex) = 2x

function EnzymeRules.augmented_primal(
    config::EnzymeRules.ConfigWidth{1},
    func::Const{typeof(foo)},
    ::Type{<:Duplicated},
    x::Duplicated{<:Complex},
)
    println("In custom augmented primal rule.")
    # Compute primal
    r = func.val(x.val)
    if EnzymeRules.needs_primal(config)
        primal = r
    else
        primal = nothing
    end
    if EnzymeRules.needs_shadow(config)
        shadow = zero(r)
    else
        shadow = nothing
    end
    tape = nothing
    return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    config::EnzymeRules.ConfigWidth{1},
    func::Const{typeof(foo)},
    dret::Duplicated{<:Complex},
    tape,
    y::Duplicated{<:Complex},
)
    println("In custom reverse rule.")
    return ()
end

When I execute this rule, I get the following stacktrace:

julia> autodiff(Reverse, foo, Active, Active(1.0+3im))
ERROR: AssertionError: value_type(normalV) == value_type(orig)
Stacktrace:
  [1] enzyme_custom_common_rev(forward::Bool, B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tape::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:4038
  [2] enzyme_custom_augfwd(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, gutils::Ptr{Nothing}, normalR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, shadowR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}}, tapeR::Ptr{Ptr{LLVM.API.LLVMOpaqueValue}})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:4104
  [3] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/zSGqM/src/api.jl:124
  [4] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String}, expectedTapeType::Type)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:6698
  [5] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.ThreadSafeContext, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:7939
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, ctx::Nothing, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8452
  [7] _thunk
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8449 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8487 [inlined]
  [9] #s286#173
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8545 [inlined]
 [10] var"#s286#173"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ::Any, ::Any, ::Any, ::Any, tt::Any, ::Any, ::Any, ::Any, ::Any, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] thunk
    @ ~/.julia/packages/Enzyme/zSGqM/src/compiler.jl:8504 [inlined]
 [13] autodiff(#unused#::EnzymeCore.ReverseMode{false}, f::Const{typeof(foo)}, #unused#::Type{Active}, args::Active{ComplexF64})
    @ Enzyme ~/.julia/packages/Enzyme/zSGqM/src/Enzyme.jl:199
 [14] autodiff(::EnzymeCore.ReverseMode{false}, ::typeof(foo), ::Type, ::Active{ComplexF64})
    @ Enzyme ~/.julia/packages/Enzyme/zSGqM/src/Enzyme.jl:214
 [15] top-level scope
    @ REPL[18]:1

I was surprised that Enzyme seems to insist on using Duplicated annotations for complex scalars. If I specify Active for the inputs as done above, they are replaced with a Duplicated. Second, if I specify shadow=nothing, Enzyme complains that it expects the shadow to be a ComplexF64, but if I make it a ComplexF64, then I see this error. How can I repair the above rules to work?

wsmoses commented 1 year ago

This will be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/754 but will require a jll bump.

wsmoses commented 1 year ago

Fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/754

sethaxen commented 1 year ago

Thanks! For completeness, this now works:

julia> using Enzyme

julia> foo(x::Complex) = 2x;

julia> function EnzymeRules.augmented_primal(
           config::EnzymeRules.ConfigWidth{1},
           func::Const{typeof(foo)},
           ::Type{<:Active},
           x::Active{<:Complex},
       )
           println("In custom augmented primal rule.")
           # Compute primal
           r = func.val(x.val)
           if EnzymeRules.needs_primal(config)
               primal = r
           else
               primal = nothing
           end
           if EnzymeRules.needs_shadow(config)
               shadow = zero(r)
           else
               shadow = nothing
           end
           tape = nothing
           return EnzymeRules.AugmentedReturn(primal, shadow, tape)
       end

julia> function EnzymeRules.reverse(
           config::EnzymeRules.ConfigWidth{1},
           func::Const{typeof(foo)},
           dret::Active{<:Complex},
           tape,
           y::Active{<:Complex},
       )
           println("In custom reverse rule.")
           return (2*dret.val,)
       end

julia> autodiff(Reverse, foo, Active, Active(1.0+3im))
In custom augmented primal rule.
In custom reverse rule.
((2.0 + 0.0im,),)