Open danielwe opened 2 years ago
We could implement @allow_scalar
as a helper in ChainRules.jl.
It is very simple.
https://github.com/JuliaGPU/GPUArrays.jl/blob/master/src/host/indexing.jl#L74
Just setting a variable in task local storage.
Downside is that if GPUArrays changes the mechanism it uses we wouldn't be protected by semver.
For rules that basically work already because the way to do it for AbstractArrays naturally matches what you want, and it just needs telling GPUArrays that "this is an ok scalar operation", this seems fine. For things that were bigger differences, (e.g. because they do bad scalar ops) I would say those would warrent seperate rules defined in GPUArrays.jl
For things that were bigger differences, (e.g. because they do bad scalar ops) I would say those would warrent seperate rules defined in GPUArrays.jl
For sure, if a good rrule needs a custom GPU kernel it belongs somewhere else. On the other hand, there might exist rules where moving from explicit loops and indexing to a more vectorized/mapreducing style would only make the rule more generic with no downside. I haven't gone looking for examples, so I have no idea how big of an effort it would be to carry out such a rewrite across the package.
This _normInf_back
function, and others like it which write into an array of zeros, should probably eventually call the gradient of getindex
. This is because mutation won't play nicely with second derivatives, and maybe also immutable arrays. Xref #382 maybe.
At the moment, the candidate for that is this function with an ugly name, which has a 2nd derivative rule, and takes care to avoid scalar indexing:
There are other places where @allowscalar
would be useful though, and if it's so simple we should just copy the macro. Like this hack which (IIRC) may not actually work right now:
The other question is testing things. One possible way is to use a fake GPUArray which they provide, so that tests can work on everyone's laptop, and also on github's CI. Something like this:
https://github.com/FluxML/OneHotArrays.jl/blob/main/test/runtests.jl#L17-L35
Maybe we could have a very simple test_GPU
function which just checks that the rule agrees with itself when fed cu(x)
. And slowly start adding that to tests of rules which are known to work?
Some rules use scalar indexing, breaking GPU compatibility, e.g., https://github.com/JuliaDiff/ChainRules.jl/blob/3b3791f10bc88c41f004fbb9eb229745d1764593/src/rulesets/LinearAlgebra/norm.jl#L187
One solution would be to use
@allowscalar
from GPUArrays, but one concern about adding that dependency is the loading time (h/t @mcabbott).Another is to be clever and replace all scalar indexing by size-1 views, such as:
For people unfamiliar with GPU programming and
@allowscalar
: The reason scalar setindex/getindex is disallowed for GPU arrays is that sequential processing of GPU arrays from loopy CPU code defeats the purpose of using the GPU and leads to terrible performance. However, scalar setindex/getindex is of course legitimate in cases like the above, where you're not in a loop and only intend to set or retrieve that single element in the array.@allowscalar
from GPUArrays is how you indicate that an exception should be made in any given instance, e.g.,See https://cuda.juliagpu.org/stable/usage/workflow/#UsageWorkflowScalar for more about this mechanism
MWE: