jonniedie / ComponentArrays.jl

Arrays with arbitrarily nested named components.
MIT License
286 stars 34 forks source link

add missing LinearAlgebra methods on GPU #227

Closed vpuri3 closed 9 months ago

vpuri3 commented 9 months ago
julia> using ComponentArrays, CUDA                                                                                                                                 

julia> a, b = ComponentArray((; a = ones(2))) |> cu, CUDA.ones(2)                                                                                                  
((a = Float32[1.0, 1.0]), Float32[1.0, 1.0])                                                                                                                       

julia> dot(a, b) # dot(b, a)
ERROR: MethodError: no method matching dot(::Int64, ::CuPtr{Float32}, ::Int64, ::CuPtr{Float32}, ::Int64)

Closest candidates are:
  dot(::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:344
  dot(::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer)
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:344

Stacktrace:
 [1] dot(x::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2,)}}}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
   @ LinearAlgebra.BLAS ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:395
 [2] dot(x::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2,)}}}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:15
 [3] top-level scope
   @ REPL[83]:1
 [4] top-level scope
   @ ~/.julia/packages/CUDA/nbRJk/src/initialization.jl:205

julia> a' * b  # b' * a                                                                                                                                                     
ERROR: MethodError: no method matching dot(::Int64, ::CuPtr{Float32}, ::Int64, ::CuPtr{Float32}, ::Int64)                                                          

Closest candidates are:                                                                                                                                            
  dot(::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer, ::Union{Ptr{Float32}, AbstractArray{Float32}}, ::Integer)                               
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:344                                      
  dot(::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer, ::Union{Ptr{Float64}, AbstractArray{Float64}}, ::Integer)                               
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:344                                      

Stacktrace:                                                                                                                                                        
 [1] dot(x::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2,)}}}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})          
   @ LinearAlgebra.BLAS ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/blas.jl:395                                 
 [2] dot(x::ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{(a = 1:2,)}}}, y::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})          
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:15                                     
 [3] *(u::Adjoint{Float32, ComponentVector{Float32, CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Axis{…}}}}, v::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})   
   @ LinearAlgebra ~/.julia/juliaup/julia-1.10.0-beta3+0.x64.linux.gnu/share/julia/stdlib/v1.10/LinearAlgebra/src/adjtrans.jl:462                                  
 [4] top-level scope                                                                                                                                               
   @ REPL[85]:1                                                                                                                                                    
 [5] top-level scope                                                                                                                                               
   @ ~/.julia/packages/CUDA/nbRJk/src/initialization.jl:205                                                                                                        
Some type information was truncated. Use `show(err)` to see complete types.