FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

NaN in gradient of abs() on complex 0 #1472

Closed koenvos closed 12 months ago

koenvos commented 12 months ago

using Zygote x = [0f0 + 1im * 0f0] f(x) = sum(abs.(x)) f'(x)

result: 1-element Vector{ComplexF32}: NaN32 + NaN32*im

koenvos commented 12 months ago

Seems like it's caused by the sqrt() (on the sum of squared real and imaginary parts). Fixed by adding a small offset: abs_(x::Complex) = sqrt(real(x) ^ 2 + imag(x) ^ 2 + 1f-12)