FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 211 forks source link

Inconsistent values of sum(abs, _) between GPU and CPU (NaNs for zero input only on GPU) #1529

Open piever opened 1 month ago

piever commented 1 month ago

Bug description

I've experienced the following inconsistency between GPU and CPU gradient computation for sum(abs, _).

julia> using Zygote, CUDA

julia> rl, cplx = [0.0f0], [0.0f0 + 0.0f0im]
(Float32[0.0], ComplexF32[0.0f0 + 0.0f0im])

julia> l1(x) = sum(abs, x)
l1 (generic function with 1 method)

julia> Zygote.gradient(l1, rl)
(Float32[0.0],)

julia> Zygote.gradient(l1, cplx)
(ComplexF32[0.0f0 + 0.0f0im],)

julia> Zygote.gradient(l1, cu(rl))
(Float32[1.0],)

julia> Zygote.gradient(l1, cu(cplx))
(ComplexF32[NaN32 + NaN32*im],)

The last one is particularly problematic, as it leads to NaN values in the gradient that may be hard to understand in a more complex model.

Slack discussion

On Slack, @mcabbott explained to me the most likely cause for this:

julia> abs(ForwardDiff.Dual(0,1) + 0im) Dual{Nothing}(0.0,NaN)

- even though DiffRules has a rule for `abs` (used for real inputs), for complex inputs the computation passes through `hypot` and the [DiffRule method](https://github.com/JuliaDiff/DiffRules.jl/blob/8842177391b07dcd8234ac7612b9ca8ca72d28e0/src/rules.jl#L86) for the derivative of `hypot` in `(0, 0)` gives `NaN` 

Not sure what the best fix is here. If DiffRules is open to it, maybe the easiest is to fix their `hypot` derivative rule?

## Version info

I'm on julia 1.10.5, on a fresh environment with

(jl_FHvUua) pkg> st Status /tmp/jl_FHvUua/Project.toml [052768ef] CUDA v5.5.2 [e88e6eb3] Zygote v0.6.71

mcabbott commented 1 month ago

Any chance that https://github.com/JuliaDiff/ForwardDiff.jl/pull/669 solves this?

piever commented 1 month ago

Somehow it doesn't... Unless I messed something up, I checked out https://github.com/JuliaDiff/ForwardDiff.jl/pull/669 (manually changing version ForwardDiff version number to 0.10) and still get

julia> Zygote.gradient(l1, cu(cplx))
(ComplexF32[NaN32 + NaN32*im],)

which is weird, because indeed hypot differentiates just fine:

julia> f(x) = hypot(x, 0, 0)
f (generic function with 1 method)

julia> ForwardDiff.derivative(f, 0.0)
1.0

julia> ForwardDiff.derivative(f, -0.0)
-1.0