jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
286 stars 34 forks source link

Scalar indexing on GPU when computing `Zygote.gradient` of `dot(x::CA, x::CA)` #236

Open vpuri3 opened 7 months ago

vpuri3 commented 7 months ago
julia> Zygote.gradient(x -> dot(x, x), CUDA.ones(4)) # works
(Float32[2.0, 2.0, 2.0, 2.0],)
julia> Zygote.gradient(x -> dot(x, x), (;x=ones(4)) |> ComponentArray |> cu)                                                                                                                          
ERROR: Scalar indexing is disallowed.                                                              
Invocation of getindex resulted in scalar indexing of a GPU array.                                                                                                                                    
This is typically caused by calling an iterating implementation of a method.                                                                                                                          
Such implementations *do not* execute on the GPU, but very slowly on the CPU,                                                                                                                         
and therefore are only permitted from the REPL for prototyping purposes.                                                                                                                              
If you did intend to index this array, annotate the caller with @allowscalar.                                                                                                                         
Stacktrace:                                                                                        
  [1] error(s::String)                           
    @ Base ./error.jl:35                                                                           
  [2] assertscalar(op::String)                                                                     
    @ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103                                                                                                                    
  [3] getindex                                   
    @ ~/.julia/packages/GPUArrays/dAUOE/src/host/indexing.jl:48 [inlined]                                                                                                                             
  [4] getindex                                                                                                                                                                                        
    @ ~/.julia/dev/ComponentArrays/src/array_interface.jl:94 [inlined]                                                                                                                                
  [5] getindex                                   
    @ ~/.julia/packages/OffsetArrays/0MOrf/src/OffsetArrays.jl:436 [inlined]                                                                                                                          
  [6] _broadcast_getindex                                                                                                                                                                             
    @ ./broadcast.jl:675 [inlined]                                                                                                                                                                    
  [7] _getindex                                                                                                                                                                                       
    @ ./broadcast.jl:705 [inlined]                                                                                                                                                                    
  [8] _broadcast_getindex                        
    @ ./broadcast.jl:681 [inlined]                                                                 
  [9] getindex                                   
    @ ./broadcast.jl:636 [inlined]                                                                                                                                                                    
 [10] macro expansion                                                                                                                                                                                 
    @ ./broadcast.jl:1004 [inlined]                                                                
 [11] macro expansion                                                                              
    @ ./simdloop.jl:77 [inlined]                                                                   
 [12] copyto!                                    
    @ ./broadcast.jl:1003 [inlined]                                                                
 [13] copyto!                                    
    @ ./broadcast.jl:956 [inlined]                                                                 
 [14] copy                                       
    @ ./broadcast.jl:928 [inlined]                                                                 
 [15] materialize                                
    @ ./broadcast.jl:903 [inlined]                                                                 
 [16] accum(x::OffsetArrays.OffsetVector{…}, ys::OffsetArrays.OffsetVector{…})                                                                                                                        
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/lib/lib.jl:25                                                                                                                                         
 [17] accum                                      
    @ ComponentArraysZygoteExt ~/.julia/dev/ComponentArrays/ext/ComponentArraysZygoteExt.jl:10 [inlined]                                                                                              
 [18] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{var"#64#65", ComponentVector{…}}, Tuple{Zygote.ZBack{…}}}})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:45                                                                                                                              
 [19] gradient(f::Function, args::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}})                                                                               
    @ Zygote ~/.julia/packages/Zygote/YYT6v/src/compiler/interface.jl:97