dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Custom trivial rrule causes NNlib CPU backend to be used for CUDA Flux.Conv #140

Closed DrChainsaw closed 11 months ago

DrChainsaw commented 11 months ago

Sorry for high level MWE, I could not come up with a way to further break it down:

using Flux, Yota

struct Wrapper{L}
  l::L
end

(w::Wrapper)(x) = w.l(x)

let
  model = Wrapper(gpu(Conv((1,1), 3=>1)))
  x = gpu(randn(Float32, 32,32,3,1))
  Yota.grad(m -> sum(m(x)), model)
end; # No problem!

However, after adding a trivial custom rrule for Wrapper it seems like the CPU backend is used:

function ChainRulesCore.rrule(config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}, w::Wrapper, args...)
    res, back = ChainRulesCore.rrule_via_ad(config, w.l, args...)
    function Wrapper_back(Δ)
      δs = back(Δ)
      ChainRulesCore.Tangent{Wrapper}(l=δs[1]), δs[2:end]...
    end
    return res, Wrapper_back
end

julia> let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       Yota.grad(m -> sum(m(x)), model)
       end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000022b2f8789c0.
│ Invocation of getindex resulted in scalar indexing of a GPU array.
│ This is typically caused by calling an iterating implementation of a method.
│ Such implementations *do not* execute on the GPU, but very slowly on the CPU,
│ and therefore are only permitted from the REPL for prototyping purposes.
│ If you did intend to index this array, annotate the caller with @allowscalar.
└ @ GPUArraysCore E:\Programs\julia\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:106
ERROR: TaskFailedException
### Full stacktrace in details in the end ###

Seems like Zygote can handle it:

let
    model = Wrapper(gpu(Conv((1,1), 3=>1)))
    x = gpu(randn(Float32, 32,32,3,1))
    gradient(m -> sum(m(x)), model)
end;
Full Stack Trace ```julia julia> let model = Wrapper(gpu(Conv((1,1), 3=>1))) x = gpu(randn(Float32, 32,32,3,1)) Yota.grad(m -> sum(m(x)), model) end; ┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000022b2f8789c0. │ Invocation of getindex resulted in scalar indexing of a GPU array. │ This is typically caused by calling an iterating implementation of a method. │ Such implementations *do not* execute on the GPU, but very slowly on the CPU, │ and therefore are only permitted from the REPL for prototyping purposes. │ If you did intend to index this array, annotate the caller with @allowscalar. └ @ GPUArraysCore E:\Programs\julia\.julia\packages\GPUArraysCore\uOYfN\src\GPUArraysCore.jl:106 ERROR: TaskFailedException nested task error: TaskFailedException nested task error: MethodError: no method matching gemm!(::Val{false}, ::Val{true}, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::CUDA.CuPtr{Float32}, ::Float32, ::Ptr{Float32}) Closest candidates are: gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float32, ::Ptr{Float32}, ::Ptr{Float32}, ::Float32, ::Ptr{Float32}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29 gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::Float64, ::Ptr{Float64}, ::Ptr{Float64}, ::Float64, ::Ptr{Float64}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29 gemm!(::Val, ::Val, ::Int64, ::Int64, ::Int64, ::ComplexF64, ::Ptr{ComplexF64}, ::Ptr{ComplexF64}, ::ComplexF64, ::Ptr{ComplexF64}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\gemm.jl:29 ... Stacktrace: [1] macro expansion @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:163 [inlined] [2] (::NNlib.var"#640#641"{Float32, Array{Float32, 3}, Float32, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, DenseConvDims{3, 3, 3, 6, 3}, Int64, Int64, Int64, UnitRange{Int64}, Int64})() @ NNlib .\threadingconstructs.jl:404 Stacktrace: [1] sync_end(c::Channel{Any}) @ Base .\task.jl:445 [2] macro expansion @ .\task.jl:477 [inlined] [3] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; col::Array{Float32, 3}, alpha::Float32, beta::Float32, ntasks::Int64) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:155 [4] ∇conv_data_im2col!(dx::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, dy::SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, w::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\impl\conv_im2col.jl:126 [5] (::NNlib.var"#323#327"{Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, DenseConvDims{3, 3, 3, 6, 3}, CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, SubArray{Float32, 5, Array{Float32, 5}, Tuple{Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})() @ NNlib .\threadingconstructs.jl:404 Stacktrace: [1] sync_end(c::Channel{Any}) @ Base .\task.jl:445 [2] macro expansion @ .\task.jl:477 [inlined] [3] ∇conv_data!(out::Array{Float32, 5}, in1::Array{Float32, 5}, in2::CUDA.CuArray{Float32, 5, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{3, 3, 3, 6, 3}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:249 [4] ∇conv_data! @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:226 [inlined] [5] ∇conv_data!(y::Array{Float32, 4}, x::Array{Float32, 4}, w::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}) @ NNlib E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:145 [6] ∇conv_data! @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:140 [inlined] [7] #∇conv_data#241 @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:99 [inlined] [8] ∇conv_data @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:95 [inlined] [9] #380 @ E:\Programs\julia\.julia\packages\NNlib\aaK3U\src\conv.jl:350 [inlined] [10] unthunk @ E:\Programs\julia\.julia\packages\ChainRulesCore\0t04l\src\tangent_types\thunks.jl:204 [inlined] [11] map(f::typeof(ChainRulesCore.unthunk), t::Tuple{ChainRulesCore.Tangent{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, NamedTuple{(:weight, :bias), Tuple{Array{Float32, 4}, Vector{Float32}}}}, ChainRulesCore.Thunk{NNlib.var"#380#383"{Array{Float32, 4}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, DenseConvDims{2, 2, 2, 4, 2}}}}) @ Base .\tuple.jl:274 [12] mkcall(::Any, ::Any, ::Vararg{Any}; val::Any, line::Any, kwargs::Any, free_kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}}) @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:207 [13] mkcall(::Any, ::Any, ::Vararg{Any}) @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:188 [14] finalize_grad!(tape::Umlaut.Tape{Yota.GradCtx}) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:238 [15] #gradtape!#77 @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:251 [inlined] [16] gradtape(f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}; ctx::Yota.GradCtx, seed::Symbol) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:268 [17] gradtape @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:263 [inlined] [18] make_rrule!(f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\chainrules.jl:91 [19] rrule_via_ad(cfg::Yota.YotaRuleConfig, f::Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\chainrules.jl:119 [20] rrule(config::Yota.YotaRuleConfig, w::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, args::CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}) @ Main .\REPL[38]:2 [21] mkcall(::Any, ::Any, ::Vararg{Any}; val::Any, line::Any, kwargs::Any, free_kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}}) @ Umlaut E:\Programs\julia\.julia\packages\Umlaut\XPASX\src\tape.jl:207 [22] chainrules_transform!(tape::Umlaut.Tape{Yota.GradCtx}) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:149 [23] #gradtape!#77 @ E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:247 [inlined] [24] gradtape(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any}; ctx::Yota.GradCtx, seed::Int64) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:268 [25] grad(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any}; seed::Int64) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:360 [26] grad(::Function, ::Wrapper{Conv{2, 4, typeof(identity), CUDA.CuArray{Float32, 4, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}, ::Vararg{Any}) @ Yota E:\Programs\julia\.julia\packages\Yota\sbYPp\src\grad.jl:352 [27] top-level scope @ REPL[46]:4 [28] top-level scope @ E:\Programs\julia\.julia\packages\CUDA\tVtYo\src\initialization.jl:185 ```
dfdx commented 11 months ago

