EnzymeAD / Enzyme.jl

Julia bindings for the Enzyme automatic differentiator
https://enzyme.mit.edu
MIT License
455 stars 63 forks source link

Non-Deterministic behaviour from Enzyme Reverse Mode #1298

Closed dominic-chang closed 6 months ago

dominic-chang commented 8 months ago

A piece of my code uses Ferrari's method to solve for the roots of a quartic equation and then takes the difference between the roots. This code needs to be GPUized on my M1 mac, and so needs be type preserving. I have noticed that using reverse mode autodiff on this code sometimes results in NaNs for the derivatives, and actual numbers in a seemingly random way. Here is a MWE:

using Enzyme

function _pow(z::Complex{T}, i) where {T}
    zabs = abs(z)
    zangle = angle(z)
    return (zabs^i) * (cos(zangle * i) + sin(zangle * i) * one(T)im)
end

function _pow(z::T, i) where {T<:Real}
    zabs = abs(z)
    if sign(z) < zero(T)
        return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
    end
    return zabs^i + zero(T)im
end

"""
Checks if a complex number is real to some tolerance
"""
function _isreal2(num::Complex{T}) where {T}
    ren, imn = reim(num)
    ren2 = ren^2
    imn2 = imn^2
    return imn2 / (imn2 + ren2) < eps(T)
end

function Ferrari_Method(a::T, η, λ) where {T}
    a2 = a * a
    A = a2 - η - λ * λ
    A2 = A + A
    B = T(2) * (η + (λ - a)^2)
    C = -a2 * η

    P = -A * A / T(12) - C
    Q = -A / T(3) * (A * A / T(36) + zero(T)im - C) - B * B / T(8)

    Δ3 = -T(4) * P * P * P - T(27) * Q * Q
    ωp = _pow(-Q / T(2) + _pow(-Δ3 / T(108), T(0.5)) + zero(T)im, T(1 / 3))

    #C = ((-1+0im)^(2/3), (-1+0im)^(4/3), 1) .* ωp
    C = (-T(1 / 2) + T(√3 / 2)im, -T(1 / 2) - T(√3 / 2)im, one(T) + zero(T)im) .* ωp

    v = -P .* _pow.(T(3) .* C, -one(T))

    ξ0 = argmax(real, (C .+ v)) - A / T(3)
    ξ02 = ξ0 + ξ0

    predet1 = A2 + ξ02
    predet2 = (√T(2) * B) * _pow(ξ0, T(-0.5))
    det1 = _pow(-(predet1 - predet2), T(0.5))
    det2 = _pow(-(predet1 + predet2), T(0.5))

    sqrtξ02 = _pow(ξ02, T(0.5))

    r1 = (-sqrtξ02 - det1) / 2
    r2 = (-sqrtξ02 + det1) / 2
    r3 = (sqrtξ02 - det2) / 2
    r4 = (sqrtξ02 + det2) / 2

    roots = (r1, r2, r3, r4)
    if (sum(_isreal2.(roots)) == 2) && (abs(imag(roots[4])) < sqrt(eps(T)))
        roots = (roots[1], roots[4], roots[2], roots[3])
    end
    return roots
end

for _ in 1:10
    begin
        function test(ηtemp, λtemp)
            _, r2, r3, r4 = Ferrari_Method(0.5, ηtemp, λtemp)

            real((r3 - r2) * (r4 - r2))
        end

        println(autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836)))
    end
end

And here is an example output from this code:

((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)

I am using Enzyme v0.11.16 with juliav1.10.1 on a 2021, 16 inch MacBook Pro running macOS 14.3 (23D56)

wsmoses commented 8 months ago

@dchang10 I cannot reproduce this myself.

Right after using Enzyme, can you add Enzyme.API.printall!(true), pipe all output to a. file, and upload the logs?

dominic-chang commented 8 months ago

I'm not sure why, but I was also unable to reproduce the behavior I saw with the code I provided. I got a similar behavior by just running the loop at the end twice.

for _ in 1:10
    begin
        function test(ηtemp, λtemp)
            _, r2, r3, r4 = Ferrari_Method(0.5, ηtemp, λtemp)

            real((r3 - r2) * (r4 - r2))
        end

        println(autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836)))
    end
end
for _ in 1:10
    begin
        function test(ηtemp, λtemp)
            _, r2, r3, r4 = Ferrari_Method(0.5, ηtemp, λtemp)

            real((r3 - r2) * (r4 - r2))
        end

        println(autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836)))
    end
end

Here is the resulting output with the enzyme dump:

((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)

out.txt

wsmoses commented 8 months ago

This one is somewhat tricky, especially because nondeteminsric but on apparent compilation.

Maybe sticking this in an eval in the for loop will make it recompile each time?

Regardless the code here is too complex to debug. To the extent you can remove computation, especially calls to pow, even if it creates different results, would be helpful.

wsmoses commented 8 months ago

Sorry to the extent you can simplify the code, while still it can get a nan, would be helpful

dominic-chang commented 8 months ago

I've been trying to reduce the code down to something simpler, I've lost the Non-Deterministic behavior, but still have some flavor of it.

Enzyme is able to AD through the following code and return a primal

using Enzyme
#Enzyme.Compiler.CheckNan[] = true

function _pow(z::Complex{T}, i) where {T}
    zabs = abs(z)
    zangle = angle(z)
    return (zabs^i) * (cos(zangle * i) + sin(zangle * i) * one(T)im)
end

function _pow(z::T, i) where {T<:Real}
    zabs = abs(z)
    if sign(z) < zero(T)
        return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
    end
    return zabs^i + zero(T)im
end

function Ferrari_Method(η::Float64, λ)
    C::Union{Float64,NTuple{2,ComplexF64}} =  η

    ωp = _pow(-λ / 2 + _pow(-η / 108, 0.5) + 0im, 1 / 3)

    C = ((√3 / 2)*im, 1im) .* ωp

    #_pow(-(η- λ), 0.5)

    return C[1]

end

function test(ηtemp, λtemp)
    return real(Ferrari_Method(ηtemp, λtemp))
end

autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836))

which produces:

((-0.13044933090281938, -0.27712000773357137), -0.0741409227434657)

Uncommenting the #_pow(-(η- λ), 0.5) however results in NaNs a the output

using Enzyme
#Enzyme.Compiler.CheckNan[] = true

function _pow(z::Complex{T}, i) where {T}
    zabs = abs(z)
    zangle = angle(z)
    return (zabs^i) * (cos(zangle * i) + sin(zangle * i) * one(T)im)
end

function _pow(z::T, i) where {T<:Real}
    zabs = abs(z)
    if sign(z) < zero(T)
        return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
    end
    return zabs^i + zero(T)im
end

function Ferrari_Method(η::Float64, λ)
    C::Union{Float64,NTuple{2,ComplexF64}} =  η

    ωp = _pow(-λ / 2 + _pow(-η / 108, 0.5) + 0im, 1 / 3)

    C = ((√3 / 2)*im, 1im) .* ωp

    _pow(-(η- λ), 0.5)

    return C[1]

end

function test(ηtemp, λtemp)
    return real(Ferrari_Method(ηtemp, λtemp))
end

autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836))

with output

((-0.0, NaN), NaN)

dominic-chang commented 8 months ago

I forgot to mention that inlining the _pow functions allows enzyme to return derivatives normally.

((-0.13044933090281938, -0.27712000773357137), -0.0741409227434657)
dominic-chang commented 8 months ago

I'm not sure if this is related or not, but I notice that I found an error in the process of trying to produce an MWE which prevented me from taking the derivative of operations that involve sqrt of complex numbers.

using Enzyme
Enzyme.API.printall!(true)

function test(η)
    return abs(sqrt(-η + 0im))
end
test(1.0)

autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(1.0))

