FluxML / NNlib.jl

Neural Network primitives with multiple backends
Other
199 stars 121 forks source link

Nested AD with `leakyrelu` activations fails on GPU #386

Open vincentmolin opened 2 years ago

vincentmolin commented 2 years ago

Using leakyrelu causes a compilation error when differentiating through the following gradient penalty loss. It works on cpu and using for example elu/relu on gpu.

using Flux, Zygote, CUDA

function gradient_penalty(m, x)
    _, back = Flux.pullback(() -> sum(m(x)), params(x))
    grads = back(1.0f0)[x]
    return sum(grads .^ 2)
end

x = randn(Float32, 1, 4) # dims, batch

m₁ = Chain(Dense(1, 1), x -> leakyrelu.(x, 0.2f0))
l, b = Flux.pullback(() -> gradient_penalty(m₁, x), params(m₁))    # Ok

cx = x |> gpu
cm₂ = Chain(Dense(1, 1), x -> elu.(x)) |> gpu
l, b = Flux.pullback(() -> gradient_penalty(cm₂, cx), params(cm₂)) # Ok

cm₁ = Chain(Dense(1, 1), x -> leakyrelu.(x, 0.2f0)) |> gpu
l, b = Flux.pullback(() -> gradient_penalty(cm₁, cx), params(cm₁)) # Fails to compile

Throws

