FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Unexpected return values of `pullback` on GPU #1424

Open YichengDWu opened 1 year ago

YichengDWu commented 1 year ago
using CUDA, Zygote

x = CUDA.rand(2)
y = CUDA.rand(2)

f(x,y) = broadcast(tuple,x,y)

pullback(f, x, y)[1]
2-element CuArray{Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}, 1, CUDA.Mem.DeviceBuffer}:
 (Dual{Nothing}(0.19380774,1.0,0.0), Dual{Nothing}(0.0026825257,0.0,1.0))
 (Dual{Nothing}(0.28045696,1.0,0.0), Dual{Nothing}(0.62378126,0.0,1.0))

Not sure how ForwardDiff.Dual kicks in here. On CPU it's normal:

pullback(f, Vector(x), Vector(y))[1]
2-element Vector{Tuple{Float32, Float32}}:
 (0.19380774, 0.0026825257)
 (0.28045696, 0.62378126)

This causes the following bug:

julia> o, back = pullback(f, x, y)
(Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}[(Dual{Nothing}(0.8242247,1.0,0.0), Dual{Nothing}(0.5431607,0.0,1.0)), (Dual{Nothing}(0.83507776,1.0,0.0), Dual{Nothing}(0.4801908,0.0,1.0))], Zygote.var"#68#69"{typeof(∂(f))}(∂(f)))

julia> back(o)
(nothing, nothing)

julia> o, back = pullback(f, Vector(x), Vector(y))
(Tuple{Float32, Float32}[(0.8242247, 0.5431607), (0.83507776, 0.4801908)], Zygote.var"#68#69"{typeof(∂(f))}(∂(f)))

julia> back(o)
(Float32[0.8242247, 0.83507776], Float32[0.5431607, 0.4801908])
ToucheSir commented 1 year ago

Broadcasting on GPU unconditionally takes the ForwardDiff path, which is why you see Duals. But those Duals should not make it to the user, so that's a bug. Evidently https://github.com/FluxML/Zygote.jl/blob/v0.6.61/src/lib/broadcast.jl#L295 is not smart enough to recurse into the Tuples to remove any Duals there.

YichengDWu commented 1 year ago

What about a custom struct? It just throws an error

struct Point{T}
       x::T
       y::T
       Point(x::T,y::T) where T = new{T}(x,y)
end

import Adapt
Adapt.@adapt_structure Point

f(x,y) = Point.(x,y)

julia> pullback(f, x, y)[1]
ERROR: GPU broadcast resulted in non-concrete element type Any.
This probably means that the function you are broadcasting contains an error or type instability.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] copy
    @ ~/.julia/packages/GPUArrays/g2pOV/src/host/broadcast.jl:34 [inlined]
  [3] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1}, Nothing, Zygote.var"#1550#1551"{UnionAll}, Tuple{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})
    @ Base.Broadcast ./broadcast.jl:860
  [4] broadcast_forward(::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:269
  [5] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/broadcast.jl:348 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{1}, ::Type, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66
  [7] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:816
  [8] adjoint
    @ ~/.julia/packages/Zygote/g2w9o/src/lib/lib.jl:203 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
 [10] _pullback
    @ ./broadcast.jl:1304 [inlined]
 [11] _pullback
    @ ./REPL[262]:1 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(f), ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface2.jl:0
 [13] pullback(::Function, ::Zygote.Context{false}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::Vararg{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:44
 [14] pullback(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/g2w9o/src/compiler/interface.jl:42
 [15] top-level scope
    @ REPL[263]:1
 [16] top-level scope
    @ ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52

but on cpu it's fine

julia> pullback(f, Vector(x), Vector(y))[1]
2-element Vector{Point{Float32}}:
 Point{Float32}(0.26743817f0, 0.2564943f0)
 Point{Float32}(0.34023497f0, 0.41681844f0)
ToucheSir commented 1 year ago

That's because of https://github.com/JuliaGPU/CUDA.jl/issues/1761, which Zygote doesn't have any control over. Defining a differently-named constructor like tuple is to Tuple and using that should work. Filling out the type parameters of Point (i.e. f(x::AbstractArray{A}, y::AbstractArray{B}) where {A, B} = Point{A, B}.(x,y)) might also work as it avoids the UnionAll.

YichengDWu commented 1 year ago

I don't understand why it is using the forward mode AD here. Is there a way to force using the reverse mode AD on GPU? Say writing a custom rrule.

ToucheSir commented 1 year ago

Only by defining a rule for broadcasted(::myfunc, ...), which doesn't exist for most functions. Otherwise it won't be GPU compatible. You could see how ChainRule's broadcast rule handles this. If it does a better job on GPU, that's another argument for trying to replace Zygote's broadcasting machinery with it.

YichengDWu commented 1 year ago

Thanks I will try writing a rrule then.

YichengDWu commented 1 year ago

pullback still uses the forward mode even there is an rrule.

julia> function rrule(::typeof(f), x, y)
       o = f(x,y)
       function f_pullback(x̄)
           return NoTangent(), x, y
       end
       return o, f_pullback
       end
rrule (generic function with 2 methods)

julia> o, back = pullback(f, x, y)
(Tuple{ForwardDiff.Dual{Nothing, Float32, 2}, ForwardDiff.Dual{Nothing, Float32, 2}}[(Dual{Nothing}(0.61599565,1.0,0.0), Dual{Nothing}(0.16058706,0.0,1.0)), (Dual{Nothing}(0.7189002,1.0,0.0), Dual{Nothing}(0.69142073,0.0,1.0))], Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(f), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#4155#back#1388"{Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcasted), typeof(tuple), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{NTuple{4, Nothing}, Tuple{}}, Zygote.var"#4165#back#1428"{Zygote.var"#1394#1396"}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.var"#2001#back#200"{typeof(identity)}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.broadcastable), CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Tuple{}}, Zygote.var"#2157#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing, Nothing}, Tuple{}}, Zygote.var"#combine_styles_pullback#1170"{Tuple{Nothing, Nothing, Nothing}}}}, Zygote.var"#2865#back#684"{Zygote.var"#map_back#678"{typeof(Base.Broadcast.broadcastable), 1, Tuple{Tuple{}}, Tuple{Val{0}}, Tuple{}}}}}}}}}(∂(f)))

julia> o, back = rrule(f, x, y)
(Tuple{Float32, Float32}[(0.61599565, 0.16058706), (0.7189002, 0.69142073)], var"#f_pullback#34"{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}(Float32[0.61599565, 0.7189002], Float32[0.16058706, 0.69142073]))
YichengDWu commented 1 year ago

Ok this works

f(x,y) = ChainRulesCore.@ignore_derivatives broadcast(tuple,x,y)