with stacktrace:

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1289
  [2] |
    @ ./int.jl:372 [inlined]
  [3] ldexp
    @ ./math.jl:964 [inlined]
  [4] sqrt
    @ ./complex.jl:541 [inlined]
  [5] test
    @ ~/Software/Krang.jl/examples/mwe.jl:5 [inlined]
  [6] diffejulia_test_1383wrap
    @ ~/Software/Krang.jl/examples/mwe.jl:0
  [7] macro expansion
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
  [8] enzyme_call(::Val{false}, ::Ptr{Nothing}, ::Type{Enzyme.Compiler.CombinedAdjointThunk}, ::Type{Val{1}}, ::Val{true}, ::Type{Tuple{Active{Float64}}}, ::Type{Active{Float64}}, ::Const{typeof(test)}, ::Type{Nothing}, ::Active{Float64}, ::Float64)
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5118
  [9] (::Enzyme.Compiler.CombinedAdjointThunk{Ptr{Nothing}, Const{typeof(test)}, Active{Float64}, Tuple{Active{Float64}}, Val{1}, Val{true}()})(::Const{typeof(test)}, ::Active{Float64}, ::Vararg{Any})
    @ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5000
 [10] autodiff
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:275 [inlined]
 [11] autodiff(mode::ReverseMode{true, FFIABI, false}, f::typeof(test), ::Type{Active}, args::Active{Float64})
    @ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
 [12] top-level scope
    @ ~/Software/Krang.jl/examples/mwe.jl:9
in expression starting at /Users/dominicchang/Software/Krang.jl/examples/mwe.jl:9

Here is the attached enzyme output. out.txt

wsmoses commented 6 months ago

@dominic-chang so I have been looking into this and somehow the original primal code creates an undefined value [which when we differentiate means can be arbitrary results]. Looking into how/why.

Simplified version:

wmoses@beast:~/git/Enzyme.jl ((HEAD detached from 43193d62)) $ cat nan.jl 
using Enzyme

Enzyme.API.printall!(true)

#Enzyme.Compiler.CheckNan[] = true

function _pow(z::T, i) where {T<:Real}
    zabs = abs(z)
    if sign(z) < zero(T)
        return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
    end
    return zabs^i + zero(T)im
end

function test(n)
    wp = 1 + _pow(-n, 0.5)

    _pow(-n, 0.5)

    return real(wp)
end

test(0.25)
@show autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25))

LLVM:

after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_test_1422(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) local_unnamed_addr #23 !dbg !2227 {
top:
  %1 = call {}*** @julia.get_pgcstack() #24
  %ptls_field3 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !21
  %3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !25
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #24, !dbg !2228
  fence syncscope("singlethread") seq_cst
  %4 = fneg double %0, !dbg !2229
  call fastcc void @julia__pow_1425(double %4) #24, !dbg !2230
  %5 = fadd double undef, 1.000000e+00, !dbg !2231
  call fastcc void @julia__pow_1425(double %4) #24, !dbg !2234
  ret double %5, !dbg !2235
}

after simplification :
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2245 {
top:
  %1 = call {}*** @julia.get_pgcstack() #25
  %ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field7 to i64***
  %ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21
  %3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !25
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #25, !dbg !2246
  fence syncscope("singlethread") seq_cst
  %4 = call double @llvm.fabs.f64(double %0) #25, !dbg !2247
  %5 = fcmp uge double %0, 0.000000e+00, !dbg !2249
  %6 = fcmp ule double %0, 0.000000e+00, !dbg !2252
  %7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2254
  %8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2254
  %9 = fcmp uge double %8, 0.000000e+00, !dbg !2255
  br i1 %9, label %L22, label %L8, !dbg !2251

common.ret:                                       ; preds = %L22, %L8
  ret void, !dbg !2256

L8:                                               ; preds = %top
  %10 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2257
  %11 = call double @julia_cos_1438(double noundef 0x3FF921FB54442D18) #25, !dbg !2257
  %12 = call double @julia_sin_1430(double noundef 0x3FF921FB54442D18) #25, !dbg !2257
  br label %common.ret

L22:                                              ; preds = %top
  %13 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2258
  br label %common.ret
}

