This Heisenbug is particularly unsettling. When I run:
using Zygote, SparseDiffTools, LinearAlgebra, CuArrays
CuArrays.allowscalar(false)
A = cu(rand(2,2))
g(u) = A*u
function f(p)
J = SparseDiffTools.forwarddiff_color_jacobian(g,p)
sum(J)
end
p = cu(rand(2))
Zygote.gradient(f,p)
Notice that while the partials are (1,0) and (0,1), it seeds t to be (1,0) and (1,0). Also, the tag on the dual is incorrect. If I copy out that computation to the REPL:
using ForwardDiff
partial_i = cu(Tuple{Bool,Bool}[(1, 0), (0, 1)])
x = p
vecx = vec(x)
t = reshape(ForwardDiff.Dual{typeof(ForwardDiff.Tag(g,eltype(vecx)))}.(vecx,partial_i),size(x))
@show t
it computes the correct thing:
t = ForwardDiff.Dual{ForwardDiff.Tag{typeof(g),Float32},Float32,2}[Dual{ForwardDiff.Tag{typeof(g),Float32}}(0.41212302,1.0,0.0), Dual{ForwardDiff.Tag{typeof(g),Float32}}(0.8292642,0.0,1.0)]
So somehow, being inside of the package module changes the broadcast that it chooses and breaks it. While this made me scared that everything about this package is broken ever, I checked:
which demonstrates it's only silently calculating incorrect values on GPUs. Since this case on GPUs errors right now, that's okay because users won't get a silent incorrect value (instead they just get an error), but it would be good to track this down before attempting to fix this case.
This Heisenbug is particularly unsettling. When I run:
I get:
Notice that while the partials are (1,0) and (0,1), it seeds
t
to be(1,0)
and(1,0)
. Also, the tag on the dual is incorrect. If I copy out that computation to the REPL:it computes the correct thing:
So somehow, being inside of the package module changes the broadcast that it chooses and breaks it. While this made me scared that everything about this package is broken ever, I checked:
and get:
which demonstrates it's only silently calculating incorrect values on GPUs. Since this case on GPUs errors right now, that's okay because users won't get a silent incorrect value (instead they just get an error), but it would be good to track this down before attempting to fix this case.