JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
436 stars 89 forks source link

`rrule` for `mean(f, x)` is not vectorized? #733

Open Red-Portal opened 1 year ago

Red-Portal commented 1 year ago

Hi, it seems that the rrule for mean(f, x) is not vectorized and thus does not place nicely with CUDA:

using Zygote, CUDA, Statistics

julia> gradient(y -> mean(x -> x.^2, y), CUDA.randn(10))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
  [3] getindex
    @ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
  [4] iterate
    @ ./abstractarray.jl:1220 [inlined]
  [5] iterate
    @ ./abstractarray.jl:1218 [inlined]
  [6] iterate
    @ ./generator.jl:44 [inlined]
  [7] collect(itr::Base.Generator{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ChainRules.var"#1655#1660"{Zygote.ZygoteRuleConfig{Zygote.Context{false}}, var"#24#26"}})
    @ Base ./array.jl:782
  [8] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::typeof(sum), f::var"#24#26", xs::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; dims::Function)
    @ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:102
  [9] rrule
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:76 [inlined]
 [10] #rrule#1808
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:28 [inlined]
 [11] rrule
    @ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:21 [inlined]
 [12] chain_rrule
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:223 [inlined]
 [13] macro expansion
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
 [14] _pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
 [15] _pullback
    @ ./REPL[14]:1 [inlined]
 [16] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
 [17] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
 [18] pullback
    @ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
 [19] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
 [20] top-level scope
    @ REPL[14]:1
 [21] top-level scope
    @ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185

The problem seems to be that this line does not use map or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?

By the way, sum(f, x) for the same f works perfectly. So I'm quite curious why the result is different. Both hit the same rrule right?

julia> gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10))
(Float32[-0.03543221, -0.002124702, 0.068868384, -0.21756743, 0.234217, -0.16418666, -0.033367466, -0.26496077, 0.095435165, -0.044487894],)
Red-Portal commented 1 year ago

This appears to be more complicated. It seems that gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10)) does not hit the sum(f, x) rrule, while mean(f, x) does. This is super weird. I have no idea which rrule is being hit for sum(f, x).

mcabbott commented 1 year ago

Zygote has this rule for sum(f, xs::CuArray), which takes precedence over the one here:

https://github.com/FluxML/Zygote.jl/blob/d4562e330d588cb986604bb4f1942bf9fca8ecc5/src/lib/broadcast.jl#L372-L377

Note also that sum(x -> x^2, xs) is equivalent to sum(abs2, xs) which has a special rule. I think that mean(abs2, xs) goes here and should call that.

(One example above has x -> x.^2 with an extra broadcast, some chance that changes what path is taken in the sum(f, xs) rule.)