; Function Attrs: mustprogress willreturn
define internal fastcc { double } @diffejulia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2259 {
top:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %2 = call {}*** @julia.get_pgcstack() #25
  %ptls_field7 = getelementptr inbounds {}**, {}*** %2, i64 2
  %3 = bitcast {}*** %ptls_field7 to i64***
  %ptls_load89 = load i64**, i64*** %3, align 8, !tbaa !21, !alias.scope !2260, !noalias !2263
  %4 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
  %safepoint = load i64*, i64** %4, align 8, !tbaa !25, !alias.scope !2265, !noalias !2268
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #25, !dbg !2270
  fence syncscope("singlethread") seq_cst
  %5 = fcmp uge double %0, 0.000000e+00, !dbg !2271
  %6 = fcmp ule double %0, 0.000000e+00, !dbg !2274
  %7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2276
  %8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2276
  %9 = fcmp uge double %8, 0.000000e+00, !dbg !2277
  br i1 %9, label %L22, label %L8, !dbg !2273

common.ret:                                       ; preds = %L22, %L8
  br label %invertcommon.ret, !dbg !2278

L8:                                               ; preds = %top
  br label %common.ret

L22:                                              ; preds = %top
  br label %common.ret

inverttop:                                        ; preds = %invertL22, %invertL8
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  %10 = load double, double* %"'de", align 8
  %11 = insertvalue { double } undef, double %10, 0
  ret { double } %11

invertcommon.ret:                                 ; preds = %common.ret
  br i1 %9, label %invertL22, label %invertL8

invertL8:                                         ; preds = %invertcommon.ret
  br label %inverttop

invertL22:                                        ; preds = %invertcommon.ret
  br label %inverttop
}

