Closed dominic-chang closed 6 months ago
@dchang10 I cannot reproduce this myself.
Right after using Enzyme, can you add Enzyme.API.printall!(true)
, pipe all output to a. file, and upload the logs?
I'm not sure why, but I was also unable to reproduce the behavior I saw with the code I provided. I got a similar behavior by just running the loop at the end twice.
for _ in 1:10
begin
function test(ηtemp, λtemp)
_, r2, r3, r4 = Ferrari_Method(0.5, ηtemp, λtemp)
real((r3 - r2) * (r4 - r2))
end
println(autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836)))
end
end
for _ in 1:10
begin
function test(ηtemp, λtemp)
_, r2, r3, r4 = Ferrari_Method(0.5, ηtemp, λtemp)
real((r3 - r2) * (r4 - r2))
end
println(autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836)))
end
end
Here is the resulting output with the enzyme dump:
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((0.6561537762449112, -1.5304304926046994), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
((NaN, NaN), 1.1402929364197087)
This one is somewhat tricky, especially because nondeteminsric but on apparent compilation.
Maybe sticking this in an eval in the for loop will make it recompile each time?
Regardless the code here is too complex to debug. To the extent you can remove computation, especially calls to pow, even if it creates different results, would be helpful.
Sorry to the extent you can simplify the code, while still it can get a nan, would be helpful
I've been trying to reduce the code down to something simpler, I've lost the Non-Deterministic behavior, but still have some flavor of it.
Enzyme is able to AD through the following code and return a primal
using Enzyme
#Enzyme.Compiler.CheckNan[] = true
function _pow(z::Complex{T}, i) where {T}
zabs = abs(z)
zangle = angle(z)
return (zabs^i) * (cos(zangle * i) + sin(zangle * i) * one(T)im)
end
function _pow(z::T, i) where {T<:Real}
zabs = abs(z)
if sign(z) < zero(T)
return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
end
return zabs^i + zero(T)im
end
function Ferrari_Method(η::Float64, λ)
C::Union{Float64,NTuple{2,ComplexF64}} = η
ωp = _pow(-λ / 2 + _pow(-η / 108, 0.5) + 0im, 1 / 3)
C = ((√3 / 2)*im, 1im) .* ωp
#_pow(-(η- λ), 0.5)
return C[1]
end
function test(ηtemp, λtemp)
return real(Ferrari_Method(ηtemp, λtemp))
end
autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836))
which produces:
((-0.13044933090281938, -0.27712000773357137), -0.0741409227434657)
Uncommenting the #_pow(-(η- λ), 0.5)
however results in NaN
s a the output
using Enzyme
#Enzyme.Compiler.CheckNan[] = true
function _pow(z::Complex{T}, i) where {T}
zabs = abs(z)
zangle = angle(z)
return (zabs^i) * (cos(zangle * i) + sin(zangle * i) * one(T)im)
end
function _pow(z::T, i) where {T<:Real}
zabs = abs(z)
if sign(z) < zero(T)
return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
end
return zabs^i + zero(T)im
end
function Ferrari_Method(η::Float64, λ)
C::Union{Float64,NTuple{2,ComplexF64}} = η
ωp = _pow(-λ / 2 + _pow(-η / 108, 0.5) + 0im, 1 / 3)
C = ((√3 / 2)*im, 1im) .* ωp
_pow(-(η- λ), 0.5)
return C[1]
end
function test(ηtemp, λtemp)
return real(Ferrari_Method(ηtemp, λtemp))
end
autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25), Active(-0.14618585236136836))
with output
((-0.0, NaN), NaN)
I forgot to mention that inlining the _pow
functions allows enzyme to return derivatives normally.
((-0.13044933090281938, -0.27712000773357137), -0.0741409227434657)
I'm not sure if this is related or not, but I notice that I found an error in the process of trying to produce an MWE which prevented me from taking the derivative of operations that involve sqrt
of complex numbers.
using Enzyme
Enzyme.API.printall!(true)
function test(η)
return abs(sqrt(-η + 0im))
end
test(1.0)
autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(1.0))
with stacktrace:
Stacktrace:
[1] throwerr(cstr::Cstring)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:1289
[2] |
@ ./int.jl:372 [inlined]
[3] ldexp
@ ./math.jl:964 [inlined]
[4] sqrt
@ ./complex.jl:541 [inlined]
[5] test
@ ~/Software/Krang.jl/examples/mwe.jl:5 [inlined]
[6] diffejulia_test_1383wrap
@ ~/Software/Krang.jl/examples/mwe.jl:0
[7] macro expansion
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5440 [inlined]
[8] enzyme_call(::Val{false}, ::Ptr{Nothing}, ::Type{Enzyme.Compiler.CombinedAdjointThunk}, ::Type{Val{1}}, ::Val{true}, ::Type{Tuple{Active{Float64}}}, ::Type{Active{Float64}}, ::Const{typeof(test)}, ::Type{Nothing}, ::Active{Float64}, ::Float64)
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5118
[9] (::Enzyme.Compiler.CombinedAdjointThunk{Ptr{Nothing}, Const{typeof(test)}, Active{Float64}, Tuple{Active{Float64}}, Val{1}, Val{true}()})(::Const{typeof(test)}, ::Active{Float64}, ::Vararg{Any})
@ Enzyme.Compiler ~/.julia/dev/Enzyme/src/compiler.jl:5000
[10] autodiff
@ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:275 [inlined]
[11] autodiff(mode::ReverseMode{true, FFIABI, false}, f::typeof(test), ::Type{Active}, args::Active{Float64})
@ Enzyme ~/.julia/dev/Enzyme/src/Enzyme.jl:287
[12] top-level scope
@ ~/Software/Krang.jl/examples/mwe.jl:9
in expression starting at /Users/dominicchang/Software/Krang.jl/examples/mwe.jl:9
Here is the attached enzyme output. out.txt
@dominic-chang so I have been looking into this and somehow the original primal code creates an undefined value [which when we differentiate means can be arbitrary results]. Looking into how/why.
Simplified version:
wmoses@beast:~/git/Enzyme.jl ((HEAD detached from 43193d62)) $ cat nan.jl
using Enzyme
Enzyme.API.printall!(true)
#Enzyme.Compiler.CheckNan[] = true
function _pow(z::T, i) where {T<:Real}
zabs = abs(z)
if sign(z) < zero(T)
return (zabs^i) * (cos(T(π) * i) + sin(T(π) * i)im)
end
return zabs^i + zero(T)im
end
function test(n)
wp = 1 + _pow(-n, 0.5)
_pow(-n, 0.5)
return real(wp)
end
test(0.25)
@show autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25))
LLVM:
after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@double}" double @preprocess_julia_test_1422(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) local_unnamed_addr #23 !dbg !2227 {
top:
%1 = call {}*** @julia.get_pgcstack() #24
%ptls_field3 = getelementptr inbounds {}**, {}*** %1, i64 2
%2 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %2, align 8, !tbaa !21
%3 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !25
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #24, !dbg !2228
fence syncscope("singlethread") seq_cst
%4 = fneg double %0, !dbg !2229
call fastcc void @julia__pow_1425(double %4) #24, !dbg !2230
%5 = fadd double undef, 1.000000e+00, !dbg !2231
call fastcc void @julia__pow_1425(double %4) #24, !dbg !2234
ret double %5, !dbg !2235
}
after simplification :
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2245 {
top:
%1 = call {}*** @julia.get_pgcstack() #25
%ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
%2 = bitcast {}*** %ptls_field7 to i64***
%ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21
%3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !25
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #25, !dbg !2246
fence syncscope("singlethread") seq_cst
%4 = call double @llvm.fabs.f64(double %0) #25, !dbg !2247
%5 = fcmp uge double %0, 0.000000e+00, !dbg !2249
%6 = fcmp ule double %0, 0.000000e+00, !dbg !2252
%7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2254
%8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2254
%9 = fcmp uge double %8, 0.000000e+00, !dbg !2255
br i1 %9, label %L22, label %L8, !dbg !2251
common.ret: ; preds = %L22, %L8
ret void, !dbg !2256
L8: ; preds = %top
%10 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2257
%11 = call double @julia_cos_1438(double noundef 0x3FF921FB54442D18) #25, !dbg !2257
%12 = call double @julia_sin_1430(double noundef 0x3FF921FB54442D18) #25, !dbg !2257
br label %common.ret
L22: ; preds = %top
%13 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2258
br label %common.ret
}
; Function Attrs: mustprogress willreturn
define internal fastcc { double } @diffejulia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2259 {
top:
%"'de" = alloca double, align 8
%1 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %1, align 8
%2 = call {}*** @julia.get_pgcstack() #25
%ptls_field7 = getelementptr inbounds {}**, {}*** %2, i64 2
%3 = bitcast {}*** %ptls_field7 to i64***
%ptls_load89 = load i64**, i64*** %3, align 8, !tbaa !21, !alias.scope !2260, !noalias !2263
%4 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
%safepoint = load i64*, i64** %4, align 8, !tbaa !25, !alias.scope !2265, !noalias !2268
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #25, !dbg !2270
fence syncscope("singlethread") seq_cst
%5 = fcmp uge double %0, 0.000000e+00, !dbg !2271
%6 = fcmp ule double %0, 0.000000e+00, !dbg !2274
%7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2276
%8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2276
%9 = fcmp uge double %8, 0.000000e+00, !dbg !2277
br i1 %9, label %L22, label %L8, !dbg !2273
common.ret: ; preds = %L22, %L8
br label %invertcommon.ret, !dbg !2278
L8: ; preds = %top
br label %common.ret
L22: ; preds = %top
br label %common.ret
inverttop: ; preds = %invertL22, %invertL8
fence syncscope("singlethread") seq_cst
fence syncscope("singlethread") seq_cst
%10 = load double, double* %"'de", align 8
%11 = insertvalue { double } undef, double %10, 0
ret { double } %11
invertcommon.ret: ; preds = %common.ret
br i1 %9, label %invertL22, label %invertL8
invertL8: ; preds = %invertcommon.ret
br label %inverttop
invertL22: ; preds = %invertcommon.ret
br label %inverttop
}
after simplification :
; Function Attrs: mustprogress willreturn
define internal fastcc void @preprocess_julia__pow_1425.3(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2279 {
top:
%1 = call {}*** @julia.get_pgcstack() #25
%ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
%2 = bitcast {}*** %ptls_field7 to i64***
%ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21
%3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !25
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #25, !dbg !2280
fence syncscope("singlethread") seq_cst
%4 = call double @llvm.fabs.f64(double %0) #25, !dbg !2281
%5 = fcmp uge double %0, 0.000000e+00, !dbg !2283
%6 = fcmp ule double %0, 0.000000e+00, !dbg !2286
%7 = select i1 %6, double %0, double 1.000000e+00, !dbg !2288
%8 = select i1 %5, double %7, double -1.000000e+00, !dbg !2288
%9 = fcmp uge double %8, 0.000000e+00, !dbg !2289
br i1 %9, label %L22, label %L8, !dbg !2285
common.ret: ; preds = %L22, %L8
ret void, !dbg !2290
L8: ; preds = %top
%10 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2291
%11 = call double @julia_cos_1438(double noundef 0x3FF921FB54442D18) #25, !dbg !2291
%12 = call double @julia_sin_1430(double noundef 0x3FF921FB54442D18) #25, !dbg !2291
br label %common.ret
L22: ; preds = %top
%13 = call double @julia___1444(double %4, double noundef 5.000000e-01) #25, !dbg !2292
br label %common.ret
}
; Function Attrs: mustprogress willreturn
define internal fastcc void @augmented_julia__pow_1425(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2293 {
top:
%1 = call {}*** @julia.get_pgcstack() #25
%ptls_field7 = getelementptr inbounds {}**, {}*** %1, i64 2
%2 = bitcast {}*** %ptls_field7 to i64***
%ptls_load89 = load i64**, i64*** %2, align 8, !tbaa !21, !alias.scope !2294, !noalias !2297
%3 = getelementptr inbounds i64*, i64** %ptls_load89, i64 2
%safepoint = load i64*, i64** %3, align 8, !tbaa !25, !alias.scope !2299, !noalias !2302
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #25, !dbg !2304
fence syncscope("singlethread") seq_cst
%4 = fcmp uge double %0, 0.000000e+00, !dbg !2305
%5 = fcmp ule double %0, 0.000000e+00, !dbg !2308
%6 = select i1 %5, double %0, double 1.000000e+00, !dbg !2310
%7 = select i1 %4, double %6, double -1.000000e+00, !dbg !2310
%8 = fcmp uge double %7, 0.000000e+00, !dbg !2311
br i1 %8, label %L22, label %L8, !dbg !2307
common.ret: ; preds = %L22, %L8
ret void, !dbg !2312
L8: ; preds = %top
br label %common.ret
L22: ; preds = %top
br label %common.ret
}
; Function Attrs: mustprogress willreturn
define internal fastcc { double } @diffejulia__pow_1425.4(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0) unnamed_addr #24 !dbg !2313 {
top:
%"'de" = alloca double, align 8
%1 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %1, align 8
%2 = call {}*** @julia.get_pgcstack() #25
%3 = fcmp uge double %0, 0.000000e+00, !dbg !2314
%4 = fcmp ule double %0, 0.000000e+00, !dbg !2317
%5 = select i1 %4, double %0, double 1.000000e+00, !dbg !2319
%6 = select i1 %3, double %5, double -1.000000e+00, !dbg !2319
%7 = fcmp uge double %6, 0.000000e+00, !dbg !2320
br i1 %7, label %L22, label %L8, !dbg !2316
common.ret: ; preds = %L22, %L8
br label %invertcommon.ret, !dbg !2321
L8: ; preds = %top
br label %common.ret
L22: ; preds = %top
br label %common.ret
inverttop: ; preds = %invertL22, %invertL8
fence syncscope("singlethread") seq_cst
fence syncscope("singlethread") seq_cst
%8 = load double, double* %"'de", align 8
%9 = insertvalue { double } undef, double %8, 0
ret { double } %9
invertcommon.ret: ; preds = %common.ret
br i1 %7, label %invertL22, label %invertL8
invertL8: ; preds = %invertcommon.ret
br label %inverttop
invertL22: ; preds = %invertcommon.ret
br label %inverttop
}
; Function Attrs: mustprogress willreturn
define internal "enzyme_type"="{[-1]:Float@double}" { double, double } @diffejulia_test_1422(double "enzyme_type"="{[-1]:Float@double}" "enzymejl_parmtype"="124475843291040" "enzymejl_parmtype_ref"="0" %0, double %differeturn) local_unnamed_addr #23 !dbg !2236 {
top:
%"'de" = alloca double, align 8
%1 = getelementptr double, double* %"'de", i64 0
store double 0.000000e+00, double* %1, align 8
%"'de1" = alloca double, align 8
%2 = getelementptr double, double* %"'de1", i64 0
store double 0.000000e+00, double* %2, align 8
%toreturn = alloca double, align 8
%3 = call {}*** @julia.get_pgcstack() #25
%ptls_field3 = getelementptr inbounds {}**, {}*** %3, i64 2
%4 = bitcast {}*** %ptls_field3 to i64***
%ptls_load45 = load i64**, i64*** %4, align 8, !tbaa !21, !alias.scope !2237, !noalias !2240
%5 = getelementptr inbounds i64*, i64** %ptls_load45, i64 2
%safepoint = load i64*, i64** %5, align 8, !tbaa !25, !alias.scope !2242, !noalias !2245
fence syncscope("singlethread") seq_cst
call void @julia.safepoint(i64* %safepoint) #25, !dbg !2247
fence syncscope("singlethread") seq_cst
%6 = fneg double %0, !dbg !2248
call fastcc void @augmented_julia__pow_1425(double %6), !dbg !2249
%7 = fadd double undef, 1.000000e+00, !dbg !2250
br label %inverttop, !dbg !2253
inverttop: ; preds = %top
%8 = call fastcc { double } @diffejulia__pow_1425(double %6), !dbg !2254
%9 = extractvalue { double } %8, 0, !dbg !2254
%10 = load double, double* %"'de", align 8, !dbg !2254
%11 = fadd fast double %10, %9, !dbg !2254
store double %11, double* %"'de", align 8, !dbg !2254
store double %7, double* %toreturn, align 8, !dbg !2253
%12 = call fastcc { double } @diffejulia__pow_1425.4(double %6), !dbg !2249
%13 = extractvalue { double } %12, 0, !dbg !2249
%14 = load double, double* %"'de", align 8, !dbg !2249
%15 = fadd fast double %14, %13, !dbg !2249
store double %15, double* %"'de", align 8, !dbg !2249
%16 = load double, double* %"'de", align 8, !dbg !2248
store double 0.000000e+00, double* %"'de", align 8, !dbg !2248
%17 = fneg fast double %16, !dbg !2248
%18 = load double, double* %"'de1", align 8, !dbg !2248
%19 = fadd fast double %18, %17, !dbg !2248
store double %19, double* %"'de1", align 8, !dbg !2248
fence syncscope("singlethread") seq_cst
fence syncscope("singlethread") seq_cst
%retreload = load double, double* %toreturn, align 8
%20 = load double, double* %"'de1", align 8
%21 = insertvalue { double, double } undef, double %retreload, 0
%22 = insertvalue { double, double } %21, double %20, 1
ret { double, double } %22
}
autodiff(Enzyme.ReverseWithPrimal, test, Active, Active(0.25)) = ((-0.0,), NaN)
Should be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/1421
Wow amazing! Hope I was more helpful than less.
A piece of my code uses Ferrari's method to solve for the roots of a quartic equation and then takes the difference between the roots. This code needs to be GPUized on my M1 mac, and so needs be type preserving. I have noticed that using reverse mode
autodiff
on this code sometimes results inNaN
s for the derivatives, and actual numbers in a seemingly random way. Here is a MWE:And here is an example output from this code:
I am using
Enzyme v0.11.16
withjuliav1.10.1
on a 2021, 16 inch MacBook Pro running macOS14.3 (23D56)