It will take me some time to get an environment with a GPU to try it, but I have a couple of quick guesses:

  1. x is passed implicitly in your code. Yota treats all non-argument values as constants, so x is recorded to the underlying tape as is and never changes. Not sure it causes the issue, but I'd certainly use Yota.grad((m, x) -> sum(m, x), model, x) instead.
  2. rrule_via_ad() has always been quite fragile, so I'd try an rrule() without it first. If it works, we can further check what's going on in rrule_via_ad() as follows:
import Yota: V

# this is what happens internally when you call rrule_via_ad()
tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())

# check if any operation produces a non-GPU array
for i in length(tape)
    op = tape[V(i)]
    if isa(op.val, AbstractArray) && !isa(op.val, CuArray)
        println("Operaion $(op) produces a non-GPU array")
    elseif isa(op.val, Tuple)
        for res in op.val
             if isa(res, AbstractArray) && !isa(res, CuArray)
                println("Operaion $(op) produces a non-GPU array")
            end
        end
    end
end
DrChainsaw commented 11 months ago

Thanks!

I tried 1 and it gave the same error:


let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       Yota.grad((m, xx) -> sum(model(xx)), model, x)
       end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x000002428231c7d0.

For 2 I can't really wrap my head around how to do it without calling back to AD for the wrapped function.

Here is one example from the wild where the only purpose of the rrule is to snoop on gradients. How can it be written without calling back to AD?

I'm not sure if the last part was some way to try to see where it goes wrong or if it was just fyi, but I tried running it but I get the same error from gradtape:

import Yota: V

let
       model = Wrapper(gpu(Conv((1,1), 3=>1)))
       x = gpu(randn(Float32, 32,32,3,1))
       tape = Yota.gradtape(model.l, x; seed=:auto, ctx=Yota.GradCtx())
end;
┌ Warning: Performing scalar indexing on task Task (runnable) @0x0000024281dbb650.

Let me know if there is something else I can do to help troubleshoot from my side.

dfdx commented 11 months ago

For 2 I can't really wrap my head around how to do it without calling back to AD for the wrapped function.

I meant using a fake rrule just to understand where the error comes from. But from the last experiment I'm now pretty much sure the problem comes from rrule_via_ad(), which narrows down the search. I will try to get a Julia + GPU setup and debug the issue this week. Thanks for discovering it!

dfdx commented 11 months ago

TLDR: Run ] add cuDNN .

It turns out, neither CUDA, nor Flux install cuDNN by default. import Flux reports it as a possible issue, but doesn't prevent you from running the code on GPU. As a result, even forward pass model(x) leads to the error you posted, without Yota or ChainRules involved. Installing cuDNN solves the issue.

I don't quite understand how Zygote works around this issue, but it may indirectly install cuDNN or rewrite calls to conv to some alternative implementation than model(x) invokes.

DrChainsaw commented 11 months ago

I had both CUDA and cuDNN installed in the MWE. :(

The example without the rrule and the Zygote example wouldn't have worked without it. Sorry for not making that clear.

dfdx commented 11 months ago

So you mean with cuDNN installed the Yota + ChainRules example still doesn't work? Can you post your ] st then?

dfdx commented 11 months ago

Ok, I can reproduce it now. Investigating.

dfdx commented 11 months ago

Fixed in version 0.8.5, please re-open if you still experience the issue.