after simplification :
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia__pow_1425.3(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2279 {
top:
  %1 = call {}*** @julia.get_pgcstack() #25
  %ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field7 to i64***
  %ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21
  %3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !25
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #25, !dbg !2280
  fence syncscope("singlethread") seq_cst
  %4 = call double @llvm.fabs.f64(double %0) #25, !dbg !2281
  %5 = fcmp uge double %0, 0.000000e+00, !dbg !2283
  %6 = fcmp ule double %0, 0.000000e+00, !dbg !2286
  %7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2288
  %8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2288
  %9 = fcmp uge double %8, 0.000000e+00, !dbg !2289
  br i1 %9, label %L22, label %L8, !dbg !2285

common.ret:                                       ; preds = %L22, %L8
  ret void, !dbg !2290

L8:                                               ; preds = %top
  %10 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2291
  %11 = call double @julia_cos_1438(double noundef 0x3FF921FB54442D18) #25, !dbg !2291
  %12 = call double @julia_sin_1430(double noundef 0x3FF921FB54442D18) #25, !dbg !2291
  br label %common.ret

L22:                                              ; preds = %top
  %13 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2292
  br label %common.ret
}

; Function Attrs: mustprogress willreturn
define internal fastcc void @augmented_julia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2293 {
top:
  %1 = call {}*** @julia.get_pgcstack() #25
  %ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field7 to i64***
  %ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21, !alias.scope !2294, !noalias !2297
  %3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !25, !alias.scope !2299, !noalias !2302
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #25, !dbg !2304
  fence syncscope("singlethread") seq_cst
  %4 = fcmp uge double %0, 0.000000e+00, !dbg !2305
  %5 = fcmp ule double %0, 0.000000e+00, !dbg !2308
  %6 = select i1 %5, double %0, double 1.000000e+00, !dbg !2310
  %7 = select i1 %4, double %6, double -1.000000e+00, !dbg !2310
  %8 = fcmp uge double %7, 0.000000e+00, !dbg !2311
  br i1 %8, label %L22, label %L8, !dbg !2307

common.ret:                                       ; preds = %L22, %L8
  ret void, !dbg !2312

L8:                                               ; preds = %top
  br label %common.ret

L22:                                              ; preds = %top
  br label %common.ret
}

; Function Attrs: mustprogress willreturn
define internal fastcc { double } @diffejulia__pow_1425.4(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2313 {
top:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %2 = call {}*** @julia.get_pgcstack() #25
  %3 = fcmp uge double %0, 0.000000e+00, !dbg !2314
  %4 = fcmp ule double %0, 0.000000e+00, !dbg !2317
  %5 = select i1 %4, double %0, double 1.000000e+00, !dbg !2319
  %6 = select i1 %3, double %5, double -1.000000e+00, !dbg !2319
  %7 = fcmp uge double %6, 0.000000e+00, !dbg !2320
  br i1 %7, label %L22, label %L8, !dbg !2316

common.ret:                                       ; preds = %L22, %L8
  br label %invertcommon.ret, !dbg !2321

L8:                                               ; preds = %top
  br label %common.ret

L22:                                              ; preds = %top
  br label %common.ret

inverttop:                                        ; preds = %invertL22, %invertL8
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  %8 = load double, double* %"'de", align 8
  %9 = insertvalue { double } undef, double %8, 0
  ret { double } %9

invertcommon.ret:                                 ; preds = %common.ret
  br i1 %7, label %invertL22, label %invertL8

invertL8:                                         ; preds = %invertcommon.ret
  br label %inverttop

invertL22:                                        ; preds = %invertcommon.ret
  br label %inverttop
}

; Function Attrs: mustprogress willreturn
define internal "enzyme_type"="{[-1]:Float@double}" { double, double } @diffejulia_test_1422(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0, double %differeturn) local_unnamed_addr #23 !dbg !2236 {
top:
  %"'de" = alloca double, align 8
  %1 = getelementptr double, double* %"'de", i64 0
  store double 0.000000e+00, double* %1, align 8
  %"'de1" = alloca double, align 8
  %2 = getelementptr double, double* %"'de1", i64 0
  store double 0.000000e+00, double* %2, align 8
  %toreturn = alloca double, align 8
  %3 = call {}*** @julia.get_pgcstack() #25
  %ptls_field3 = getelementptr inbounds {}**, {}*** %3, i64 2
  %4 = bitcast {}*** %ptls_field3 to i64***
  %ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !21, !alias.scope !2237, !noalias !2240
  %5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
  %safepoint = load i64*, i64** %5, align 8, !tbaa !25, !alias.scope !2242, !noalias !2245
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #25, !dbg !2247
  fence syncscope("singlethread") seq_cst
  %6 = fneg double %0, !dbg !2248
  call fastcc void @augmented_julia__pow_1425(double %6), !dbg !2249
  %7 = fadd double undef, 1.000000e+00, !dbg !2250
  br label %inverttop, !dbg !2253

inverttop:                                        ; preds = %top
  %8 = call fastcc { double } @diffejulia__pow_1425(double %6), !dbg !2254
  %9 = extractvalue { double } %8, 0, !dbg !2254
  %10 = load double, double* %"'de", align 8, !dbg !2254
  %11 = fadd fast double %10, %9, !dbg !2254
  store double %11, double* %"'de", align 8, !dbg !2254
  store double %7, double* %toreturn, align 8, !dbg !2253
  %12 = call fastcc { double } @diffejulia__pow_1425.4(double %6), !dbg !2249
  %13 = extractvalue { double } %12, 0, !dbg !2249
  %14 = load double, double* %"'de", align 8, !dbg !2249
  %15 = fadd fast double %14, %13, !dbg !2249
  store double %15, double* %"'de", align 8, !dbg !2249
  %16 = load double, double* %"'de", align 8, !dbg !2248
  store double 0.000000e+00, double* %"'de", align 8, !dbg !2248
  %17 = fneg fast double %16, !dbg !2248
  %18 = load double, double* %"'de1", align 8, !dbg !2248
  %19 = fadd fast double %18, %17, !dbg !2248
  store double %19, double* %"'de1", align 8, !dbg !2248
  fence syncscope("singlethread") seq_cst
  fence syncscope("singlethread") seq_cst
  %retreload = load double, double* %toreturn, align 8
  %20 = load double, double* %"'de1", align 8
  %21 = insertvalue { double, double } undef, double %retreload, 0
  %22 = insertvalue { double, double } %21, double %20, 1
  ret { double, double } %22
}

autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25)) = ((-0.0,), NaN)
wsmoses commented 6 months ago

Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1421

dominic-chang commented 6 months ago

Wow amazing! Hope I was more helpful than less.