Closed sethaxen closed 1 year ago
This will be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/754 but will require a jll bump.
Thanks! For completeness, this now works:
julia> using Enzyme
julia> foo(x::Complex) = 2x;
julia> function EnzymeRules.augmented_primal(
config::EnzymeRules.ConfigWidth{1},
func::Const{typeof(foo)},
::Type{<:Active},
x::Active{<:Complex},
)
println("In custom augmented primal rule.")
# Compute primal
r = func.val(x.val)
if EnzymeRules.needs_primal(config)
primal = r
else
primal = nothing
end
if EnzymeRules.needs_shadow(config)
shadow = zero(r)
else
shadow = nothing
end
tape = nothing
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end
julia> function EnzymeRules.reverse(
config::EnzymeRules.ConfigWidth{1},
func::Const{typeof(foo)},
dret::Active{<:Complex},
tape,
y::Active{<:Complex},
)
println("In custom reverse rule.")
return (2*dret.val,)
end
julia> autodiff(Reverse, foo, Active, Active(1.0+3im))
In custom augmented primal rule.
In custom reverse rule.
((2.0 + 0.0im,),)
While writing #739, I ran into some difficulties defining rules for functions with complex inputs and outputs. Here's a simple example:
When I execute this rule, I get the following stacktrace:
I was surprised that Enzyme seems to insist on using Duplicated annotations for complex scalars. If I specify
Active
for the inputs as done above, they are replaced with aDuplicated
. Second, if I specifyshadow=nothing
, Enzyme complains that it expects theshadow
to be aComplexF64
, but if I make it aComplexF64
, then I see this error. How can I repair the above rules to work?