JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
436 stars 89 forks source link

Rules that use scalar indexing are not GPU-compatible #617

Open danielwe opened 2 years ago

danielwe commented 2 years ago

Some rules use scalar indexing, breaking GPU compatibility, e.g., https://github.com/JuliaDiff/ChainRules.jl/blob/3b3791f10bc88c41f004fbb9eb229745d1764593/src/rulesets/LinearAlgebra/norm.jl#L187

One solution would be to use @allowscalar from GPUArrays, but one concern about adding that dependency is the loading time (h/t @mcabbott).

Another is to be clever and replace all scalar indexing by size-1 views, such as:

@inbounds @views ∂x[yind:yind] .= sign.(x[yind:yind]) .* Δu

For people unfamiliar with GPU programming and @allowscalar: The reason scalar setindex/getindex is disallowed for GPU arrays is that sequential processing of GPU arrays from loopy CPU code defeats the purpose of using the GPU and leads to terrible performance. However, scalar setindex/getindex is of course legitimate in cases like the above, where you're not in a loop and only intend to set or retrieve that single element in the array. @allowscalar from GPUArrays is how you indicate that an exception should be made in any given instance, e.g.,

@inbounds @allowscalar ∂x[yind] = sign(x[yind]) * Δu

See https://cuda.juliagpu.org/stable/usage/workflow/#UsageWorkflowScalar for more about this mechanism


MWE:

using ChainRules
using CUDA
using LinearAlgebra

CUDA.allowscalar(false)
x = CUDA.ones(3)
nx, pb = ChainRules.rrule(norm, x, Inf)
ChainRules.unthunk(pb(1f0)[2])
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~/.julia/packages/GPUArrays/Zecv7/src/host/indexing.jl:53
  [3] getindex
    @ ~/.julia/packages/GPUArrays/Zecv7/src/host/indexing.jl:86 [inlined]
  [4] findprev(testf::ChainRules.var"#1770#1771"{Float32}, A::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, start::Int64)
    @ Base ./array.jl:2151
  [5] findlast
    @ ./array.jl:2199 [inlined]
  [6] _normInf_back(x::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, y::Float32, Δy::Float32)
    @ ChainRules ~/.julia/packages/ChainRules/SrdPq/src/rulesets/LinearAlgebra/norm.jl:185
  [7] (::ChainRules.var"#1742#1746"{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float64, Float32})()
    @ ChainRules ~/.julia/packages/ChainRules/SrdPq/src/rulesets/LinearAlgebra/norm.jl:49
  [8] unthunk
    @ ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/thunks.jl:195 [inlined]
  [9] unthunk(x::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{ChainRules.var"#1742#1746"{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float64, Float32}}, ChainRules.var"#1741#1745"{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Float64, Float32}})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/thunks.jl:222
 [10] top-level scope
    @ REPL[329]:1
 [11] top-level scope
oxinabox commented 2 years ago

We could implement @allow_scalar as a helper in ChainRules.jl. It is very simple. https://github.com/JuliaGPU/GPUArrays.jl/blob/master/src/host/indexing.jl#L74 Just setting a variable in task local storage. Downside is that if GPUArrays changes the mechanism it uses we wouldn't be protected by semver.

For rules that basically work already because the way to do it for AbstractArrays naturally matches what you want, and it just needs telling GPUArrays that "this is an ok scalar operation", this seems fine. For things that were bigger differences, (e.g. because they do bad scalar ops) I would say those would warrent seperate rules defined in GPUArrays.jl

danielwe commented 2 years ago

For things that were bigger differences, (e.g. because they do bad scalar ops) I would say those would warrent seperate rules defined in GPUArrays.jl

For sure, if a good rrule needs a custom GPU kernel it belongs somewhere else. On the other hand, there might exist rules where moving from explicit loops and indexing to a more vectorized/mapreducing style would only make the rule more generic with no downside. I haven't gone looking for examples, so I have no idea how big of an effort it would be to carry out such a rewrite across the package.

mcabbott commented 2 years ago

This _normInf_back function, and others like it which write into an array of zeros, should probably eventually call the gradient of getindex. This is because mutation won't play nicely with second derivatives, and maybe also immutable arrays. Xref #382 maybe.

At the moment, the candidate for that is this function with an ugly name, which has a 2nd derivative rule, and takes care to avoid scalar indexing:

https://github.com/JuliaDiff/ChainRules.jl/blob/7b5f4d137f338661e266d438932b27b04a3c20b1/src/rulesets/Base/array.jl#L537-L538

There are other places where @allowscalar would be useful though, and if it's so simple we should just copy the macro. Like this hack which (IIRC) may not actually work right now:

https://github.com/JuliaDiff/ChainRules.jl/blob/3b3791f10bc88c41f004fbb9eb229745d1764593/src/rulesets/Base/array.jl#L240-L241

The other question is testing things. One possible way is to use a fake GPUArray which they provide, so that tests can work on everyone's laptop, and also on github's CI. Something like this:

https://github.com/FluxML/OneHotArrays.jl/blob/main/test/runtests.jl#L17-L35

Maybe we could have a very simple test_GPU function which just checks that the rule agrees with itself when fed cu(x). And slowly start adding that to tests of rules which are known to work?