ERROR: LoadError: GPU compilation of kernel broadcast_kernel(CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float32, typeof(∂(#1122))}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64) failed
KernelError: passing and using non-bitstype argument

Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, which is not isbits:
  .f is of type Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"} which is not isbits.
    .cx is of type Zygote.Context which is not isbits.
      .cache is of type Union{Nothing, IdDict{Any, Any}} which is not isbits.

Stacktrace:
  [1] check_invocation(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/validation.jl:66
  [2] macro expansion
    @ ~/.julia/packages/GPUCompiler/1Ajz2/src/driver.jl:325 [inlined]
  [3] macro expansion
    @ ~/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/GPUCompiler/1Ajz2/src/driver.jl:324 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:326
  [7] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler ~/.julia/packages/GPUCompiler/1Ajz2/src/cache.jl:90
  [8] cufunction(f::GPUArrays.var"#broadcast_kernel#17", tt::Type{Tuple{CUDA.CuKernelContext, CuDeviceMatrix{Tuple{Float32, typeof(∂(#1122))}, 1}, Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{Base.Broadcast.Extruded{CuDeviceMatrix{ForwardDiff.Dual{Nothing, Float32, 2}, 1}, Tuple{Bool, Bool}, Tuple{Int64, Int64}}}}, Int64}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:297
  [9] cufunction
    @ ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:291 [inlined]
 [10] macro expansion
    @ ~/.julia/packages/CUDA/bki2w/src/compiler/execution.jl:102 [inlined]
 [11] #launch_heuristic#270
    @ ~/.julia/packages/CUDA/bki2w/src/gpuarrays.jl:17 [inlined]
 [12] copyto!
    @ ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:65 [inlined]
 [13] copyto!
    @ ./broadcast.jl:913 [inlined]
 [14] copy
    @ ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:47 [inlined]
 [15] materialize(bc::Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{2}, Nothing, Zygote.var"#561#565"{Zygote.Context, Zygote.var"#1122#1126"}, Tuple{CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer}}})
    @ Base.Broadcast ./broadcast.jl:860
 [16] map(::Function, ::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
    @ GPUArrays ~/.julia/packages/GPUArrays/umZob/src/host/broadcast.jl:90
 [17] ∇map(cx::Zygote.Context, f::Zygote.var"#1122#1126", args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:197
 [18] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/array.jl:223 [inlined]
 [19] _pullback(__context__::Zygote.Context, 541::typeof(map), f::Function, args::CuArray{ForwardDiff.Dual{Nothing, Float32, 2}, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
 [20] _pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:241 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(Zygote.broadcast_forward), ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [22] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [23] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
 [24] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [25] _pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/broadcast.jl:265 [inlined]
 [26] _pullback(::Zygote.Context, ::typeof(ZygoteRules.adjoint), ::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{2}, ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [27] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [28] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
 [29] _pullback (repeats 2 times)
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [30] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Base.Broadcast.broadcasted), ::CUDA.CuArrayStyle{2}, ::typeof(leakyrelu), ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, ::Float32)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [31] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [32] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:189 [inlined]
 [33] _pullback(::Zygote.Context, ::typeof(Core._apply), ::Function, ::Tuple{Zygote.Context, typeof(Base.Broadcast.broadcasted)}, ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65
 [34] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [35] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [37] _pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
 [38] _pullback(::Zygote.Context, ::typeof(ZygoteRules.adjoint), ::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(Base.Broadcast.broadcasted), ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [39] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:814
 [40] adjoint
    @ ~/.julia/packages/Zygote/FPUm3/src/lib/lib.jl:200 [inlined]
 [41] _pullback (repeats 2 times)
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [42] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::typeof(Base.Broadcast.broadcasted), ::Tuple{CUDA.CuArrayStyle{2}, typeof(leakyrelu), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Float32}, ::Tuple{})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [43] _pullback
    @ ./broadcast.jl:1303 [inlined]
 [44] _pullback
    @ ~/ws/msc/scratch/gpsmaller.jl:13 [inlined]
 [45] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#15#16", ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [46] _pullback
    @ ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:47 [inlined]
--- the last 2 lines are repeated 1 more time ---
 [49] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::typeof(Flux.applychain), ::Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [50] _pullback
    @ ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:49 [inlined]
 [51] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [52] _pullback
    @ ~/ws/msc/scratch/gpsmaller.jl:6 [inlined]
 [53] _pullback(::Zygote.Context, ::typeof(ZygoteRules._pullback), ::Zygote.Context, ::var"#13#14"{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [54] _pullback
    @ ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:352 [inlined]
 [55] _pullback(::Zygote.Context, ::typeof(pullback), ::var"#13#14"{Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, ::Params)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [56] _pullback
    @ ~/ws/msc/scratch/gpsmaller.jl:6 [inlined]
 [57] _pullback(::Zygote.Context, ::typeof(gradient_penalty), ::Chain{Tuple{Dense{typeof(identity), CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, var"#15#16"}}, ::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [58] _pullback
    @ ~/ws/msc/scratch/gpsmaller.jl:21 [inlined]
 [59] _pullback(::Zygote.Context, ::var"#23#24")
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [60] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:352
in expression starting at /home/vincent/ws/msc/scratch/gpsmaller.jl:21
DhairyaLGandhi commented 2 years ago

Let's add these to tests with the activations. I suppose we would want to cover a decent gamut of layers as well, so we can have the tests in Flux as extensions of https://github.com/FluxML/Flux.jl/blob/0b7e1b61addbe245e4a565d522df334ce0d41584/test/cuda/layers.jl#L84

ToucheSir commented 2 years ago

Thanks for the report, this is an interesting one. The chain points to https://github.com/FluxML/Zygote.jl/blob/v0.6.34/src/lib/broadcast.jl#L241, which when differentiated through runs the very not GPU friendly https://github.com/FluxML/Zygote.jl/blob/v0.6.34/src/lib/array.jl#L197. I'm not sure why other activations are fine here (would have to look at the call stack there to be sure). @mcabbott would replacing https://github.com/FluxML/Zygote.jl/blob/v0.6.34/src/lib/broadcast.jl#L241 by y = ForwardDiff.value.(out) help here?

DhairyaLGandhi commented 2 years ago

In general, we would expect to be able to differentiate over higher orders with map (and also differentiate through f too). That line is pretty general, and would be the same for GPU and CPU cases iirc.

ToucheSir commented 2 years ago

When running forward, yes, but the map adjoint captures the context along with a bunch of other not GPU-friendly state in https://github.com/FluxML/Zygote.jl/blob/v0.6.34/src/lib/array.jl#L197. To my knowledge broadcasting does not do this, but whether switching map for broadcast might run into issues with nested Duals I'm not sure.