Open YichengDWu opened 1 year ago
Broadcasting on GPU unconditionally takes the ForwardDiff path, which is why you see Duals. But those Duals should not make it to the user, so that's a bug. Evidently https://github.com/FluxML/Zygote.jl/blob/v0.6.61/src/lib/broadcast.jl#L295 is not smart enough to recurse into the Tuples to remove any Duals there.
What about a custom struct? It just throws an error
struct Point{T}
x::T
y::T
Point(x::T,y::T) where T = new{T}(x,y)
end
import Adapt
Adapt.@adapt_structure Point
f(x,y) = Point.(x,y)
julia> pullback(f, x, y)[1]
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] copy
@ ~/.julia/packages/GPUArrays/g2pOV/src/host/broadcast.jl:34 [inlined]
[3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, Zygote.var"#1550#1551"{UnionAll}, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})
@ Base.Broadcast ./broadcast.jl:860
[4] broadcast_forward(::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:269
[5] adjoint
@ ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:348 [inlined]
[6] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{1}, ::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
[7] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:816
[8] adjoint
@ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
[9] _pullback
@ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[10] _pullback
@ ./broadcast.jl:1304 [inlined]
[11] _pullback
@ ./REPL[262]:1 [inlined]
[12] _pullback(::Zygote.Context{false}, ::typeof(f), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
[13] pullback(::Function, ::Zygote.Context{false}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
[14] pullback(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42
[15] top-level scope
@ REPL[263]:1
[16] top-level scope
@ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52
but on cpu it's fine
julia> pullback(f, Vector(x), Vector(y))[1]
2-element Vector{Point{Float32}}:
Point{Float32}(0.26743817f0, 0.2564943f0)
Point{Float32}(0.34023497f0, 0.41681844f0)
That's because of https://github.com/JuliaGPU/CUDA.jl/issues/1761, which Zygote doesn't have any control over. Defining a differently-named constructor like tuple
is to Tuple
and using that should work. Filling out the type parameters of Point
(i.e.
f(x::AbstractArray{A}, y::AbstractArray{B}) where {A, B} = Point{A, B}.(x,y)
) might also work as it avoids the UnionAll.
I don't understand why it is using the forward mode AD here. Is there a way to force using the reverse mode AD on GPU? Say writing a custom rrule.
Only by defining a rule for broadcasted(::myfunc, ...)
, which doesn't exist for most functions. Otherwise it won't be GPU compatible. You could see how ChainRule's broadcast rule handles this. If it does a better job on GPU, that's another argument for trying to replace Zygote's broadcasting machinery with it.
Thanks I will try writing a rrule then.
pullback
still uses the forward mode even there is an rrule
.
julia> function rrule(::typeof(f), x, y)
o = f(x,y)
function f_pullback(x̄)
return NoTangent(), x, y
end
return o, f_pullback
end
rrule (generic function with 2 methods)
julia> o, back = pullback(f, x, y)
(Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}[(Dual{Nothing}(0.61599565,1.0,0.0), Dual{Nothing}(0.16058706,0.0,1.0)), (Dual{Nothing}(0.7189002,1.0,0.0), Dual{Nothing}(0.69142073,0.0,1.0))], Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(f), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#4155#back#1388"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(tuple), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4165#back#1428"{Zygote.var"#1394#1396"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1170"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.var"#2865#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}}}}}}}(∂(f)))
julia> o, back = rrule(f, x, y)
(Tuple{Float32, Float32}[(0.61599565, 0.16058706), (0.7189002, 0.69142073)], var"#f_pullback#34"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}(Float32[0.61599565, 0.7189002], Float32[0.16058706, 0.69142073]))
Ok this works
f(x,y) = ChainRulesCore.@ignore_derivatives broadcast(tuple,x,y)
Not sure how
ForwardDiff.Dual
kicks in here. On CPU it's normal:This causes the following bug: