FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.47k stars 209 forks source link

Gradient involving `LinearAlgebra.tr` errors #1512

Open jondeuce opened 2 months ago

jondeuce commented 2 months ago

MWE:

julia> using LinearAlgebra, Zygote, CUDA

julia> Zygote.gradient(x -> tr(x), CUDA.zeros(2,2))[1] # works
2×2 Diagonal{Float32, FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}}:
 1.0   ⋅ 
  ⋅   1.0

julia> Zygote.gradient(x -> sum(abs2, x) - tr(x), zeros(2,2))[1] # works on cpu
2×2 Matrix{Float64}:
 -1.0   0.0
  0.0  -1.0

julia> Zygote.gradient(x -> sum(abs2, x) - tr(x), CUDA.zeros(2,2))[1] # fails
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
  [5] getindex(A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:48
  [6] scalar_getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:34 [inlined]
  [7] _getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:17 [inlined]
  [8] getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:15 [inlined]
  [9] _broadcast_getindex
    @ ./broadcast.jl:675 [inlined]
 [10] _getindex
    @ ./broadcast.jl:706 [inlined]
 [11] _getindex
    @ ./broadcast.jl:705 [inlined]
 [12] _broadcast_getindex
    @ ./broadcast.jl:681 [inlined]
 [13] getindex
    @ ./broadcast.jl:636 [inlined]
 [14] macro expansion
    @ ./broadcast.jl:1004 [inlined]
 [15] macro expansion
    @ ./simdloop.jl:77 [inlined]
 [16] copyto!
    @ ./broadcast.jl:1003 [inlined]
 [17] copyto!
    @ ./broadcast.jl:956 [inlined]
 [18] copy
    @ ./broadcast.jl:928 [inlined]
 [19] materialize
    @ ./broadcast.jl:903 [inlined]
 [20] broadcast_preserving_zero_d
    @ ./broadcast.jl:892 [inlined]
 [21] accum(x::Diagonal{Float32, FillArrays.Fill{Float32, 1, Tuple{…}}}, ys::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:25
 [22] ComposedFunction
    @ ./operators.jl:1041 [inlined]
 [23] #11
    @ ./REPL[12]:1 [inlined]
 [24] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
 [25] gradient(f::Function, args::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
 [26] top-level scope
    @ REPL[12]:1
Some type information was truncated. Use `show(err)` to see complete types.

Seems to be hitting this generic accum method and falling back to scalar indexing.

There's a note here about efficiently implementing the rrule for LinearAlgebra.tr, which returns a Fill wrapped in a Diagonal, and this seems to cause issues with broadcasting. In fact, here's an even smaller MWE:

julia> Diagonal(Zygote.Fill(1f0, 2)) .+ CUDA.zeros(2, 2)
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore should be avoided.

If you want to allow scalar iteration, use `allowscalar` or `@allowscalar`
to enable scalar iteration globally or for the operations in question.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] errorscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:155
  [3] _assertscalar(op::String, behavior::GPUArraysCore.ScalarIndexing)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:128
  [4] assertscalar(op::String)
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/GMsgk/src/GPUArraysCore.jl:116
  [5] getindex(A::CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, I::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:48
  [6] scalar_getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:34 [inlined]
  [7] _getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:17 [inlined]
  [8] getindex
    @ ~/.julia/packages/GPUArrays/OKkAu/src/host/indexing.jl:15 [inlined]
  [9] _broadcast_getindex
    @ ./broadcast.jl:675 [inlined]
 [10] _getindex
    @ ./broadcast.jl:706 [inlined]
 [11] _getindex
    @ ./broadcast.jl:705 [inlined]
 [12] _broadcast_getindex
    @ ./broadcast.jl:681 [inlined]
 [13] getindex
    @ ./broadcast.jl:636 [inlined]
 [14] macro expansion
    @ ./broadcast.jl:1004 [inlined]
 [15] macro expansion
    @ ./simdloop.jl:77 [inlined]
 [16] copyto!
    @ ./broadcast.jl:1003 [inlined]
 [17] copyto!
    @ ./broadcast.jl:956 [inlined]
 [18] copy
    @ ./broadcast.jl:928 [inlined]
 [19] materialize(bc::Base.Broadcast.Broadcasted{Base.Broadcast.ArrayConflict, Nothing, typeof(+), Tuple{…}})
    @ Base.Broadcast ./broadcast.jl:903
 [20] top-level scope
    @ REPL[28]:1
Some type information was truncated. Use `show(err)` to see complete types.

Package and version info:

julia> using Pkg

julia> Pkg.status()
Status `/tmp/jl_xXMchX/Project.toml`
  [052768ef] CUDA v5.3.4
  [e88e6eb3] Zygote v0.6.70
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.10.0

julia> versioninfo()
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 32 × AMD Ryzen 9 3950X 16-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 32 default, 0 interactive, 16 GC (on 32 virtual cores)
Environment:
  JULIA_EDITOR = code
  JULIA_NUM_THREADS = auto
  JULIA_CONDAPKG_BACKEND = System
  JULIA_CONDAPKG_EXE = /home/jdoucette/miniconda3/bin/conda
mcabbott commented 1 month ago

Somehow the gradient of trace needs to be a Diagonal{..., CuArray{..., made by either by calling similar, or by having a special rule for CuArray.

Rule in CR makes always Diagonal{..., Array{..., which won't be any better than the present state:

https://github.com/JuliaDiff/ChainRules.jl/blob/be9c221cf01c79f99938ad81192dab44b549c158/src/rulesets/LinearAlgebra/dense.jl#L192-L200

Rule here making a Diagonal{..., Fill{..., seems like probably a bad idea, premature optimisation:

https://github.com/FluxML/Zygote.jl/blob/3c3325d9987931f15bd478c932332be19c316de4/src/lib/array.jl#L597-L606

There's a similar issue for sum, where https://github.com/FluxML/Zygote.jl/pull/1453/files wants to remove the rule which uses FillArrays to give this:

julia> Zygote.gradient(sum, [1 2; 3 4.])
(Fill(1.0, 2, 2),)

and also remove the special rule for sum(xs::AbstractGPUArray) which uses similar. But in that case the CR rules always use similar.