JuliaDiff / SparseDiffTools.jl

Fast jacobian computation through sparsity exploitation and matrix coloring
MIT License
237 stars 41 forks source link

[WIP] Fix OOP GPU AD #106

Closed ChrisRackauckas closed 2 years ago

ChrisRackauckas commented 4 years ago

This Heisenbug is particularly unsettling. When I run:

using Zygote, SparseDiffTools, LinearAlgebra, CuArrays
CuArrays.allowscalar(false)
A = cu(rand(2,2))

g(u) = A*u
function f(p)
    J = SparseDiffTools.forwarddiff_color_jacobian(g,p)
    sum(J)
end
p = cu(rand(2))
Zygote.gradient(f,p)

I get:

partial_i = Tuple{Bool,Bool}[(1, 0), (0, 1)]
vecx = Float32[0.41212302, 0.8292642]
typeof(vecx) = CuArray{Float32,1,Nothing}
x = Float32[0.41212302, 0.8292642]
typeof(x) = CuArray{Float32,1,Nothing}
t = ForwardDiff.Dual{Nothing,Float32,2}[Dual{Nothing}(0.41212302,1.0,0.0), Dual{Nothing}(0.8292642,1.0,0.0)]

Notice that while the partials are (1,0) and (0,1), it seeds t to be (1,0) and (1,0). Also, the tag on the dual is incorrect. If I copy out that computation to the REPL:

using ForwardDiff
partial_i = cu(Tuple{Bool,Bool}[(1, 0), (0, 1)])
x = p
vecx = vec(x)
t = reshape(ForwardDiff.Dual{typeof(ForwardDiff.Tag(g,eltype(vecx)))}.(vecx,partial_i),size(x))
@show t

it computes the correct thing:

t = ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,2}[Dual{ForwardDiff.Tag{typeof(g),Float32}}(0.41212302,1.0,0.0), Dual{ForwardDiff.Tag{typeof(g),Float32}}(0.8292642,0.0,1.0)]

So somehow, being inside of the package module changes the broadcast that it chooses and breaks it. While this made me scared that everything about this package is broken ever, I checked:

p = rand(2)
Zygote.gradient(f,p)

and get:

Array{Tuple{Bool,Bool},1}
partial_i = Tuple{Bool,Bool}[(1, 0), (0, 1)]
vecx = [0.708969256741854, 0.21105700050316667]
typeof(vecx) = Array{Float64,1}
x = [0.708969256741854, 0.21105700050316667]
typeof(x) = Array{Float64,1}
t = ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float64},Float64,2}[Dual{ForwardDiff.Tag{typeof(g),Float64}}(0.708969256741854,1.0,0.0), Dual{ForwardDiff.Tag{typeof(g),Float64}}(0.21105700050316667,0.0,1.0)]

which demonstrates it's only silently calculating incorrect values on GPUs. Since this case on GPUs errors right now, that's okay because users won't get a silent incorrect value (instead they just get an error), but it would be good to track this down before attempting to fix this case.